[XPU] support kernel for mtp(base) (#4748)

* [XPU] support kernel for mtp(base)

* [XPU] support kernel for mtp(base)

* format

* format

* format

* fix gather next token

* fix step && add test

* fix

* mv pre/post process

* add adjust batch / gather next token for mtp

* fix code style

* fix mtp kenrel name

* fix mtp kernel test

* mv xpu pre/post process

* mv xpu pre/post process
This commit is contained in:
cmcamdy
2025-11-27 15:05:44 +08:00
committed by GitHub
parent e63d715fc3
commit 5a67a6d960
32 changed files with 3618 additions and 972 deletions

View File

@@ -18,38 +18,49 @@
#include "utility/helper.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
template <paddle::DataType T>
std::vector<paddle::Tensor> AdjustBatchKernel(
const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &decoder_seq_lod,
const paddle::Tensor &encoder_batch_idx,
const paddle::Tensor &decoder_batch_idx,
const paddle::Tensor &encoder_seq_lod_cpu,
const paddle::Tensor &decoder_seq_lod_cpu,
const paddle::Tensor &encoder_batch_idx_cpu,
const paddle::Tensor &decoder_batch_idx_cpu,
const paddle::Tensor &enc_batch_tensor,
const paddle::Tensor &dec_batch_tensor,
const paddle::Tensor &len_info_cpu,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
auto ctx = static_cast<const phi::XPUContext *>(dev_ctx)->x_context();
PD_CHECK(x.dtype() == T);
PD_CHECK(x.dims().size() == 2);
if (x.is_cpu()) {
ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU);
}
using XPUType = typename XPUTypeTrait<typename PDTraits<T>::DataType>::Type;
using data_t = typename PDTraits<T>::data_t;
const int token_num = x.dims()[0];
const int dim = x.dims()[1];
const int bsz = cum_offsets.shape()[0];
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
int enc_batch = len_info_cpu.data<int32_t>()[0];
int dec_batch = len_info_cpu.data<int32_t>()[1];
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp{
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
enc_batch + 1,
const_cast<int32_t *>(encoder_seq_lod.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> decoder_seqs_lods_vp{
const_cast<int32_t *>(decoder_seq_lod_cpu.data<int32_t>()),
dec_batch + 1,
const_cast<int32_t *>(decoder_seq_lod.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp{
const_cast<int32_t *>(encoder_batch_idx_cpu.data<int32_t>()),
enc_batch,
@@ -59,13 +70,14 @@ std::vector<paddle::Tensor> AdjustBatchKernel(
dec_batch,
const_cast<int32_t *>(decoder_batch_idx.data<int32_t>())};
auto out = paddle::full({token_num, dim}, -2, x.type(), x.place());
auto out = paddle::empty({token_num, dim}, x.type(), x.place());
int r = baidu::xpu::api::plugin::eb_adjust_batch<XPUType, XPUType>(
xpu_ctx->x_context(),
ctx,
reinterpret_cast<const XPUType *>(x.data<data_t>()),
reinterpret_cast<XPUType *>(out.data<data_t>()),
encoder_seqs_lods_vp,
decoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
@@ -76,13 +88,14 @@ using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &decoder_seq_lod,
const paddle::Tensor &encoder_batch_idx,
const paddle::Tensor &decoder_batch_idx,
const paddle::Tensor &encoder_seq_lod_cpu,
const paddle::Tensor &decoder_seq_lod_cpu,
const paddle::Tensor &encoder_batch_idx_cpu,
const paddle::Tensor &decoder_batch_idx_cpu,
const paddle::Tensor &enc_batch_tensor,
const paddle::Tensor &dec_batch_tensor,
const paddle::Tensor &len_info_cpu,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length);
@@ -90,13 +103,14 @@ std::vector<paddle::Tensor> AdjustBatch(
const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &decoder_seq_lod,
const paddle::Tensor &encoder_batch_idx,
const paddle::Tensor &decoder_batch_idx,
const paddle::Tensor &encoder_seq_lod_cpu,
const paddle::Tensor &decoder_seq_lod_cpu,
const paddle::Tensor &encoder_batch_idx_cpu,
const paddle::Tensor &decoder_batch_idx_cpu,
const paddle::Tensor &enc_batch_tensor,
const paddle::Tensor &dec_batch_tensor,
const paddle::Tensor &len_info_cpu,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
AdjustBatchKernelFuncPtr func = nullptr;
@@ -108,12 +122,12 @@ std::vector<paddle::Tensor> AdjustBatch(
case paddle::DataType::FLOAT16:
func = &AdjustBatchKernel<paddle::DataType::FLOAT16>;
break;
case paddle::DataType::FLOAT32:
func = &AdjustBatchKernel<paddle::DataType::FLOAT32>;
break;
case paddle::DataType::INT64:
func = &AdjustBatchKernel<paddle::DataType::INT64>;
break;
case paddle::DataType::FLOAT32:
func = &AdjustBatchKernel<paddle::DataType::FLOAT32>;
break;
default:
PD_THROW("Unsupported data type: ", x.dtype());
}
@@ -121,13 +135,14 @@ std::vector<paddle::Tensor> AdjustBatch(
return func(x,
cum_offsets,
encoder_seq_lod,
decoder_seq_lod,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
enc_batch_tensor,
dec_batch_tensor,
len_info_cpu,
output_padding_offset,
max_input_length);
}
@@ -136,13 +151,14 @@ std::vector<std::vector<int64_t>> AdjustBatchInferShape(
const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &cum_offsets_shape,
const std::vector<int64_t> &encoder_seq_lod_shape,
const std::vector<int64_t> &decoder_seq_lod_shape,
const std::vector<int64_t> &encoder_batch_idx_shape,
const std::vector<int64_t> &decoder_batch_idx_shape,
const std::vector<int64_t> &encoder_seq_lod_cpu_shape,
const std::vector<int64_t> &decoder_seq_lod_cpu_shape,
const std::vector<int64_t> &encoder_batch_idx_cpu_shape,
const std::vector<int64_t> &decoder_batch_idx_cpu_shape,
const std::vector<int64_t> &enc_batch_tensor_shape,
const std::vector<int64_t> &dec_batch_tensor_shape,
const std::vector<int64_t> &len_info_cpu_shape,
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
if (output_padding_offset_shape) {
PD_THROW("speculative decoding is not supported in XPU.");
@@ -156,28 +172,30 @@ std::vector<paddle::DataType> AdjustBatchInferDtype(
const paddle::DataType &x_dtype,
const paddle::DataType &cum_offsets_dtype,
const paddle::DataType &encoder_seq_lod_dtype,
const paddle::DataType &decoder_seq_lod_dtype,
const paddle::DataType &encoder_batch_idx_dtype,
const paddle::DataType &decoder_batch_idx_dtype,
const paddle::DataType &encoder_seq_lod_cpu_dtype,
const paddle::DataType &decoder_seq_lod_cpu_dtype,
const paddle::DataType &encoder_batch_idx_cpu_dtype,
const paddle::DataType &decoder_batch_idx_cpu_dtype,
const paddle::DataType &enc_batch_tensor_dtype,
const paddle::DataType &dec_batch_tensor_dtype,
const paddle::DataType &len_info_cpu_dtype,
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
return {x_dtype};
}
PD_BUILD_OP(adjust_batch)
PD_BUILD_STATIC_OP(adjust_batch)
.Inputs({"x",
"cum_offsets",
"encoder_seq_lod",
"decoder_seq_lod",
"encoder_batch_idx",
"decoder_batch_idx",
"encoder_seq_lod_cpu",
"decoder_seq_lod_cpu",
"encoder_batch_idx_cpu",
"decoder_batch_idx_cpu",
"enc_batch_tensor",
"dec_batch_tensor",
"len_info_cpu",
paddle::Optional("output_padding_offset")})
.Outputs({"out"})
.Attrs({"max_input_length: int"})

View File

@@ -722,7 +722,6 @@ std::vector<paddle::Tensor> BlockAttnKernel(
: quant_v_scale_inv,
nullptr, // o_maxptr
param.head_dim); // vo_head_dim
PD_CHECK(0, "speculative_attention unimplemented");
PD_CHECK(ret == api::SUCCESS,
"xfa::speculative_attention_decoder failed.");
if (!Eq_len) {

View File

@@ -13,107 +13,169 @@
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <xft/xdnn_plugin.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor &tmp_out, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &encoder_batch_map,
const paddle::Tensor &decoder_batch_map,
const paddle::Tensor &encoder_seq_lod_cpu,
const paddle::Tensor &encoder_batch_map_cpu,
const paddle::Tensor &decoder_batch_map_cpu,
const paddle::Tensor &enc_batch_tensor,
const paddle::Tensor &dec_batch_tensor,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
const paddle::Tensor& x, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& encoder_seq_lod,
const paddle::Tensor& decoder_seq_lod,
const paddle::Tensor& encoder_batch_map,
const paddle::Tensor& decoder_batch_map,
const paddle::Tensor& encoder_seq_lod_cpu,
const paddle::Tensor& decoder_seq_lod_cpu,
const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& len_info_cpu,
const paddle::optional<paddle::Tensor>& output_padding_offset,
int max_bsz) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
auto ctx = static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
if (x.is_cpu()) {
ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU);
}
using XPUType =
typename XPUTypeTrait<bfloat16>::Type; // only support bfloat16
typedef paddle::bfloat16 data_t;
const int dim = tmp_out.dims()[1];
const int bsz = cum_offsets.shape()[0];
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
const int dim = x.dims()[1];
const int token_num = x.shape()[0];
int bsz = cum_offsets.shape()[0];
int enc_batch = len_info_cpu.data<int32_t>()[0];
int dec_batch = len_info_cpu.data<int32_t>()[1];
if (max_bsz > 0) {
PD_CHECK(encoder_batch_map_cpu.data<int32_t>()[enc_batch - 1] <= max_bsz,
"encoder_batch_map_cpu check failed");
PD_CHECK(decoder_batch_map_cpu.data<int32_t>()[dec_batch - 1] <= max_bsz,
"decoder_batch_map_cpu check failed");
bsz = max_bsz;
}
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp{
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
const_cast<int32_t*>(encoder_seq_lod_cpu.data<int32_t>()),
enc_batch + 1,
const_cast<int32_t *>(encoder_seq_lod.data<int32_t>())};
const_cast<int32_t*>(encoder_seq_lod.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> decoder_seqs_lods_vp{
const_cast<int32_t*>(decoder_seq_lod_cpu.data<int32_t>()),
dec_batch + 1,
const_cast<int32_t*>(decoder_seq_lod.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp{
const_cast<int32_t *>(encoder_batch_map_cpu.data<int32_t>()),
const_cast<int32_t*>(encoder_batch_map_cpu.data<int32_t>()),
enc_batch,
const_cast<int32_t *>(encoder_batch_map.data<int32_t>())};
const_cast<int32_t*>(encoder_batch_map.data<int32_t>())};
baidu::xpu::api::VectorParam<int32_t> decoder_batch_map_vp{
const_cast<int32_t *>(decoder_batch_map_cpu.data<int32_t>()),
const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()),
dec_batch,
const_cast<int32_t *>(decoder_batch_map.data<int32_t>())};
const_cast<int32_t*>(decoder_batch_map.data<int32_t>())};
auto out = paddle::full({bsz, dim}, -2, tmp_out.type(), tmp_out.place());
paddle::Tensor out;
if (output_padding_offset) {
int need_delete_token_num = 0;
if (enc_batch > 0) {
need_delete_token_num =
encoder_seq_lod_cpu.data<int32_t>()[enc_batch] - enc_batch;
}
out = paddle::empty(
{token_num - need_delete_token_num, dim}, x.type(), x.place());
} else {
out = paddle::empty({bsz, dim}, x.type(), x.place());
}
if (x.shape()[0] == 0) {
return {out};
}
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<const XPUType *>(tmp_out.data<data_t>()),
reinterpret_cast<XPUType *>(out.data<data_t>()),
encoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
if (enc_batch <= 0) {
out = x.copy_to(x.place(), false);
} else {
if (output_padding_offset) {
int r =
baidu::xpu::api::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x.data<data_t>()),
reinterpret_cast<XPUType*>(out.data<data_t>()),
encoder_seqs_lods_vp,
decoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
} else {
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x.data<data_t>()),
reinterpret_cast<XPUType*>(out.data<data_t>()),
encoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
}
}
return {out};
}
std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
const std::vector<int64_t> &tmp_out_shape,
const std::vector<int64_t> &cum_offsets_shape,
const std::vector<int64_t> &encoder_seq_lod_shape,
const std::vector<int64_t> &encoder_batch_map_shape,
const std::vector<int64_t> &decoder_batch_map_shape,
const std::vector<int64_t> &encoder_seq_lod_cpu_shape,
const std::vector<int64_t> &encoder_batch_map_cpu_shape,
const std::vector<int64_t> &decoder_batch_map_cpu_shape,
const std::vector<int64_t> &enc_batch_tensor_shape,
const std::vector<int64_t> &dec_batch_tensor_shape,
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
if (output_padding_offset_shape) {
PD_THROW("speculative decoding is not supported in XPU.");
}
const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& encoder_seq_lod_shape,
const std::vector<int64_t>& decoder_seq_lod_shape,
const std::vector<int64_t>& encoder_batch_map_shape,
const std::vector<int64_t>& decoder_batch_map_shape,
const std::vector<int64_t>& encoder_seq_lod_cpu_shape,
const std::vector<int64_t>& decoder_seq_lod_cpu_shape,
const std::vector<int64_t>& encoder_batch_map_cpu_shape,
const std::vector<int64_t>& decoder_batch_map_cpu_shape,
const std::vector<int64_t>& len_info_cpu_shape,
const paddle::optional<std::vector<int64_t>>& output_padding_offset_shape) {
// if (output_padding_offset_shape) {
// PD_THROW("speculative decoding is not supported in XPU.");
// }
int64_t bsz = cum_offsets_shape[0];
int64_t dim_embed = tmp_out_shape[1];
return {{bsz, dim_embed}};
int64_t dim_embed = x_shape[1];
if (output_padding_offset_shape) {
return {{-1, dim_embed}};
} else {
int64_t bsz = cum_offsets_shape[0];
return {{bsz, dim_embed}};
}
}
std::vector<paddle::DataType> GatherNextTokenInferDtype(
const paddle::DataType &tmp_out_dtype,
const paddle::DataType &cum_offsets_dtype,
const paddle::DataType &encoder_seq_lod_dtype,
const paddle::DataType &encoder_batch_map_dtype,
const paddle::DataType &decoder_batch_map_dtype,
const paddle::DataType &encoder_seq_lod_cpu_dtype,
const paddle::DataType &encoder_batch_map_cpu_dtype,
const paddle::DataType &decoder_batch_map_cpu_dtype,
const paddle::DataType &enc_batch_tensor_dtype,
const paddle::DataType &dec_batch_tensor_dtype,
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
return {tmp_out_dtype};
const paddle::DataType& x_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& encoder_seq_lod_dtype,
const paddle::DataType& decoder_seq_lod_dtype,
const paddle::DataType& encoder_batch_map_dtype,
const paddle::DataType& decoder_batch_map_dtype,
const paddle::DataType& encoder_seq_lod_cpu_dtype,
const paddle::DataType& decoder_seq_lod_cpu_dtype,
const paddle::DataType& encoder_batch_map_cpu_dtype,
const paddle::DataType& decoder_batch_map_cpu_dtype,
const paddle::DataType& len_info_cpu_dtype,
const paddle::optional<paddle::DataType>& output_padding_offset_dtype) {
return {x_dtype};
}
PD_BUILD_OP(gather_next_token)
.Inputs({"tmp_out",
PD_BUILD_STATIC_OP(gather_next_token)
.Inputs({"x",
"cum_offsets",
"encoder_seq_lod",
"decoder_seq_lod",
"encoder_batch_map",
"decoder_batch_map",
"encoder_seq_lod_cpu",
"decoder_seq_lod_cpu",
"encoder_batch_map_cpu",
"decoder_batch_map_cpu",
"enc_batch_tensor",
"dec_batch_tensor",
"len_info_cpu",
paddle::Optional("output_padding_offset")})
.Outputs({"out"})
.Attrs({"max_input_length: int"})
.Attrs({"max_bsz: int"})
.SetKernelFn(PD_KERNEL(GatherNextToken))
.SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype));

View File

@@ -29,21 +29,23 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& seq_lens_encoder_record,
const paddle::Tensor& seq_lens_decoder_record,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& is_block_step,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& base_model_is_block_step,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const int num_model_step,
const bool truncate_first_token,
const bool splitwise_prefill) {
const bool splitwise_prefill,
const bool kvcache_scheduler_v1) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
api::Context* ctx = static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
@@ -54,6 +56,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
int accept_tokens_len = accept_tokens.shape()[1];
int input_ids_len = input_ids.shape()[1];
int draft_tokens_len = draft_tokens.shape()[1];
int pre_ids_len = pre_ids.shape()[1];
constexpr int BlockSize = 512;
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
auto not_need_stop_gpu =
not_need_stop.copy_to(seq_lens_this_time.place(), false);
@@ -67,12 +71,13 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(seq_lens_encoder_record.data<int>()),
const_cast<int*>(seq_lens_decoder_record.data<int>()),
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<bool*>(is_block_step.data<bool>()),
const_cast<bool*>(batch_drop.data<bool>()),
const_cast<int64_t*>(pre_ids.data<int64_t>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
base_model_seq_lens_this_time.data<int>(),
base_model_seq_lens_encoder.data<int>(),
base_model_seq_lens_decoder.data<int>(),
base_model_step_idx.data<int64_t>(),
@@ -80,13 +85,16 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
base_model_is_block_step.data<bool>(),
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
real_bsz,
max_draft_token,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
truncate_first_token,
splitwise_prefill);
splitwise_prefill,
kvcache_scheduler_v1);
PD_CHECK(r == 0, "xpu::plugin::draft_model_preprocess failed.");
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
@@ -102,12 +110,13 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx",
"seq_lens_encoder_record",
"seq_lens_decoder_record",
"not_need_stop",
"is_block_step",
"batch_drop",
"pre_ids",
"accept_tokens",
"accept_num",
"base_model_seq_lens_this_time",
"base_model_seq_lens_encoder",
"base_model_seq_lens_decoder",
"base_model_step_idx",
@@ -123,11 +132,11 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"step_idx_out",
"not_need_stop_out",
"batch_drop_out",
"seq_lens_encoder_record_out",
"seq_lens_decoder_record_out"})
.Attrs({"max_draft_token: int",
"pre_ids_out"})
.Attrs({"num_model_step: int",
"truncate_first_token: bool",
"splitwise_prefill: bool"})
"splitwise_prefill: bool",
"kvcache_scheduler_v1: bool"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"input_ids", "input_ids_out"},
{"stop_flags", "stop_flags_out"},
@@ -137,6 +146,5 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
{"step_idx", "step_idx_out"},
{"not_need_stop", "not_need_stop_out"},
{"batch_drop", "batch_drop_out"},
{"seq_lens_encoder_record", "seq_lens_encoder_record_out"},
{"seq_lens_decoder_record", "seq_lens_decoder_record_out"}})
{"pre_ids", "pre_ids_out"}})
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));

View File

@@ -43,6 +43,8 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
{token_num_data}, paddle::DataType::INT64, input_ids.place());
auto padding_offset = paddle::empty(
{token_num_data}, paddle::DataType::INT32, input_ids.place());
auto batch_id_per_token = paddle::empty(
{token_num_data}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
@@ -57,7 +59,7 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
int r = baidu::xpu::api::plugin::speculate_get_padding_offset(
xpu_ctx->x_context(),
padding_offset.data<int>(),
batch_id_per_token.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
@@ -83,7 +85,7 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
return {x_remove_padding,
cum_offsets_out,
padding_offset,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k}; // , enc_token_num, dec_token_num};
}
@@ -123,7 +125,7 @@ PD_BUILD_STATIC_OP(speculate_get_padding_offset)
"seq_lens_encoder"})
.Outputs({"x_remove_padding",
"cum_offsets_out",
"padding_offset",
"batch_id_per_token",
"cu_seqlens_q",
"cu_seqlens_k"})
.SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset))

View File

@@ -35,7 +35,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
int save_each_rank) {
bool save_each_rank) {
// printf("enter save output");
if (!save_each_rank && rank_id > 0) {
return;

View File

@@ -0,0 +1,187 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
void SpeculateStepPaddle(
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &ori_seq_lens_encoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &step_block_list,
const paddle::Tensor &step_lens,
const paddle::Tensor &recover_block_list,
const paddle::Tensor &recover_lens,
const paddle::Tensor &need_block_list,
const paddle::Tensor &need_block_len,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const paddle::Tensor &input_ids,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &next_tokens,
const paddle::Tensor &first_token_ids,
const paddle::Tensor &accept_num,
const int block_size,
const int encoder_decoder_block_num,
const int max_draft_tokens) {
namespace api = baidu::xpu::api;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
api::Context *ctx = xpu_ctx->x_context();
if (seq_lens_this_time.is_cpu()) {
ctx = new api::Context(api::kCPU);
}
const int bsz = seq_lens_this_time.shape()[0];
PADDLE_ENFORCE_LE(
bsz,
640,
phi::errors::InvalidArgument(
"Only support bsz <= 640, but received bsz is %d", bsz));
const int block_num_per_seq = block_tables.shape()[1];
const int length = input_ids.shape()[1];
const int pre_id_length = pre_ids.shape()[1];
const int max_decoder_block_num = pre_id_length / block_size;
int r = baidu::xpu::api::plugin::speculate_free_and_dispatch_block(
ctx,
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(encoder_block_lens.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<int *>(step_block_list.data<int>()),
const_cast<int *>(step_lens.data<int>()),
const_cast<int *>(recover_block_list.data<int>()),
const_cast<int *>(recover_lens.data<int>()),
const_cast<int *>(need_block_list.data<int>()),
const_cast<int *>(need_block_len.data<int>()),
const_cast<int *>(used_list_len.data<int>()),
const_cast<int *>(free_list.data<int>()),
const_cast<int *>(free_list_len.data<int>()),
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
PD_CHECK(r == 0, "speculate_free_and_dispatch_block failed.");
auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false);
int recover_lens_cpu_data = recover_lens_cpu.data<int>()[0];
if (recover_lens_cpu_data > 0) {
r = baidu::xpu::api::plugin::speculate_recover_block(
ctx,
const_cast<int *>(recover_block_list.data<int>()),
const_cast<int *>(recover_lens.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
ori_seq_lens_encoder.data<int>(),
const_cast<int *>(seq_lens_encoder.data<int>()),
seq_lens_decoder.data<int>(),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(free_list.data<int>()),
const_cast<int *>(free_list_len.data<int>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
encoder_block_lens.data<int>(),
used_list_len.data<int>(),
next_tokens.data<int64_t>(),
first_token_ids.data<int64_t>(),
bsz,
block_num_per_seq,
length,
pre_id_length);
PD_CHECK(r == 0, "speculate_recover_block failed.");
}
}
PD_BUILD_STATIC_OP(speculate_step_paddle)
.Inputs({"stop_flags",
"seq_lens_this_time",
"ori_seq_lens_encoder",
"seq_lens_encoder",
"seq_lens_decoder",
"block_tables",
"encoder_block_lens",
"is_block_step",
"step_block_list",
"step_lens",
"recover_block_list",
"recover_lens",
"need_block_list",
"need_block_len",
"used_list_len",
"free_list",
"free_list_len",
"input_ids",
"pre_ids",
"step_idx",
"next_tokens",
"first_token_ids",
"accept_num"})
.Attrs({"block_size: int",
"encoder_decoder_block_num: int",
"max_draft_tokens: int"})
.Outputs({"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"block_tables_out",
"encoder_block_lens_out",
"is_block_step_out",
"step_block_list_out",
"step_lens_out",
"recover_block_list_out",
"recover_lens_out",
"need_block_list_out",
"need_block_len_out",
"used_list_len_out",
"free_list_out",
"free_list_len_out",
"input_ids_out",
"first_token_ids_out"})
.SetInplaceMap({{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"block_tables", "block_tables_out"},
{"encoder_block_lens", "encoder_block_lens_out"},
{"is_block_step", "is_block_step_out"},
{"step_block_list", "step_block_list_out"},
{"step_lens", "step_lens_out"},
{"recover_block_list", "recover_block_list_out"},
{"recover_lens", "recover_lens_out"},
{"need_block_list", "need_block_list_out"},
{"need_block_len", "need_block_len_out"},
{"used_list_len", "used_list_len_out"},
{"free_list", "free_list_out"},
{"free_list_len", "free_list_len_out"},
{"input_ids", "input_ids_out"},
{"first_token_ids", "first_token_ids_out"}})
.SetKernelFn(PD_KERNEL(SpeculateStepPaddle));

View File

@@ -45,7 +45,10 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
const paddle::Tensor &topp,
int max_seq_len,
int verify_window,
bool enable_topp) {
bool enable_topp,
bool benchmark_mode,
bool accept_all_drafts) {
// TODO(chenhuan09):support accept_all_drafts
auto bsz = accept_tokens.shape()[0];
int real_bsz = seq_lens_this_time.shape()[0];
auto max_draft_tokens = draft_tokens.shape()[1];
@@ -133,7 +136,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
prefill_one_step_stop,
benchmark_mode);
} else {
baidu::xpu::api::plugin::speculate_verify<false, true>(
ctx,
@@ -161,7 +165,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
prefill_one_step_stop,
benchmark_mode);
}
} else {
if (enable_topp) {
@@ -191,7 +196,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
prefill_one_step_stop,
benchmark_mode);
} else {
baidu::xpu::api::plugin::speculate_verify<false, false>(
ctx,
@@ -219,7 +225,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
prefill_one_step_stop,
benchmark_mode);
}
}
}
@@ -246,7 +253,11 @@ PD_BUILD_STATIC_OP(speculate_verify)
"accept_num_out",
"step_idx_out",
"stop_flags_out"})
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"})
.Attrs({"max_seq_len: int",
"verify_window: int",
"enable_topp: bool",
"benchmark_mode: bool",
"accept_all_drafts: bool"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"accept_num", "accept_num_out"},
{"step_idx", "step_idx_out"},

View File

@@ -37,13 +37,14 @@ std::vector<paddle::Tensor> AdjustBatch(
const paddle::Tensor& x, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& encoder_seq_lod,
const paddle::Tensor& decoder_seq_lod,
const paddle::Tensor& encoder_batch_idx,
const paddle::Tensor& decoder_batch_idx,
const paddle::Tensor& encoder_seq_lod_cpu,
const paddle::Tensor& decoder_seq_lod_cpu,
const paddle::Tensor& encoder_batch_idx_cpu,
const paddle::Tensor& decoder_batch_idx_cpu,
const paddle::Tensor& enc_batch_tensor,
const paddle::Tensor& dec_batch_tensor,
const paddle::Tensor& len_info_cpu,
const paddle::optional<paddle::Tensor>& output_padding_offset,
int max_input_length);
@@ -264,7 +265,9 @@ void SpeculateVerify(const paddle::Tensor& accept_tokens,
const paddle::Tensor& topp,
int max_seq_len,
int verify_window,
bool enable_topp);
bool enable_topp,
bool benchmark_mode,
bool accept_all_drafts);
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder);
@@ -285,21 +288,23 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& seq_lens_encoder_record,
const paddle::Tensor& seq_lens_decoder_record,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& is_block_step,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& base_model_is_block_step,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const int num_model_step,
const bool truncate_first_token,
const bool splitwise_prefill);
const bool splitwise_prefill,
const bool kvcache_scheduler_v1);
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
@@ -324,18 +329,19 @@ std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
const paddle::Tensor& step_idx);
std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor& tmp_out, // [token_num, dim_embed]
const paddle::Tensor& x, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& encoder_seq_lod,
const paddle::Tensor& decoder_seq_lod,
const paddle::Tensor& encoder_batch_map,
const paddle::Tensor& decoder_batch_map,
const paddle::Tensor& encoder_seq_lod_cpu,
const paddle::Tensor& decoder_seq_lod_cpu,
const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& enc_batch_tensor,
const paddle::Tensor& dec_batch_tensor,
const paddle::Tensor& len_info_cpu,
const paddle::optional<paddle::Tensor>& output_padding_offset,
int max_input_length);
int max_bsz);
std::vector<paddle::Tensor> GetImgBoundaries(
const paddle::Tensor& task_input_ids,
@@ -436,6 +442,34 @@ void MTPStepPaddle(
const int block_size,
const int max_draft_tokens);
void SpeculateStepPaddle(
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& ori_seq_lens_encoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor& encoder_block_lens,
const paddle::Tensor& is_block_step,
const paddle::Tensor& step_block_list,
const paddle::Tensor& step_lens,
const paddle::Tensor& recover_block_list,
const paddle::Tensor& recover_lens,
const paddle::Tensor& need_block_list,
const paddle::Tensor& need_block_len,
const paddle::Tensor& used_list_len,
const paddle::Tensor& free_list,
const paddle::Tensor& free_list_len,
const paddle::Tensor& input_ids,
const paddle::Tensor& pre_ids,
const paddle::Tensor& step_idx,
const paddle::Tensor& next_tokens,
const paddle::Tensor& first_token_ids,
const paddle::Tensor& accept_num,
const int block_size,
const int encoder_decoder_block_num,
const int max_draft_tokens);
void SaveOutMmsgStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
@@ -542,13 +576,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("x"),
py::arg("cum_offsets"),
py::arg("encoder_seq_lod"),
py::arg("decoder_seq_lod"),
py::arg("encoder_batch_idx"),
py::arg("decoder_batch_idx"),
py::arg("encoder_seq_lod_cpu"),
py::arg("decoder_seq_lod_cpu"),
py::arg("encoder_batch_idx_cpu"),
py::arg("decoder_batch_idx_cpu"),
py::arg("enc_batch_tensor"),
py::arg("dec_batch_tensor"),
py::arg("len_info_cpu"),
py::arg("output_padding_offset"),
py::arg("max_input_length"),
"adjust batch in XPU");
@@ -620,21 +655,23 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("step_idx"),
py::arg("seq_lens_encoder_record"),
py::arg("seq_lens_decoder_record"),
py::arg("not_need_stop"),
py::arg("is_block_step"),
py::arg("batch_drop"),
py::arg("pre_ids"),
py::arg("accept_tokens"),
py::arg("accept_num"),
py::arg("base_model_seq_lens_this_time"),
py::arg("base_model_seq_lens_encoder"),
py::arg("base_model_seq_lens_decoder"),
py::arg("base_model_step_idx"),
py::arg("base_model_stop_flags"),
py::arg("base_model_is_block_step"),
py::arg("base_model_draft_tokens"),
py::arg("max_draft_token"),
py::arg("num_model_step"),
py::arg("truncate_first_token"),
py::arg("splitwise_prefill"),
py::arg("kvcache_scheduler_v1"),
"Preprocess data for draft model in speculative decoding");
m.def("draft_model_postprocess",
@@ -727,18 +764,19 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("gather_next_token",
&GatherNextToken,
py::arg("tmp_out"),
py::arg("x"),
py::arg("cum_offsets"),
py::arg("encoder_seq_lod"),
py::arg("decoder_seq_lod"),
py::arg("encoder_batch_map"),
py::arg("decoder_batch_map"),
py::arg("encoder_seq_lod_cpu"),
py::arg("decoder_seq_lod_cpu"),
py::arg("encoder_batch_map_cpu"),
py::arg("decoder_batch_map_cpu"),
py::arg("enc_batch_tensor"),
py::arg("dec_batch_tensor"),
py::arg("len_info_cpu"),
py::arg("output_padding_offset"),
py::arg("max_input_length"),
py::arg("max_bsz"),
"Gather next token for XPU");
m.def("get_img_boundaries",
@@ -983,6 +1021,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("max_seq_len"),
py::arg("verify_window"),
py::arg("enable_topp"),
py::arg("benchmark_mode"),
py::arg("accept_all_drafts"),
"Perform speculative verification for decoding");
m.def("speculate_clear_accept_nums",
@@ -1104,6 +1144,36 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("encoder_decoder_block_num"),
"Step paddle function");
m.def("speculate_step_paddle",
&SpeculateStepPaddle,
py::arg("stop_flags"),
py::arg("seq_lens_this_time"),
py::arg("ori_seq_lens_encoder"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("block_tables"),
py::arg("encoder_block_lens"),
py::arg("is_block_step"),
py::arg("step_block_list"),
py::arg("step_lens"),
py::arg("recover_block_list"),
py::arg("recover_lens"),
py::arg("need_block_list"),
py::arg("need_block_len"),
py::arg("used_list_len"),
py::arg("free_list"),
py::arg("free_list_len"),
py::arg("input_ids"),
py::arg("pre_ids"),
py::arg("step_idx"),
py::arg("next_tokens"),
py::arg("first_token_ids"),
py::arg("accept_num"),
py::arg("block_size"),
py::arg("encoder_decoder_block_num"),
py::arg("max_draft_tokens"),
"Step paddle function");
m.def("text_image_gather_scatter",
&TextImageGatherScatter,
py::arg("input"),

View File

@@ -75,6 +75,48 @@ DLL_EXPORT int get_padding_offset(Context* ctx,
const int max_seq_len,
const int bs);
DLL_EXPORT int speculate_get_padding_offset(Context* ctx,
int* batch_id_per_token,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
const int* cum_offsets,
const int* seq_lens,
const int max_seq_len,
int bsz);
DLL_EXPORT int draft_model_preprocess(api::Context* ctx,
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const int pre_ids_len,
const bool truncate_first_token,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1);
DLL_EXPORT int update_inputs(Context* ctx,
bool* not_need_stop,
int* seq_lens_this_time,
@@ -111,6 +153,31 @@ DLL_EXPORT int free_and_dispatch_block(Context* ctx,
const int block_num_per_seq,
const int max_decoder_block_num);
DLL_EXPORT int speculate_free_and_dispatch_block(
Context* ctx,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_decoder,
int* block_tables,
int* encoder_block_lens,
bool* is_block_step,
int* step_block_list, // [bsz]
int* step_len,
int* recover_block_list,
int* recover_len,
int* need_block_list,
int* need_block_len,
int* used_list_len,
int* free_list,
int* free_list_len,
int64_t* first_token_ids,
int* accept_num,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens);
DLL_EXPORT int recover_block(Context* ctx,
int* recover_block_list, // [bsz]
int* recover_len,
@@ -134,6 +201,29 @@ DLL_EXPORT int recover_block(Context* ctx,
const int length,
const int pre_id_length);
DLL_EXPORT int speculate_recover_block(Context* ctx,
int* recover_block_list, // [bsz]
int* recover_len,
bool* stop_flags,
int* seq_lens_this_time,
const int* ori_seq_lens_encoder,
int* seq_lens_encoder,
const int* seq_lens_decoder,
int* block_tables,
int* free_list,
int* free_list_len,
int64_t* input_ids,
const int64_t* pre_ids,
const int64_t* step_idx,
const int* encoder_block_lens,
const int* used_list_len,
const int64_t* next_tokens,
const int64_t* first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length);
DLL_EXPORT int recover_decode_task(Context* ctx,
bool* stop_flags,
int* seq_lens_this_time,
@@ -172,6 +262,7 @@ DLL_EXPORT int eb_adjust_batch(
const TX* x,
TY* y,
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
VectorParam<int32_t>& encoder_batch_map, // NOLINT
VectorParam<int32_t>& decoder_batch_map, // NOLINT
int64_t hidden_dim);
@@ -186,6 +277,17 @@ DLL_EXPORT int eb_gather_next_token(
VectorParam<int32_t>& decoder_batch_map, // NOLINT
int64_t hidden_dim);
template <typename TX, typename TY>
DLL_EXPORT int eb_mtp_gather_next_token(
Context* ctx,
const TX* x,
TY* y,
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
VectorParam<int32_t>& encoder_batch_map, // NOLINT
VectorParam<int32_t>& decoder_batch_map, // NOLINT
int64_t hidden_dim);
template <typename TX, typename TSCALE = float, typename TY = int8_t>
DLL_EXPORT int quant2d_per_channel(api::Context* ctx,
const TX* x,
@@ -305,7 +407,8 @@ DLL_EXPORT int speculate_verify(Context* ctx,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const bool prefill_one_step_stop);
const bool prefill_one_step_stop,
const bool benchmark_mode);
DLL_EXPORT int speculate_clear_accept_nums(Context* ctx,
int* accept_num,
@@ -342,35 +445,6 @@ DLL_EXPORT int draft_model_update(Context* ctx,
const int substep,
const bool prefill_one_step_stop);
DLL_EXPORT int draft_model_preprocess(api::Context* ctx,
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
int real_bsz,
int max_draft_token,
int accept_tokens_len,
int draft_tokens_len,
int input_ids_len,
int base_model_draft_tokens_len,
bool truncate_first_token,
bool splitwise_prefill);
DLL_EXPORT int speculate_set_stop_value_multi_seqs(Context* ctx,
bool* stop_flags,
int64_t* accept_tokens,
@@ -411,16 +485,6 @@ DLL_EXPORT int speculate_remove_padding(Context* ctx,
int bsz,
int token_num_data);
DLL_EXPORT int speculate_get_padding_offset(Context* ctx,
int* padding_offset,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
const int* cum_offsets,
const int* seq_lens,
const int max_seq_len,
int bsz);
DLL_EXPORT int compute_self_order(api::Context* ctx,
const int* last_seq_lens_this_time,
const int* seq_lens_this_time,

View File

@@ -4,7 +4,7 @@
namespace xpu3 {
namespace plugin {
#define MAX_LM_SIZE 28672
// One core has 32KB LMgroup LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
// One core has 32KB LM(gropu LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
// the stack space
#define MAX_BATCH 512
#define ALIGNMENT 64
@@ -53,6 +53,7 @@ template <typename TX, typename TY>
__global__ void eb_adjust_batch(TX* src,
TY* dst,
int* encoder_seqs_lods,
int* decoder_seqs_lods,
int* encoder_batch_map,
int* decoder_batch_map,
int en_batch,
@@ -61,9 +62,11 @@ __global__ void eb_adjust_batch(TX* src,
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = core_num() * cluster_num();
__group_shared__ int local_lods_en[MAX_BATCH + 1];
__group_shared__ int local_lods_de[MAX_BATCH + 1];
__group_shared__ int local_map_en[MAX_BATCH];
__group_shared__ int local_map_de[MAX_BATCH];
GM2GSM_ASYNC(encoder_seqs_lods, local_lods_en, (en_batch + 1) * sizeof(int));
GM2GSM_ASYNC(decoder_seqs_lods, local_lods_de, (de_batch + 1) * sizeof(int));
if (en_batch > 0) {
GM2GSM_ASYNC(encoder_batch_map, local_map_en, en_batch * sizeof(int));
}
@@ -72,7 +75,8 @@ __global__ void eb_adjust_batch(TX* src,
}
mfence();
int max_encoder_len = local_lods_en[en_batch];
int seq_sum = max_encoder_len + de_batch;
int max_decoder_len = local_lods_de[de_batch];
int seq_sum = max_encoder_len + max_decoder_len;
int total_batch = en_batch + de_batch;
int start = 0;
int end = 0;
@@ -82,13 +86,16 @@ __global__ void eb_adjust_batch(TX* src,
while (i < end) {
if (i >= max_encoder_len) {
// dst decode part
int cur_de_bs = i - max_encoder_len;
int cur_de_bs = 0;
get_cur_batch(local_lods_de, de_batch, i - max_encoder_len, cur_de_bs);
int cur_en_bs = local_map_de[cur_de_bs] - cur_de_bs;
int cur_len =
min(end, local_lods_de[cur_de_bs + 1] + max_encoder_len) - i;
_global_ptr_ TY* cur_dst = dst + i * copy_size;
_global_ptr_ TX* cur_src =
src + (cur_de_bs + local_lods_en[cur_en_bs]) * copy_size;
do_memcpy_1d<TX, TY>(cur_src, cur_dst, copy_size);
i++;
src + (local_lods_en[cur_en_bs] + i - max_encoder_len) * copy_size;
do_memcpy_1d<TX, TY>(cur_src, cur_dst, copy_size * cur_len);
i += cur_len;
} else {
// dst encode part
int cur_en_bs = 0;
@@ -97,7 +104,8 @@ __global__ void eb_adjust_batch(TX* src,
cur_de_bs = local_map_en[cur_en_bs] - cur_en_bs;
int cur_len = min(end, local_lods_en[cur_en_bs + 1]) - i;
_global_ptr_ TY* cur_dst = dst + i * copy_size;
_global_ptr_ TX* cur_src = src + (cur_de_bs + i) * copy_size;
_global_ptr_ TX* cur_src =
src + (local_lods_de[cur_de_bs] + i) * copy_size;
do_memcpy_1d<TX, TY>(cur_src, cur_dst, copy_size * cur_len);
i += cur_len;
}
@@ -108,6 +116,7 @@ __global__ void eb_adjust_batch(TX* src,
template __global__ void eb_adjust_batch<TX, TY>(TX * src, \
TY * dst, \
int* encoder_seqs_lods, \
int* decoder_seqs_lods, \
int* encoder_batch_map, \
int* decoder_batch_map, \
int en_batch, \

View File

@@ -20,6 +20,7 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time,
return;
}
// 256 * int
char lm[6 * 1024];
int buf_size = 6 * 1024 / (6 * sizeof(int));
int* lm_base_model_seq_lens_this_time = (int*)lm;
@@ -68,10 +69,7 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time,
in_offset += write_size;
}
mfence_lm();
// 2. base model encoder. Base step=0
} else if (cur_base_model_seq_lens_encoder != 0) {
// nothing happens
// 3. New end
// 2. Base model stop at last verify-step.
} else if (cur_base_model_seq_lens_this_time != 0 &&
cur_seq_lens_this_time == 0) {
in_offset += cur_base_model_seq_lens_this_time;
@@ -80,27 +78,16 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time,
cur_seq_lens_this_time == 0) {
// nothing happens
} else {
if (accept_num <= actual_draft_token_num) {
int position_map_val = out_offset;
LM2GM(&position_map_val,
position_map + in_offset + accept_num - 1,
sizeof(int));
out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
} else {
int position_map_val_1 = out_offset;
LM2GM(&position_map_val_1,
position_map + in_offset + accept_num - 2,
sizeof(int));
out_offset++;
int position_map_val_2 = out_offset;
LM2GM(&position_map_val_2,
position_map + in_offset + accept_num - 1,
sizeof(int));
out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
// accept_num << buf_size, so do not need split
for (int i = 0; i < accept_num; i++) {
lm_position_map[i] = out_offset++;
}
mfence_lm();
LM2GM(lm_position_map,
position_map + in_offset,
accept_num * sizeof(int));
in_offset += cur_base_model_seq_lens_this_time;
mfence_lm();
}
}
}

View File

@@ -13,26 +13,29 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
int real_bsz,
int max_draft_token,
int accept_tokens_len,
int draft_tokens_len,
int input_ids_len,
int base_model_draft_tokens_len,
bool truncate_first_token,
bool splitwise_prefill) {
const int bsz,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const int pre_ids_len,
const bool truncate_first_token,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
@@ -46,7 +49,7 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
int64_t value_fu = -1;
if (splitwise_prefill) {
for (; tid < real_bsz; tid += ncores * nclusters) {
for (; tid < bsz; tid += ncores * nclusters) {
int64_t base_model_step_idx_now = 0;
int seq_lens_encoder_now = 0;
int seq_lens_this_time_now = 0;
@@ -57,35 +60,25 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
GM2LM_ASYNC(
base_model_step_idx + tid, &base_model_step_idx_now, sizeof(int64_t));
GM2LM_ASYNC(seq_lens_encoder_record + tid,
&seq_lens_encoder_record_now,
sizeof(int));
GM2LM_ASYNC(seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int));
GM2LM(accept_tokens + tid * accept_tokens_len,
&base_model_first_token,
sizeof(int64_t));
if (base_model_step_idx_now == 1 && seq_lens_encoder_record_now > 0) {
if (seq_lens_encoder_now > 0) {
not_stop_flag_sm[cid] += 1;
int seq_len_encoder_record = seq_lens_encoder_record_now;
seq_lens_encoder_now = seq_len_encoder_record;
seq_lens_encoder_record_now = -1;
stop_flags_now = false;
int position = seq_len_encoder_record;
int position = seq_lens_encoder_now;
if (truncate_first_token) {
position = position - 1;
input_ids_now = base_model_first_token;
seq_lens_this_time_now = seq_len_encoder_record;
seq_lens_this_time_now = seq_lens_encoder_now;
} else {
input_ids_now = base_model_first_token;
seq_lens_this_time_now = seq_len_encoder_record + 1;
seq_lens_this_time_now = seq_lens_encoder_now + 1;
}
LM2GM_ASYNC(&input_ids_now,
input_ids + tid * input_ids_len + position,
sizeof(int64_t));
LM2GM_ASYNC(&seq_lens_encoder_record_now,
seq_lens_encoder_record + tid,
sizeof(int));
} else {
stop_flags_now = true;
seq_lens_this_time_now = 0;
@@ -98,21 +91,23 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
LM2GM(&seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int));
}
} else {
for (; tid < real_bsz; tid += ncores * nclusters) {
for (; tid < bsz; tid += ncores * nclusters) {
bool base_model_stop_flags_now = false;
bool base_model_is_block_step_now = false;
bool batch_drop_now = false;
bool stop_flags_now = false;
bool is_block_step_now = false;
int seq_lens_this_time_now = 0;
int seq_lens_encoder_record_now = 0;
int seq_lens_encoder_now = 0;
int seq_lens_decoder_new = 0;
int seq_lens_decoder_record_now = 0;
int accept_num_now = 0;
int base_model_seq_lens_decoder_now = 0;
int base_model_seq_lens_this_time_now = 0;
int64_t step_id_now = 0;
int64_t base_model_step_idx_now;
int64_t pre_ids_now;
mfence();
GM2LM_ASYNC(is_block_step + tid, &is_block_step_now, sizeof(bool));
GM2LM_ASYNC(base_model_stop_flags + tid,
&base_model_stop_flags_now,
sizeof(bool));
@@ -121,12 +116,6 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
sizeof(bool));
GM2LM_ASYNC(batch_drop + tid, &batch_drop_now, sizeof(bool));
GM2LM_ASYNC(stop_flags + tid, &stop_flags_now, sizeof(bool));
GM2LM_ASYNC(seq_lens_encoder_record + tid,
&seq_lens_encoder_record_now,
sizeof(int));
GM2LM_ASYNC(seq_lens_decoder_record + tid,
&seq_lens_decoder_record_now,
sizeof(int));
GM2LM_ASYNC(seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int));
GM2LM_ASYNC(seq_lens_decoder + tid, &seq_lens_decoder_new, sizeof(int));
@@ -135,6 +124,9 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
accept_tokens_len * sizeof(int64_t));
GM2LM_ASYNC(accept_num + tid, &accept_num_now, sizeof(int));
GM2LM_ASYNC(base_model_seq_lens_this_time + tid,
&base_model_seq_lens_this_time_now,
sizeof(int));
GM2LM_ASYNC(base_model_seq_lens_decoder + tid,
&base_model_seq_lens_decoder_now,
sizeof(int));
@@ -148,57 +140,67 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
base_model_draft_tokens + tid * base_model_draft_tokens_len + i,
sizeof(int));
}
if (base_model_stop_flags_now && base_model_is_block_step_now) {
batch_drop_now = true;
stop_flags_now = true;
if (kvcache_scheduler_v1) {
if (base_model_stop_flags_now && base_model_is_block_step_now) {
stop_flags_now = true;
is_block_step_now = true;
}
} else {
if (base_model_stop_flags_now && base_model_is_block_step_now) {
batch_drop_now = true;
stop_flags_now = true;
}
}
if (!(base_model_stop_flags_now || batch_drop_now)) {
not_stop_flag_sm[cid] += 1;
if (base_model_step_idx_now == 0) {
seq_lens_this_time_now = 0;
not_stop_flag_sm[cid] -= 1; // 因为上面加过,这次减去,符合=0逻辑
} else if (base_model_step_idx_now == 1 &&
seq_lens_encoder_record_now > 0) {
int seq_len_encoder_record = seq_lens_encoder_record_now;
seq_lens_encoder_now = seq_len_encoder_record;
seq_lens_encoder_record_now = -1;
seq_lens_decoder_new = seq_lens_decoder_record_now;
seq_lens_decoder_record_now = 0;
if (seq_lens_encoder_now > 0) {
int seq_len_encoder = seq_lens_encoder_now;
stop_flags_now = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder_record;
LM2GM(&base_model_first_token,
pre_ids + tid * pre_ids_len,
sizeof(int64_t));
int position = seq_len_encoder;
if (truncate_first_token) {
LM2GM(&base_model_first_token,
input_ids + tid * input_ids_len + position - 1,
sizeof(int64_t));
seq_lens_this_time_now = seq_len_encoder_record;
seq_lens_this_time_now = seq_len_encoder;
} else {
LM2GM(&base_model_first_token,
input_ids + tid * input_ids_len + position,
sizeof(int64_t));
seq_lens_this_time_now = seq_len_encoder_record + 1;
seq_lens_this_time_now = seq_len_encoder + 1;
}
} else {
if (kvcache_scheduler_v1) {
if (!base_model_is_block_step_now && is_block_step_now) {
is_block_step_now = false;
}
}
} else if (accept_num_now <= max_draft_token) {
if (stop_flags_now) {
stop_flags_now = false;
seq_lens_decoder_new = base_model_seq_lens_decoder_now;
step_id_now = base_model_step_idx_now;
} else {
seq_lens_decoder_new -= max_draft_token - accept_num_now;
step_id_now -= max_draft_token - accept_num_now;
}
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
LM2GM(&modified_token,
draft_tokens + tid * draft_tokens_len,
sizeof(int64_t));
seq_lens_this_time_now = 1;
seq_lens_decoder_new = base_model_seq_lens_decoder_now -
base_model_seq_lens_this_time_now;
step_id_now =
base_model_step_idx_now - base_model_seq_lens_this_time_now;
} else /*Accept all draft tokens*/ {
LM2GM(accept_tokens_now + max_draft_token,
draft_tokens + tid * draft_tokens_len + 1,
sizeof(int64_t));
seq_lens_this_time_now = 2;
} else {
seq_lens_decoder_new -= num_model_step - 1;
step_id_now -= num_model_step - 1;
}
for (int i = 0; i < accept_num_now; i++) {
const int pre_id_pos =
base_model_step_idx_now - (accept_num_now - i);
LM2GM(accept_tokens_now + i,
draft_tokens + tid * draft_tokens_len + i,
sizeof(int64_t));
LM2GM(accept_tokens_now + i,
pre_ids + tid * pre_ids_len + pre_id_pos,
sizeof(int64_t));
}
seq_lens_this_time_now = accept_num_now;
}
} else {
@@ -209,17 +211,11 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
}
LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool));
LM2GM_ASYNC(&batch_drop_now, batch_drop + tid, sizeof(bool));
LM2GM_ASYNC(&is_block_step_now, is_block_step + tid, sizeof(bool));
LM2GM_ASYNC(&seq_lens_decoder_new, seq_lens_decoder + tid, sizeof(int));
LM2GM_ASYNC(
&seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int));
LM2GM_ASYNC(&seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int));
LM2GM_ASYNC(&seq_lens_encoder_record_now,
seq_lens_encoder_record + tid,
sizeof(int));
LM2GM_ASYNC(&seq_lens_decoder_record_now,
seq_lens_decoder_record + tid,
sizeof(int));
LM2GM_ASYNC(&step_id_now, step_idx + tid, sizeof(int64_t));
}
}

View File

@@ -60,10 +60,8 @@ __global__ void draft_model_update(const int64_t* inter_next_tokens,
token_this_time = next_tokens_start[seq_len_this_time - 1];
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
base_model_draft_tokens_now[substep + 1] = token_this_time;
for (int i = 0; i < seq_len_this_time; ++i) {
pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i];
}
step_idx[tid] += seq_len_this_time;
pre_ids_now[step_idx[tid]] = token_this_time;
} else {
token_this_time = next_tokens_start[0];
seq_lens_decoder[tid] = seq_len_encoder + seq_len_decoder;

View File

@@ -0,0 +1,129 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_debug.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
#define MAX_LM_SIZE 28672
// One core has 32KB LMgroup LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
// the stack space
#define MAX_BATCH 512
#define ALIGNMENT 64
template <typename TX, typename TY>
static __device__ void do_memcpy_1d(_global_ptr_ TX* src,
_global_ptr_ TY* dst,
int64_t copy_size) {
#ifdef __XPU3__
constexpr int buf_size = 2048;
#else
constexpr int buf_size = 512;
#endif
__group_shared__ __simd__ float double_lmx[2][buf_size];
int64_t pingpong = 0;
for (int64_t i = 0; i < copy_size; i += buf_size) {
int real_size = min<int64_t>(buf_size, copy_size - i);
_group_shared_ptr_ float* lmx = double_lmx[pingpong];
GM2GSM(src + i, lmx, real_size * sizeof(TX));
if (!xpu_std::is_same<TX, TY>::value) {
primitive_cast_gsm<TX, float>(
(_group_shared_ptr_ TX*)lmx, lmx, real_size);
primitive_cast_gsm<float, TY>(
lmx, (_group_shared_ptr_ TY*)lmx, real_size);
}
GSM2GM_ASYNC((_group_shared_ptr_ TY*)lmx, dst + i, real_size * sizeof(TY));
pingpong = 1 - pingpong;
}
mfence();
}
template <typename TX, typename TY>
__global__ void eb_mtp_gather_next_token(TX* src,
TY* dst,
int* encoder_seqs_lods,
int* decoder_seqs_lods,
int* encoder_batch_map,
int* decoder_batch_map,
int en_batch,
int de_batch,
int64_t copy_size) {
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = core_num() * cluster_num();
__group_shared__ int local_lods_en[MAX_BATCH + 1];
__group_shared__ int local_lods_de[MAX_BATCH + 1];
__group_shared__ int local_map_en[MAX_BATCH];
__group_shared__ int local_map_de[MAX_BATCH];
GM2GSM_ASYNC(encoder_seqs_lods, local_lods_en, (en_batch + 1) * sizeof(int));
GM2GSM_ASYNC(decoder_seqs_lods, local_lods_de, (de_batch + 1) * sizeof(int));
if (en_batch > 0) {
GM2GSM_ASYNC(encoder_batch_map, local_map_en, en_batch * sizeof(int));
}
if (de_batch > 0) {
GM2GSM_ASYNC(decoder_batch_map, local_map_de, de_batch * sizeof(int));
}
mfence();
int encoder_len_total = en_batch > 0 ? local_lods_en[en_batch] : 0;
int output_len = en_batch + local_lods_de[de_batch];
int start = 0;
int end = 0;
partition(tid, nthreads, output_len, 1, &start, &end);
for (int i = start; i < end; i++) {
int len = 0;
int enc_idx = 0, dec_idx = 0;
bool is_enc;
while (i >= len) {
if (enc_idx >= en_batch) {
len += local_lods_de[dec_idx + 1] - local_lods_de[dec_idx];
dec_idx++;
is_enc = false;
continue;
}
if (dec_idx >= de_batch) {
len += 1;
enc_idx++;
is_enc = true;
continue;
}
if (local_map_en[enc_idx] < local_map_de[dec_idx]) {
len += 1;
enc_idx++;
is_enc = true;
} else {
len += local_lods_de[dec_idx + 1] - local_lods_de[dec_idx];
dec_idx++;
is_enc = false;
}
}
_global_ptr_ TX* cur_src = nullptr;
_global_ptr_ TY* cur_dst = dst + i * copy_size;
if (is_enc) {
cur_src = src + (local_lods_en[enc_idx] - 1) * copy_size;
} else {
cur_src = src + (encoder_len_total + local_lods_de[dec_idx] - (len - i)) *
copy_size;
}
do_memcpy_1d<TX, TY>(cur_src, cur_dst, copy_size);
}
}
#define _XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(TX, TY) \
template __global__ void eb_mtp_gather_next_token<TX, TY>( \
TX * src, \
TY * dst, \
int* encoder_seqs_lods, \
int* decoder_seqs_lods, \
int* encoder_batch_map, \
int* decoder_batch_map, \
int en_batch, \
int de_batch, \
int64_t copy_size);
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float16, float16);
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(bfloat16, bfloat16);
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float, float);
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float16, float);
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float, float16);
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float16);
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float16, bfloat16);
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float);
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float, bfloat16);
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,337 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
static __device__ inline int loada_float(_shared_ptr_ const int *ptr) {
int ret;
__asm__ __volatile__("loada.w %0,%1" : "=&r"(ret) : "r"(ptr));
return ret;
}
static __device__ inline bool storea_float(_shared_ptr_ int *ptr, int value) {
bool ret;
__asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(ret) : "r"(value), "r"(ptr));
return ret;
}
static __device__ int atomic_add(_shared_ptr_ int *ptr, int value) {
bool fail = true;
int old_value;
while (fail) {
old_value = loada_float(ptr);
int new_value = old_value + value;
fail = storea_float(ptr, new_value);
}
return old_value;
}
static __device__ bool in_need_block_list(const int qid,
_shared_ptr_ int *need_block_list,
const int need_block_len) {
bool res = false;
for (int i = 0; i < need_block_len; i++) {
if (qid == need_block_list[i]) {
need_block_list[i] = -1;
res = true;
break;
}
}
return res;
}
__global__ void speculate_free_and_dispatch_block(
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
int *accept_num,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
if (clusterid != 0 || cid >= bsz) return;
// assert bsz <= 640
const int max_bs = 640;
int value_zero = 0;
bool flag_true = true;
// 128 = seq_len(8192) / block_size(64)
// 每次最多处理block_table数量为128
const int block_table_now_len = 128;
int block_table_now[block_table_now_len];
for (int i = 0; i < block_table_now_len; i++) {
block_table_now[i] = -1;
}
bool stop_flag_lm;
int seq_lens_decoder_lm;
__shared__ int free_list_len_sm;
// 每次最多处理free_list数量为block_table_now_len
int free_list_now[block_table_now_len];
__shared__ int need_block_len_sm;
__shared__ int need_block_list_sm[max_bs];
__shared__ int used_list_len_sm[max_bs];
__shared__ bool step_max_block_flag;
__shared__ int in_need_block_list_len;
if (cid == 0) {
step_max_block_flag = false;
in_need_block_list_len = 0;
GM2SM_ASYNC(free_list_len, &free_list_len_sm, sizeof(int));
GM2SM_ASYNC(need_block_len, &need_block_len_sm, sizeof(int));
mfence();
if (need_block_len_sm > 0) {
GM2SM_ASYNC(
need_block_list, need_block_list_sm, sizeof(int) * need_block_len_sm);
}
GM2SM_ASYNC(used_list_len, used_list_len_sm, sizeof(int) * bsz);
mfence();
}
sync_cluster();
for (int tid = cid; tid < bsz; tid += ncores) {
bool is_block_step_lm;
int seq_lens_this_time_lm;
mfence();
GM2LM_ASYNC(stop_flags + tid, &stop_flag_lm, sizeof(bool));
GM2LM_ASYNC(is_block_step + tid, &is_block_step_lm, sizeof(bool));
GM2LM_ASYNC(seq_lens_decoder + tid, &seq_lens_decoder_lm, sizeof(int));
GM2LM_ASYNC(seq_lens_this_time + tid, &seq_lens_this_time_lm, sizeof(int));
mfence();
int max_possible_block_idx =
(seq_lens_decoder_lm + max_draft_tokens + 1) / block_size;
if (stop_flag_lm && !is_block_step_lm) {
// 回收block块
int64_t first_token_id_lm = -1;
mfence_lm();
LM2GM(&first_token_id_lm, first_token_ids + tid, sizeof(int64_t));
int encoder_block_len_lm;
int decoder_used_len_lm = used_list_len_sm[tid];
GM2LM(encoder_block_lens + tid, &encoder_block_len_lm, sizeof(int));
if (decoder_used_len_lm > 0) {
const int ori_free_list_len =
atomic_add(&free_list_len_sm, decoder_used_len_lm);
for (int i = 0; i < decoder_used_len_lm; i += block_table_now_len) {
int process_len = min(block_table_now_len, decoder_used_len_lm - i);
GM2LM(
block_tables + tid * block_num_per_seq + encoder_block_len_lm + i,
free_list_now,
process_len * sizeof(int));
LM2GM(free_list_now,
free_list + ori_free_list_len + i,
process_len * sizeof(int));
LM2GM(
block_table_now,
block_tables + tid * block_num_per_seq + encoder_block_len_lm + i,
process_len * sizeof(int));
}
used_list_len_sm[tid] = 0;
mfence();
LM2GM(&value_zero, encoder_block_lens + tid, sizeof(int));
}
} else if (seq_lens_this_time_lm != 0 &&
max_possible_block_idx < block_num_per_seq) {
int next_block_id;
GM2LM(block_tables + tid * block_num_per_seq +
(seq_lens_decoder_lm + max_draft_tokens + 1) / block_size,
&next_block_id,
sizeof(int));
if (next_block_id == -1) {
// 统计需要分配block的位置和总数
const int ori_need_block_len = atomic_add(&need_block_len_sm, 1);
need_block_list_sm[ori_need_block_len] = tid;
}
}
}
sync_cluster();
bool is_block_step_lm[max_bs];
int step_len_lm;
int step_block_list_lm[max_bs];
int recover_len_lm;
int recover_block_list_lm[max_bs];
if (cid == 0) {
GM2LM_ASYNC(is_block_step, is_block_step_lm, sizeof(bool) * bsz);
GM2LM_ASYNC(step_len, &step_len_lm, sizeof(int));
GM2LM_ASYNC(step_block_list, step_block_list_lm, sizeof(int) * bsz);
GM2LM_ASYNC(recover_len, &recover_len_lm, sizeof(int));
GM2LM_ASYNC(recover_block_list, recover_block_list_lm, sizeof(int) * bsz);
mfence();
}
if (cid == 0) {
while (need_block_len_sm > free_list_len_sm) {
// 调度block根据used_list_len从大到小回收block直到满足need_block_len已解码到最后一个block的query不参与调度马上就结束
int max_used_list_len_id = 0;
int max_used_list_len = 0;
for (int i = 0; i < bsz; i++) {
if (!is_block_step_lm[i] &&
(step_max_block_flag ||
used_list_len_sm[i] != max_decoder_block_num) &&
(used_list_len_sm[i] > max_used_list_len)) {
max_used_list_len_id = i;
max_used_list_len = used_list_len_sm[i];
}
}
if (max_used_list_len == 0) {
step_max_block_flag = true;
} else {
int encoder_block_len;
GM2LM(encoder_block_lens + max_used_list_len_id,
&encoder_block_len,
sizeof(int));
for (int i = 0; i < max_used_list_len; i += block_table_now_len) {
int process_len = min(block_table_now_len, max_used_list_len - i);
GM2LM(block_tables + max_used_list_len_id * block_num_per_seq +
encoder_block_len + i,
free_list_now,
process_len * sizeof(int));
LM2GM(free_list_now,
free_list + free_list_len_sm + i,
process_len * sizeof(int));
LM2GM(block_table_now,
block_tables + max_used_list_len_id * block_num_per_seq +
encoder_block_len + i,
process_len * sizeof(int));
}
step_block_list_lm[step_len_lm] = max_used_list_len_id;
int need_block_len_all = need_block_len_sm + in_need_block_list_len;
if (in_need_block_list(
max_used_list_len_id, need_block_list_sm, need_block_len_all)) {
need_block_len_sm--;
in_need_block_list_len++;
}
step_len_lm++;
free_list_len_sm += max_used_list_len;
LM2GM_ASYNC(
&flag_true, stop_flags + max_used_list_len_id, sizeof(bool));
is_block_step_lm[max_used_list_len_id] = true;
LM2GM_ASYNC(&value_zero,
seq_lens_this_time + max_used_list_len_id,
sizeof(int));
LM2GM_ASYNC(
&value_zero, seq_lens_decoder + max_used_list_len_id, sizeof(int));
// Note(@wufeisheng): when step, accept num will not be 0 so
// that next step even if this batch member is stepped, save
// output still stream output, so accept num should be set to 0
LM2GM_ASYNC(
&accept_num, accept_num + max_used_list_len_id, sizeof(int));
mfence();
}
}
}
sync_cluster();
int need_block_len_all = need_block_len_sm + in_need_block_list_len;
for (int tid = cid; tid < need_block_len_all; tid += ncores) {
// 为需要block的位置分配block每个位置分配一个block
const int need_block_id = need_block_list_sm[tid];
if (need_block_id != -1) {
GM2LM(stop_flags + need_block_id, &stop_flag_lm, sizeof(bool));
if (!stop_flag_lm) {
// 如果需要的位置正好是上一步中被释放的位置,不做处理
used_list_len_sm[need_block_id]++;
const int ori_free_list_len = atomic_add(&free_list_len_sm, -1);
int tmp_seq_lens_decoder;
GM2LM(seq_lens_decoder + need_block_id,
&tmp_seq_lens_decoder,
sizeof(int));
int free_block_id;
GM2LM(free_list + ori_free_list_len - 1, &free_block_id, sizeof(int));
LM2GM(&free_block_id,
block_tables + need_block_id * block_num_per_seq +
(tmp_seq_lens_decoder + max_draft_tokens + 1) / block_size,
sizeof(int));
}
need_block_list_sm[tid] = -1;
}
}
sync_cluster();
// 计算可以复原的query id
// 每次最多只恢复max_recover_num个query
int max_recover_num = 1;
if (cid == 0 && step_len_lm > 0) {
int ori_free_list_len = free_list_len_sm;
int ori_step_block_id = step_block_list_lm[step_len_lm - 1];
int tmp_used_len = used_list_len_sm[ori_step_block_id];
int encoder_block_len_lm;
GM2LM(encoder_block_lens + ori_step_block_id,
&encoder_block_len_lm,
sizeof(int));
const int max_decoder_block_num_this_seq =
max_decoder_block_num - encoder_block_len_lm;
// 比之前调度时多分配一个block防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中
int used_len = tmp_used_len + 1 < max_decoder_block_num_this_seq
? tmp_used_len + 1
: max_decoder_block_num_this_seq;
while (step_len_lm > 0 && ori_free_list_len >= used_len &&
max_recover_num-- > 0) {
recover_block_list_lm[recover_len_lm] = ori_step_block_id;
is_block_step_lm[ori_step_block_id] = false;
used_list_len_sm[ori_step_block_id] = used_len;
ori_free_list_len -= used_len;
step_block_list_lm[step_len_lm - 1] = -1;
step_len_lm--;
recover_len_lm++;
if (step_len_lm > 0) {
ori_step_block_id = step_block_list_lm[step_len_lm - 1];
tmp_used_len = used_list_len_sm[ori_step_block_id];
used_len = tmp_used_len + 1 < max_decoder_block_num_this_seq
? tmp_used_len + 1
: max_decoder_block_num_this_seq;
}
}
}
// TODO(zhupengyang):
// Before the operator: need_block_len is 0, need_block_list is -1
// After the operator: need_block_len is 0, need_block_list is -1
// May need_block_len and need_block_list not need update?
int ori_need_block_len;
if (cid == 0) {
ori_need_block_len = need_block_len_sm;
need_block_len_sm = 0;
}
if (cid == 0) {
mfence();
LM2GM_ASYNC(step_block_list_lm, step_block_list, sizeof(int) * bsz);
LM2GM_ASYNC(is_block_step_lm, is_block_step, sizeof(bool) * bsz);
LM2GM_ASYNC(&step_len_lm, step_len, sizeof(int));
LM2GM_ASYNC(&recover_len_lm, recover_len, sizeof(int));
LM2GM_ASYNC(recover_block_list_lm, recover_block_list, sizeof(int) * bsz);
SM2GM_ASYNC(&free_list_len_sm, free_list_len, sizeof(int));
SM2GM_ASYNC(&need_block_len_sm, need_block_len, sizeof(int));
if (ori_need_block_len > 0) {
SM2GM_ASYNC(need_block_list_sm,
need_block_list,
sizeof(int) * ori_need_block_len);
}
SM2GM_ASYNC(used_list_len_sm, used_list_len, sizeof(int) * bsz);
mfence();
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -65,7 +65,7 @@ __global__ void speculate_remove_padding(T* output_data,
}
}
__global__ void speculate_get_padding_offset(int* padding_offset,
__global__ void speculate_get_padding_offset(int* batch_id_per_token,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
@@ -90,8 +90,8 @@ __global__ void speculate_get_padding_offset(int* padding_offset,
GM2LM(cum_offsets + bi, &cum_offsets_now_ind, sizeof(int));
for (int i = tid; i < seq_lens_now; i += ncores) {
LM2GM(&cum_offsets_now,
padding_offset + bi * max_seq_len - cum_offsets_now + i,
LM2GM(&bi,
batch_id_per_token + bi * max_seq_len - cum_offsets_now + i,
sizeof(int));
}
LM2GM(&cum_offsets_now, cum_offsets_out + bi, sizeof(int));

View File

@@ -0,0 +1,154 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
static __device__ inline int loada_float(_shared_ptr_ const int* ptr) {
int ret;
__asm__ __volatile__("loada.w %0,%1" : "=&r"(ret) : "r"(ptr));
return ret;
}
static __device__ inline bool storea_float(_shared_ptr_ int* ptr, int value) {
bool ret;
__asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(ret) : "r"(value), "r"(ptr));
return ret;
}
static __device__ int atomic_add(_shared_ptr_ int* ptr, int value) {
bool fail = true;
int old_value;
while (fail) {
old_value = loada_float(ptr);
int new_value = old_value + value;
fail = storea_float(ptr, new_value);
}
return old_value;
}
__global__ void speculate_recover_block(int* recover_block_list, // [bsz]
int* recover_len,
bool* stop_flags,
int* seq_lens_this_time,
const int* ori_seq_lens_encoder,
int* seq_lens_encoder,
const int* seq_lens_decoder,
int* block_tables,
int* free_list,
int* free_list_len,
int64_t* input_ids,
const int64_t* pre_ids,
const int64_t* step_idx,
const int* encoder_block_lens,
const int* used_list_len,
const int64_t* next_tokens,
const int64_t* first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
if (clusterid != 0) return;
// 128 = seq_len(8192) / block_size(64)
// 每次最多处理block_table数量为128
const int block_table_now_len = 128;
int block_table_now[block_table_now_len];
// max_seq_len == length
// max_seq_len == pre_id_length
// 32k local memory per 4 core on kl2.
// No enough memory for 16382 input_ids.
const int buf_len = 256;
int64_t input_ids_now[buf_len];
bool flag_false = false;
__shared__ int free_list_len_sm;
// 每次最多处理free_list数量为block_table_now_len
int free_list_now[block_table_now_len];
if (cid == 0) {
GM2SM(free_list_len, &free_list_len_sm, sizeof(int));
}
sync_cluster();
int recover_len_lm;
GM2LM(recover_len, &recover_len_lm, sizeof(int));
for (int bid = cid; bid < recover_len_lm; bid += ncores) {
int recover_id;
int ori_seq_len_encoder;
int step_idx_now;
int encoder_block_len;
int decoder_used_len;
int64_t next_token;
GM2LM(recover_block_list + bid, &recover_id, sizeof(int));
GM2LM_ASYNC(
ori_seq_lens_encoder + recover_id, &ori_seq_len_encoder, sizeof(int));
GM2LM_ASYNC(step_idx + recover_id, &step_idx_now, sizeof(int));
GM2LM_ASYNC(
encoder_block_lens + recover_id, &encoder_block_len, sizeof(int));
GM2LM_ASYNC(used_list_len + recover_id, &decoder_used_len, sizeof(int));
GM2LM_ASYNC(next_tokens + recover_id, &next_token, sizeof(int64_t));
mfence();
int seq_len = ori_seq_len_encoder + step_idx_now;
mfence();
LM2GM_ASYNC(&seq_len, seq_lens_this_time + recover_id, sizeof(int));
LM2GM_ASYNC(&seq_len, seq_lens_encoder + recover_id, sizeof(int));
LM2GM_ASYNC(&flag_false, stop_flags + recover_id, sizeof(bool));
mfence();
// // next tokens
// LM2GM_ASYNC(&next_token,
// input_ids + recover_id * length + seq_len - 1,
// sizeof(int64_t));
// set first prompt token
int64_t first_token_id;
GM2LM(first_token_ids + recover_id, &first_token_id, sizeof(int64_t));
LM2GM_ASYNC(
&first_token_id, input_ids + recover_id * length, sizeof(int64_t));
int ori_free_list_len = atomic_add(&free_list_len_sm, -decoder_used_len);
// 恢复block table
for (int i = 0; i < decoder_used_len; i += block_table_now_len) {
int process_len = min(block_table_now_len, decoder_used_len - i);
GM2LM(free_list + ori_free_list_len - i - process_len,
free_list_now,
process_len * sizeof(int));
for (int j = 0; j < process_len; j++) {
block_table_now[j] = free_list_now[process_len - 1 - j];
}
mfence();
LM2GM(
block_table_now,
block_tables + recover_id * block_num_per_seq + encoder_block_len + i,
process_len * sizeof(int));
}
// 恢复input_ids
for (int i = 0; i < step_idx_now; i += buf_len) {
int real_len = min(buf_len, step_idx_now - i);
GM2LM(pre_ids + recover_id * pre_id_length + i + 1,
input_ids_now,
sizeof(int64_t) * real_len);
LM2GM(input_ids_now,
input_ids + recover_id * length + ori_seq_len_encoder + i,
sizeof(int64_t) * real_len);
}
mfence();
}
if (cid == 0) {
recover_len_lm = 0;
mfence();
LM2GM_ASYNC(&recover_len_lm, recover_len, sizeof(int));
SM2GM_ASYNC(&free_list_len_sm, free_list_len, sizeof(int));
mfence();
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -138,7 +138,8 @@ __global__ void speculate_verify(
const int max_candidate_len, // scalar, 每个 verify token
// 的最大候选数(用于验证或采样)
const int verify_window, // scalar, TopK 验证窗口(允许连续 top1 匹配次数)
const bool prefill_one_step_stop) {
const bool prefill_one_step_stop,
const bool benchmark_mode) {
const int cid = core_id();
const int64_t tid = cluster_id() * core_num() + core_id();
const int64_t nthreads = cluster_num() * core_num();
@@ -161,6 +162,9 @@ __global__ void speculate_verify(
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
// seq_lens_this_time[bid]-1);
for (; i < seq_lens_this_time[bid] - 1; i++) {
if (benchmark_mode) {
break;
}
if (seq_lens_encoder[bid] != 0) {
break;
}
@@ -326,7 +330,8 @@ __global__ void speculate_verify(
int max_seq_len, \
int max_candidate_len, \
int verify_window, \
bool prefill_one_step_stop);
bool prefill_one_step_stop, \
bool benchmark_mode);
SPECULATE_VERIFY_INSTANTIATE(true, true)
SPECULATE_VERIFY_INSTANTIATE(true, false)
SPECULATE_VERIFY_INSTANTIATE(false, true)

View File

@@ -23,6 +23,7 @@ template <typename TX, typename TY>
__attribute__((global)) void eb_adjust_batch(TX *src,
TY *dst,
int *encoder_seqs_lods,
int *decoder_seqs_lods,
int *encoder_batch_map,
int *decoder_batch_map,
int en_batch,
@@ -41,6 +42,7 @@ static int cpu_wrapper(api::Context *ctx,
const TX *x,
TY *y,
const int *encoder_seqs_lods,
const int *decoder_seqs_lods,
const int *encoder_batch_map,
const int *decoder_batch_map,
int en_batch,
@@ -56,11 +58,12 @@ static int cpu_wrapper(api::Context *ctx,
// get copy size && src_offset
int cpy_m = 0;
if (de_batch > 0 && decoder_batch_map[de_idx] == i) {
cpy_m = 1;
ret = api::cast<TX, TY>(ctx,
x + cur_offset * hidden_dim,
y + (encoder_len_total + de_idx) * hidden_dim,
cpy_m * hidden_dim);
cpy_m = decoder_seqs_lods[de_idx + 1] - decoder_seqs_lods[de_idx];
ret = api::cast<TX, TY>(
ctx,
x + cur_offset * hidden_dim,
y + (encoder_len_total + decoder_seqs_lods[de_idx]) * hidden_dim,
cpy_m * hidden_dim);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
de_idx++;
}
@@ -84,6 +87,7 @@ static int xpu3_wrapper(api::Context *ctx,
const TX *x,
TY *y,
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &decoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
int en_batch,
@@ -98,6 +102,7 @@ static int xpu3_wrapper(api::Context *ctx,
reinterpret_cast<XPU_INDEX_TYPE_TX *>(const_cast<TX *>(x)),
reinterpret_cast<XPU_INDEX_TYPE_TY *>(y),
encoder_seqs_lods.xpu,
decoder_seqs_lods.xpu,
encoder_batch_map.xpu,
decoder_batch_map.xpu,
en_batch,
@@ -111,6 +116,7 @@ int eb_adjust_batch(api::Context *ctx,
const TX *x,
TY *y,
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &decoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
int64_t hidden_dim) {
@@ -119,28 +125,35 @@ int eb_adjust_batch(api::Context *ctx,
// if (dev_id ==0) {
// ctx->set_debug_level(0xA1);
// }
// std::cout << decoder_seqs_lods.cpu[0] << " " << decoder_seqs_lods.cpu[1] <<
// std::endl;
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_adjust_batch", TX, TY);
WRAPPER_DUMP_PARAM6(ctx,
x,
y,
encoder_seqs_lods,
decoder_seqs_lods,
encoder_batch_map,
decoder_batch_map,
hidden_dim);
decoder_batch_map);
WRAPPER_DUMP_PARAM1(ctx, hidden_dim);
WRAPPER_DUMP(ctx);
int encoder_batch = encoder_batch_map.len;
int total_batch = encoder_batch + decoder_batch_map.len;
int decoder_batch = decoder_batch_map.len;
int total_batch = encoder_batch + decoder_batch;
int max_encoder_lod = encoder_seqs_lods.cpu[encoder_batch];
int m = max_encoder_lod + decoder_batch_map.len;
int max_decoder_lod = decoder_seqs_lods.cpu[decoder_batch];
int m = max_encoder_lod + max_decoder_lod;
WRAPPER_CHECK_PTR(ctx, TX, m * hidden_dim, x);
WRAPPER_CHECK_PTR(ctx, TY, m * hidden_dim, y);
WRAPPER_ASSERT_GT(ctx, hidden_dim, 0);
// check VectorParam
WRAPPER_ASSERT_EQ(ctx, encoder_seqs_lods.len, encoder_batch_map.len + 1);
WRAPPER_ASSERT_EQ(ctx, decoder_seqs_lods.len, decoder_batch_map.len + 1);
WRAPPER_ASSERT_GE(ctx, encoder_seqs_lods.cpu[0], 0);
WRAPPER_ASSERT_LE(ctx, encoder_seqs_lods.cpu[0], max_encoder_lod);
WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[0], 0);
WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[0], max_decoder_lod);
for (int i = 0; i < encoder_batch_map.len; ++i) {
WRAPPER_ASSERT_GE(ctx, encoder_batch_map.cpu[i], 0);
WRAPPER_ASSERT_LT(ctx, encoder_batch_map.cpu[i], total_batch)
@@ -150,12 +163,15 @@ int eb_adjust_batch(api::Context *ctx,
for (int i = 0; i < decoder_batch_map.len; ++i) {
WRAPPER_ASSERT_GE(ctx, decoder_batch_map.cpu[i], 0);
WRAPPER_ASSERT_LT(ctx, decoder_batch_map.cpu[i], total_batch)
WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[i + 1], 0);
WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[i + 1], max_decoder_lod);
}
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<TX, TY>(ctx,
x,
y,
encoder_seqs_lods.cpu,
decoder_seqs_lods.cpu,
encoder_batch_map.cpu,
decoder_batch_map.cpu,
encoder_batch_map.len,
@@ -166,6 +182,8 @@ int eb_adjust_batch(api::Context *ctx,
api::ctx_guard RAII_GUARD(ctx);
api::VectorParam<int32_t> encoder_seqs_lods_xpu =
encoder_seqs_lods.to_xpu(RAII_GUARD);
api::VectorParam<int32_t> decoder_seqs_lods_xpu =
decoder_seqs_lods.to_xpu(RAII_GUARD);
api::VectorParam<int32_t> encoder_batch_map_xpu =
encoder_batch_map.to_xpu(RAII_GUARD);
api::VectorParam<int32_t> decoder_batch_map_xpu =
@@ -174,6 +192,7 @@ int eb_adjust_batch(api::Context *ctx,
x,
y,
encoder_seqs_lods_xpu,
decoder_seqs_lods_xpu,
encoder_batch_map_xpu,
decoder_batch_map_xpu,
encoder_batch_map.len,
@@ -190,6 +209,7 @@ int eb_adjust_batch(api::Context *ctx,
api::VectorParam<int32_t> &, \
api::VectorParam<int32_t> &, \
api::VectorParam<int32_t> &, \
api::VectorParam<int32_t> &, \
int64_t);
INSTANTIATION_EB_ADJUST_BATCH(float16, float16);

View File

@@ -27,26 +27,29 @@ __attribute__((global)) void draft_model_preprocess(
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
int real_bsz,
int max_draft_token,
int accept_tokens_len,
int draft_tokens_len,
int input_ids_len,
int base_model_draft_tokens_len,
bool truncate_first_token,
bool splitwise_prefill);
const int bsz,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const int pre_ids_len,
const bool truncate_first_token,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1);
} // namespace plugin
} // namespace xpu3
@@ -67,49 +70,47 @@ static int cpu_wrapper(api::Context* ctx,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
int real_bsz,
int max_draft_token,
int accept_tokens_len,
int draft_tokens_len,
int input_ids_len,
int base_model_draft_tokens_len,
bool truncate_first_token,
bool splitwise_prefill) {
const int bsz,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const int pre_ids_len,
const bool truncate_first_token,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1) {
int64_t not_stop_flag_sum = 0;
int64_t not_stop_flag = 0;
for (int tid = 0; tid < real_bsz; tid++) {
for (int tid = 0; tid < bsz; tid++) {
if (splitwise_prefill) {
int base_model_step_idx_now = base_model_step_idx[tid];
auto* input_ids_now = input_ids + tid * input_ids_len;
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
// printf("bid: %d, base_model_step_idx_now: %d seq_lens_encoder_record:
// %d\n", tid, base_model_step_idx_now, seq_lens_encoder_record[tid]);
if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) {
if (seq_lens_encoder[tid] > 0) {
not_stop_flag = 1;
int seq_len_encoder_record = seq_lens_encoder_record[tid];
seq_lens_encoder[tid] = seq_len_encoder_record;
seq_lens_encoder_record[tid] = -1;
int seq_len_encoder = seq_lens_encoder[tid];
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder_record;
int position = seq_len_encoder;
if (truncate_first_token) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record;
seq_lens_this_time[tid] = seq_len_encoder;
} else {
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
seq_lens_this_time[tid] = seq_len_encoder + 1;
}
} else {
stop_flags[tid] = true;
@@ -120,63 +121,77 @@ static int cpu_wrapper(api::Context* ctx,
}
not_stop_flag_sum += not_stop_flag;
} else {
auto base_model_step_idx_now = base_model_step_idx[tid];
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
auto accept_num_now = accept_num[tid];
auto* input_ids_now = input_ids + tid * input_ids_len;
auto* base_model_draft_tokens_now =
base_model_draft_tokens + tid * base_model_draft_tokens_len;
auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid];
const int32_t base_model_seq_len_this_time =
base_model_seq_lens_this_time[tid];
auto* pre_ids_now = pre_ids + tid * pre_ids_len;
for (int i = 1; i < base_model_draft_tokens_len; i++) {
base_model_draft_tokens_now[i] = -1;
}
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
batch_drop[tid] = true;
stop_flags[tid] = true;
if (kvcache_scheduler_v1) {
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
stop_flags[tid] = true;
is_block_step[tid] = true;
// Need to continue infer
}
} else {
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
batch_drop[tid] = true;
stop_flags[tid] = true;
}
}
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
not_stop_flag = 1;
// 1. first token
if (base_model_step_idx_now == 0) {
seq_lens_this_time[tid] = 0;
not_stop_flag = 0;
} else if (base_model_step_idx_now == 1 &&
seq_lens_encoder_record[tid] > 0) {
// prefill generation
if (seq_lens_encoder[tid] > 0) {
// Can be extended to first few tokens
int seq_len_encoder_record = seq_lens_encoder_record[tid];
seq_lens_encoder[tid] = seq_len_encoder_record;
seq_lens_encoder_record[tid] = -1;
seq_lens_decoder[tid] = seq_lens_decoder_record[tid];
seq_lens_decoder_record[tid] = 0;
int seq_len_encoder = seq_lens_encoder[tid];
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder_record;
pre_ids_now[0] = base_model_first_token;
int position = seq_len_encoder;
if (truncate_first_token) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record;
seq_lens_this_time[tid] = seq_len_encoder;
} else {
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
seq_lens_this_time[tid] = seq_len_encoder + 1;
}
} else { // decode generation
if (kvcache_scheduler_v1) {
// 3. try to recover mtp infer in V1 mode
if (!base_model_is_block_step[tid] && is_block_step[tid]) {
is_block_step[tid] = false;
}
}
} else if (accept_num_now <=
max_draft_token) /*Accept partial draft tokens*/ {
// Base Model reject stop
if (stop_flags[tid]) {
stop_flags[tid] = false;
seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
step_idx[tid] = base_model_step_idx[tid];
// TODO: check
seq_lens_decoder[tid] =
base_model_seq_len_decoder - base_model_seq_len_this_time;
step_idx[tid] =
base_model_step_idx[tid] - base_model_seq_len_this_time;
} else {
seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
step_idx[tid] -= max_draft_token - accept_num_now;
// 2: Last base model generated token and first MTP
// token
seq_lens_decoder[tid] -= num_model_step - 1;
step_idx[tid] -= num_model_step - 1;
}
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
draft_tokens_now[0] = modified_token;
seq_lens_this_time[tid] = 1;
} else /*Accept all draft tokens*/ {
draft_tokens_now[1] = accept_tokens_now[max_draft_token];
seq_lens_this_time[tid] = 2;
for (int i = 0; i < accept_num_now; i++) {
draft_tokens_now[i] = accept_tokens_now[i];
const int pre_id_pos =
base_model_step_idx[tid] - (accept_num_now - i);
const int64_t accept_token = accept_tokens_now[i];
pre_ids_now[pre_id_pos] = accept_token;
}
seq_lens_this_time[tid] = accept_num_now;
}
} else {
stop_flags[tid] = true;
@@ -199,26 +214,29 @@ static int xpu3_wrapper(api::Context* ctx,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
int real_bsz,
int max_draft_token,
int accept_tokens_len,
int draft_tokens_len,
int input_ids_len,
int base_model_draft_tokens_len,
bool truncate_first_token,
bool splitwise_prefill) {
const int bsz,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const int pre_ids_len,
const bool truncate_first_token,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
// NOTE: Don't change 16 to 64, because kernel use gsm
@@ -230,26 +248,29 @@ static int xpu3_wrapper(api::Context* ctx,
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<XPU_INT64*>(step_idx),
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
is_block_step,
batch_drop,
reinterpret_cast<XPU_INT64*>(pre_ids),
reinterpret_cast<const XPU_INT64*>(accept_tokens),
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
reinterpret_cast<const XPU_INT64*>(base_model_step_idx),
base_model_stop_flags,
base_model_is_block_step,
reinterpret_cast<XPU_INT64*>(base_model_draft_tokens),
real_bsz,
max_draft_token,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
truncate_first_token,
splitwise_prefill);
splitwise_prefill,
kvcache_scheduler_v1);
return api::SUCCESS;
}
@@ -261,26 +282,29 @@ int draft_model_preprocess(api::Context* ctx,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
int real_bsz,
int max_draft_token,
int accept_tokens_len,
int draft_tokens_len,
int input_ids_len,
int base_model_draft_tokens_len,
bool truncate_first_token,
bool splitwise_prefill) {
const int bsz,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const int pre_ids_len,
const bool truncate_first_token,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "draft_model_preprocess", int64_t);
WRAPPER_DUMP_PARAM6(ctx,
@@ -290,37 +314,34 @@ int draft_model_preprocess(api::Context* ctx,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder);
WRAPPER_DUMP_PARAM5(ctx,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop);
WRAPPER_DUMP_PARAM5(
ctx, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids);
WRAPPER_DUMP_PARAM3(
ctx, accept_tokens, accept_num, base_model_seq_lens_encoder);
WRAPPER_DUMP_PARAM3(ctx,
WRAPPER_DUMP_PARAM4(ctx,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags);
WRAPPER_DUMP_PARAM3(
ctx, base_model_is_block_step, base_model_draft_tokens, real_bsz);
WRAPPER_DUMP_PARAM3(
ctx, max_draft_token, accept_tokens_len, draft_tokens_len);
WRAPPER_DUMP_PARAM3(
ctx, input_ids_len, base_model_draft_tokens_len, truncate_first_token);
WRAPPER_DUMP_PARAM1(ctx, splitwise_prefill);
ctx, base_model_is_block_step, base_model_draft_tokens, bsz);
WRAPPER_DUMP_PARAM3(ctx, num_model_step, accept_tokens_len, draft_tokens_len);
WRAPPER_DUMP_PARAM4(ctx,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
truncate_first_token);
WRAPPER_DUMP_PARAM2(ctx, splitwise_prefill, kvcache_scheduler_v1);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * accept_tokens_len, accept_tokens);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * input_ids_len, input_ids);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * draft_tokens_len, draft_tokens);
WRAPPER_CHECK_PTR(ctx,
int64_t,
real_bsz * base_model_draft_tokens_len,
base_model_draft_tokens);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int64_t, bsz * accept_tokens_len, accept_tokens);
WRAPPER_CHECK_PTR(ctx, int64_t, bsz * input_ids_len, input_ids);
WRAPPER_CHECK_PTR(ctx, int64_t, bsz * draft_tokens_len, draft_tokens);
WRAPPER_CHECK_PTR(
ctx, int64_t, bsz * base_model_draft_tokens_len, base_model_draft_tokens);
WRAPPER_ASSERT_GT(ctx, real_bsz, 0);
WRAPPER_ASSERT_GT(ctx, bsz, 0);
WRAPPER_ASSERT_LT(ctx, accept_tokens_len, 128);
if (ctx->dev().type() == api::kCPU) {
@@ -332,26 +353,29 @@ int draft_model_preprocess(api::Context* ctx,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
real_bsz,
max_draft_token,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
truncate_first_token,
splitwise_prefill);
splitwise_prefill,
kvcache_scheduler_v1);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
@@ -362,26 +386,29 @@ int draft_model_preprocess(api::Context* ctx,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
real_bsz,
max_draft_token,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
truncate_first_token,
splitwise_prefill);
splitwise_prefill,
kvcache_scheduler_v1);
}
WRAPPER_UNIMPLEMENTED(ctx);
}

View File

@@ -0,0 +1,227 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl/launch_strategy.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include "xpu/xdnn.h"
namespace xpu3 {
namespace plugin {
template <typename TX, typename TY>
__attribute__((global)) void eb_mtp_gather_next_token(TX *src,
TY *dst,
int *encoder_seqs_lods,
int *decoder_seqs_lods,
int *encoder_batch_map,
int *decoder_batch_map,
int en_batch,
int de_batch,
int64_t copy_size);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename TX, typename TY>
static int cpu_wrapper(api::Context *ctx,
const TX *x,
TY *y,
const int *encoder_seqs_lods,
const int *decoder_seqs_lods,
const int *encoder_batch_map,
const int *decoder_batch_map,
int en_batch,
int de_batch,
int64_t hidden_dim) {
int ret = 0;
int encoder_len_total = encoder_seqs_lods[en_batch];
int decoder_len_total = decoder_seqs_lods[de_batch];
int output_token_num = en_batch + decoder_len_total;
for (int i = 0; i < output_token_num; i++) {
int len = 0;
int enc_idx = 0, dec_idx = 0;
bool is_enc;
while (i >= len) {
if (enc_idx >= en_batch) {
len += decoder_seqs_lods[dec_idx + 1] - decoder_seqs_lods[dec_idx];
dec_idx++;
is_enc = false;
continue;
}
if (dec_idx >= de_batch) {
len += 1;
enc_idx++;
is_enc = true;
continue;
}
if ((encoder_batch_map[enc_idx] < decoder_batch_map[dec_idx])) {
len += 1;
enc_idx++;
is_enc = true;
} else {
len += decoder_seqs_lods[dec_idx + 1] - decoder_seqs_lods[dec_idx];
dec_idx++;
is_enc = false;
}
}
const TX *src = nullptr;
if (is_enc) {
src = x + (encoder_seqs_lods[enc_idx] - 1) * hidden_dim;
} else {
src = x + (encoder_len_total + decoder_seqs_lods[dec_idx] - (len - i)) *
hidden_dim;
}
ret = api::cast<TX, TY>(ctx, src, y + i * hidden_dim, hidden_dim);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
}
return api::SUCCESS;
}
template <typename TX, typename TY>
static int xpu3_wrapper(api::Context *ctx,
const TX *x,
TY *y,
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &decoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
int en_batch,
int de_batch,
int64_t hidden_dim) {
auto eb_mtp_gather_next_token_kernel =
xpu3::plugin::eb_mtp_gather_next_token<TX, TY>;
// NOTE: Don't change 16 to 64, because kernel use gsm
eb_mtp_gather_next_token_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
const_cast<TX *>(x),
y,
encoder_seqs_lods.xpu,
decoder_seqs_lods.xpu,
encoder_batch_map.xpu,
decoder_batch_map.xpu,
en_batch,
de_batch,
hidden_dim);
return api::SUCCESS;
}
template <typename TX, typename TY>
int eb_mtp_gather_next_token(
api::Context *ctx,
const TX *x,
TY *y,
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &decoder_seqs_lods, // NOLINT
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
int64_t hidden_dim) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_mtp_gather_next_token", TX, TY);
WRAPPER_DUMP_PARAM6(ctx,
x,
y,
encoder_seqs_lods,
decoder_seqs_lods,
encoder_batch_map,
decoder_batch_map);
WRAPPER_DUMP_PARAM1(ctx, hidden_dim);
WRAPPER_DUMP(ctx);
int encoder_batch = encoder_batch_map.len;
int decoder_batch = decoder_batch_map.len;
int max_encoder_lod = encoder_seqs_lods.cpu[encoder_batch];
int max_decoder_lod = decoder_seqs_lods.cpu[decoder_batch];
int m = encoder_seqs_lods.cpu[encoder_batch] +
decoder_seqs_lods.cpu[decoder_batch];
int out_m = encoder_batch + decoder_seqs_lods.cpu[decoder_batch];
WRAPPER_CHECK_PTR(ctx, TX, m * hidden_dim, x);
WRAPPER_CHECK_PTR(ctx, TY, out_m * hidden_dim, y);
WRAPPER_ASSERT_GT(ctx, hidden_dim, 0);
// check VectorParam
WRAPPER_ASSERT_EQ(ctx, encoder_seqs_lods.len, encoder_batch_map.len + 1);
WRAPPER_ASSERT_EQ(ctx, decoder_seqs_lods.len, decoder_batch_map.len + 1);
WRAPPER_ASSERT_GE(ctx, encoder_seqs_lods.cpu[0], 0);
WRAPPER_ASSERT_LE(ctx, encoder_seqs_lods.cpu[0], max_encoder_lod);
WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[0], 0);
WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[0], max_decoder_lod);
// 注意: encoder/decoder的batch
// map数值上有可能大于batch因为复原后的batch排布有可能是稀疏的所以这里只做非负检查
for (int i = 0; i < encoder_batch_map.len; ++i) {
WRAPPER_ASSERT_GE(ctx, encoder_batch_map.cpu[i], 0);
WRAPPER_ASSERT_GE(ctx, encoder_seqs_lods.cpu[i + 1], 0);
WRAPPER_ASSERT_LE(ctx, encoder_seqs_lods.cpu[i + 1], max_encoder_lod);
}
for (int i = 0; i < decoder_batch_map.len; ++i) {
WRAPPER_ASSERT_GE(ctx, decoder_batch_map.cpu[i], 0);
WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[i + 1], 0);
WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[i + 1], max_decoder_lod);
}
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<TX, TY>(ctx,
x,
y,
encoder_seqs_lods.cpu,
decoder_seqs_lods.cpu,
encoder_batch_map.cpu,
decoder_batch_map.cpu,
encoder_batch_map.len,
decoder_batch_map.len,
hidden_dim);
}
if (ctx->dev().type() == api::kXPU3) {
api::ctx_guard RAII_GUARD(ctx);
api::VectorParam<int32_t> encoder_seqs_lods_xpu =
encoder_seqs_lods.to_xpu(RAII_GUARD);
api::VectorParam<int32_t> decoder_seqs_lods_xpu =
decoder_seqs_lods.to_xpu(RAII_GUARD);
api::VectorParam<int32_t> encoder_batch_map_xpu =
encoder_batch_map.to_xpu(RAII_GUARD);
api::VectorParam<int32_t> decoder_batch_map_xpu =
decoder_batch_map.to_xpu(RAII_GUARD);
return xpu3_wrapper<TX, TY>(ctx,
x,
y,
encoder_seqs_lods_xpu,
decoder_seqs_lods_xpu,
encoder_batch_map_xpu,
decoder_batch_map_xpu,
encoder_batch_map.len,
decoder_batch_map.len,
hidden_dim);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
#define INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(TX, TY) \
template int eb_mtp_gather_next_token<TX, TY>(api::Context *, \
const TX *, \
TY *, \
api::VectorParam<int32_t> &, \
api::VectorParam<int32_t> &, \
api::VectorParam<int32_t> &, \
api::VectorParam<int32_t> &, \
int64_t);
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float16, float16);
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(bfloat16, bfloat16);
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float, float);
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float16, float);
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float, float16);
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float16);
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float16, bfloat16);
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float);
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float, bfloat16);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,340 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void speculate_free_and_dispatch_block(
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
int *accept_num,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context *ctx,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
int *accept_num,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens) {
for (int i = 0; i < bsz; i++) {
int *block_table_now = block_tables + i * block_num_per_seq;
int max_possible_block_idx =
(seq_lens_decoder[i] + max_draft_tokens + 1) / block_size;
if (stop_flags[i] && !is_block_step[i]) {
// 回收block块
first_token_ids[i] = -1;
const int encoder_block_len = encoder_block_lens[i];
const int decoder_used_len = used_list_len[i];
if (decoder_used_len > 0) {
const int ori_free_list_len = free_list_len[0];
free_list_len[0] += decoder_used_len;
for (int j = 0; j < decoder_used_len; j++) {
free_list[ori_free_list_len + j] =
block_table_now[encoder_block_len + j];
block_table_now[encoder_block_len + j] = -1;
}
encoder_block_lens[i] = 0;
used_list_len[i] = 0;
}
} else if (seq_lens_this_time[i] != 0 &&
max_possible_block_idx < block_num_per_seq &&
block_table_now[(seq_lens_decoder[i] + max_draft_tokens + 1) /
block_size] == -1) {
// 统计需要分配block的位置和总数
const int ori_need_block_len = need_block_len[0];
need_block_len[0] += 1;
need_block_list[ori_need_block_len] = i;
}
}
while (need_block_len[0] > free_list_len[0]) {
// 调度block根据used_list_len从大到小回收block直到满足need_block_len
int max_used_list_len_id = 0;
int max_used_list_len = 0;
for (int i = 0; i < bsz; i++) {
const int used_block_num = !is_block_step[i] ? used_list_len[i] : 0;
if (used_block_num > max_used_list_len) {
max_used_list_len_id = i;
max_used_list_len = used_block_num;
}
}
const int encoder_block_len = encoder_block_lens[max_used_list_len_id];
int *block_table_now =
block_tables + max_used_list_len_id * block_num_per_seq;
for (int i = 0; i < max_used_list_len; i++) {
free_list[free_list_len[0] + i] = block_table_now[encoder_block_len + i];
block_table_now[encoder_block_len + i] = -1;
}
step_block_list[step_len[0]] = max_used_list_len_id;
step_len[0] += 1;
free_list_len[0] += max_used_list_len;
stop_flags[max_used_list_len_id] = true;
is_block_step[max_used_list_len_id] = true;
seq_lens_this_time[max_used_list_len_id] = 0;
seq_lens_decoder[max_used_list_len_id] = 0;
accept_num[max_used_list_len_id] = 0;
}
// 为需要block的位置分配block每个位置分配一个block
for (int i = 0; i < bsz; i++) {
if (i < need_block_len[0]) {
const int need_block_id = need_block_list[i];
if (!stop_flags[need_block_id]) {
// 如果需要的位置正好是上一步中被释放的位置,不做处理
used_list_len[need_block_id] += 1;
const int ori_free_list_len = free_list_len[0];
free_list_len[0]--;
int *block_table_now = block_tables + need_block_id * block_num_per_seq;
block_table_now[(seq_lens_decoder[need_block_id] + max_draft_tokens +
1) /
block_size] = free_list[ori_free_list_len - 1];
}
need_block_list[i] = -1;
}
}
// 计算可以复原的query id
int ori_step_len = step_len[0];
if (ori_step_len > 0) {
int ori_free_list_len = free_list_len[0];
int ori_step_block_id = step_block_list[ori_step_len - 1];
int tmp_used_len = used_list_len[ori_step_block_id];
// 比之前调度时多分配一个block防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中
int used_len =
tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len;
if (ori_step_len > 0 && ori_free_list_len >= used_len) {
recover_block_list[recover_len[0]] = ori_step_block_id;
is_block_step[ori_step_block_id] = false;
used_list_len[ori_step_block_id] = used_len;
ori_free_list_len -= used_len;
step_block_list[ori_step_len - 1] = -1;
step_len[0] -= 1;
recover_len[0] += 1;
ori_step_len = step_len[0];
if (ori_step_len > 0) {
ori_step_block_id = step_block_list[ori_step_len - 1];
tmp_used_len = used_list_len[ori_step_block_id];
used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1
: tmp_used_len;
}
}
need_block_len[0] = 0;
}
return api::SUCCESS;
}
static int xpu3_wrapper(Context *ctx,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
int *accept_num,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto speculate_free_and_dispatch_block_kernel =
xpu3::plugin::speculate_free_and_dispatch_block;
speculate_free_and_dispatch_block_kernel<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(first_token_ids),
accept_num,
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
return api::SUCCESS;
}
int speculate_free_and_dispatch_block(Context *ctx,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
int *accept_num,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_free_and_dispatch_block", float);
WRAPPER_DUMP_PARAM6(ctx,
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step);
WRAPPER_DUMP_PARAM6(ctx,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len);
WRAPPER_DUMP_PARAM4(
ctx, used_list_len, free_list, free_list_len, first_token_ids);
WRAPPER_DUMP_PARAM4(
ctx, bsz, block_size, block_num_per_seq, max_decoder_block_num);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
first_token_ids,
accept_num,
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
first_token_ids,
accept_num,
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -33,7 +33,7 @@ __attribute__((global)) void speculate_remove_padding(
int token_num_data);
__attribute__((global)) void speculate_get_padding_offset(
int* padding_offset,
int* batch_id_per_token,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
@@ -78,7 +78,7 @@ static int cpu_wrapper_remove_padding(Context* ctx,
}
static int cpu_wrapper_get_padding_offset(Context* ctx,
int* padding_offset,
int* batch_id_per_token,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
@@ -89,7 +89,7 @@ static int cpu_wrapper_get_padding_offset(Context* ctx,
for (int bi = 0; bi < bsz; ++bi) {
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
for (int i = 0; i < seq_lens[bi]; i++) {
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
}
cum_offsets_out[bi] = cum_offset;
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
@@ -129,7 +129,7 @@ static int xpu3_wrapper_remove_padding(Context* ctx,
}
static int xpu3_wrapper_get_padding_offset(Context* ctx,
int* padding_offset,
int* batch_id_per_token,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
@@ -139,7 +139,7 @@ static int xpu3_wrapper_get_padding_offset(Context* ctx,
int bsz) {
xpu3::plugin::
speculate_get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
padding_offset,
batch_id_per_token,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
@@ -215,7 +215,7 @@ int speculate_remove_padding(Context* ctx,
}
int speculate_get_padding_offset(Context* ctx,
int* padding_offset,
int* batch_id_per_token,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
@@ -227,7 +227,7 @@ int speculate_get_padding_offset(Context* ctx,
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_padding_offset", float);
WRAPPER_DUMP_PARAM6(ctx,
padding_offset,
batch_id_per_token,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
@@ -247,7 +247,7 @@ int speculate_get_padding_offset(Context* ctx,
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper_get_padding_offset(ctx,
padding_offset,
batch_id_per_token,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
@@ -258,7 +258,7 @@ int speculate_get_padding_offset(Context* ctx,
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper_get_padding_offset(ctx,
padding_offset,
batch_id_per_token,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,

View File

@@ -0,0 +1,258 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void speculate_recover_block(
int *recover_block_list, // [bsz]
int *recover_len,
bool *stop_flags,
int *seq_lens_this_time,
const int *ori_seq_lens_encoder,
int *seq_lens_encoder,
const int *seq_lens_decoder,
int *block_tables,
int *free_list,
int *free_list_len,
int64_t *input_ids,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *encoder_block_lens,
const int *used_list_len,
const int64_t *next_tokens,
const int64_t *first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context *ctx,
int *recover_block_list, // [bsz]
int *recover_len,
bool *stop_flags,
int *seq_lens_this_time,
const int *ori_seq_lens_encoder,
int *seq_lens_encoder,
const int *seq_lens_decoder,
int *block_tables,
int *free_list,
int *free_list_len,
int64_t *input_ids,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *encoder_block_lens,
const int *used_list_len,
const int64_t *next_tokens,
const int64_t *first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length) {
for (int bid = 0; bid < recover_len[0]; bid++) {
const int recover_id = recover_block_list[bid];
const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id];
const int step_idx_now = step_idx[recover_id];
const int seq_len = ori_seq_len_encoder + step_idx_now;
const int encoder_block_len = encoder_block_lens[recover_id];
const int decoder_used_len = used_list_len[recover_id];
int *block_table_now = block_tables + recover_id * block_num_per_seq;
int64_t *input_ids_now = input_ids + recover_id * length;
const int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length;
seq_lens_this_time[recover_id] = seq_len;
seq_lens_encoder[recover_id] = seq_len;
stop_flags[recover_id] = false;
// input_ids_now[seq_len - 1] = next_tokens[recover_id]; // next tokens
input_ids_now[0] = first_token_ids[recover_id]; // set first prompt token
int ori_free_list_len = free_list_len[0];
free_list_len[0] -= decoder_used_len;
// 恢复block table
for (int i = 0; i < decoder_used_len; i++) {
block_table_now[encoder_block_len + i] =
free_list[ori_free_list_len - i - 1];
}
// 恢复input_ids
for (int i = 0; i < step_idx_now; i++) {
input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1];
}
}
recover_len[0] = 0;
return api::SUCCESS;
}
static int xpu3_wrapper(Context *ctx,
int *recover_block_list, // [bsz]
int *recover_len,
bool *stop_flags,
int *seq_lens_this_time,
const int *ori_seq_lens_encoder,
int *seq_lens_encoder,
const int *seq_lens_decoder,
int *block_tables,
int *free_list,
int *free_list_len,
int64_t *input_ids,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *encoder_block_lens,
const int *used_list_len,
const int64_t *next_tokens,
const int64_t *first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto recover_block_kernel = xpu3::plugin::speculate_recover_block;
recover_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
recover_block_list, // [bsz]
recover_len,
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
seq_lens_encoder,
seq_lens_decoder,
block_tables,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(input_ids),
reinterpret_cast<const XPU_INT64 *>(pre_ids),
reinterpret_cast<const XPU_INT64 *>(step_idx),
encoder_block_lens,
used_list_len,
reinterpret_cast<const XPU_INT64 *>(next_tokens),
reinterpret_cast<const XPU_INT64 *>(first_token_ids),
bsz,
block_num_per_seq,
length,
pre_id_length);
return api::SUCCESS;
}
int speculate_recover_block(Context *ctx,
int *recover_block_list, // [bsz]
int *recover_len,
bool *stop_flags,
int *seq_lens_this_time,
const int *ori_seq_lens_encoder,
int *seq_lens_encoder,
const int *seq_lens_decoder,
int *block_tables,
int *free_list,
int *free_list_len,
int64_t *input_ids,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *encoder_block_lens,
const int *used_list_len,
const int64_t *next_tokens,
const int64_t *first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_recover_block", float);
WRAPPER_DUMP_PARAM6(ctx,
recover_block_list,
recover_len,
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
seq_lens_encoder);
WRAPPER_DUMP_PARAM6(ctx,
seq_lens_decoder,
block_tables,
free_list,
free_list_len,
input_ids,
pre_ids);
WRAPPER_DUMP_PARAM5(ctx,
step_idx,
encoder_block_lens,
used_list_len,
next_tokens,
first_token_ids);
WRAPPER_DUMP_PARAM4(ctx, bsz, block_num_per_seq, length, pre_id_length);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
recover_block_list, // [bsz]
recover_len,
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
seq_lens_encoder,
seq_lens_decoder,
block_tables,
free_list,
free_list_len,
input_ids,
pre_ids,
step_idx,
encoder_block_lens,
used_list_len,
next_tokens,
first_token_ids,
bsz,
block_num_per_seq,
length,
pre_id_length);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
recover_block_list, // [bsz]
recover_len,
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
seq_lens_encoder,
seq_lens_decoder,
block_tables,
free_list,
free_list_len,
input_ids,
pre_ids,
step_idx,
encoder_block_lens,
used_list_len,
next_tokens,
first_token_ids,
bsz,
block_num_per_seq,
length,
pre_id_length);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -48,7 +48,8 @@ __attribute__((global)) void speculate_verify(
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const bool prefill_one_step_stop);
const bool prefill_one_step_stop,
const bool benchmark_mode);
} // namespace plugin
} // namespace xpu3
@@ -136,14 +137,15 @@ static int cpu_wrapper(Context *ctx,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const bool prefill_one_step_stop) {
const bool prefill_one_step_stop,
const bool benchmark_mode) {
for (int bid = 0; bid < real_bsz; ++bid) {
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
// verify and set stop flags
int accept_num_now = 1;
int stop_flag_now_int = 0;
if (!(is_block_step[bid] || bid >= real_bsz)) {
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
// printf("debug cpu bid:%d,start_token_id:%d\n",bid, start_token_id);
// printf("bid %d\n", bid);
@@ -160,6 +162,9 @@ static int cpu_wrapper(Context *ctx,
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
// seq_lens_this_time[bid]-1);
for (; i < seq_lens_this_time[bid] - 1; i++) {
if (benchmark_mode) {
break;
}
if (seq_lens_encoder[bid] != 0) {
break;
}
@@ -326,7 +331,8 @@ static int xpu3_wrapper(Context *ctx,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const bool prefill_one_step_stop) {
const bool prefill_one_step_stop,
const bool benchmark_mode) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
xpu3::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>
<<<1, 64, ctx->xpu_stream>>>(
@@ -354,7 +360,8 @@ static int xpu3_wrapper(Context *ctx,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
prefill_one_step_stop,
benchmark_mode);
return api::SUCCESS;
}
template <bool ENABLE_TOPP, bool USE_TOPK>
@@ -383,7 +390,8 @@ int speculate_verify(Context *ctx,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const bool prefill_one_step_stop) {
const bool prefill_one_step_stop,
const bool benchmark_mode) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_verify", int64_t);
WRAPPER_DUMP_PARAM3(ctx, accept_tokens, accept_num, step_idx);
@@ -406,12 +414,13 @@ int speculate_verify(Context *ctx,
actual_candidate_len,
real_bsz,
max_draft_tokens);
WRAPPER_DUMP_PARAM5(ctx,
WRAPPER_DUMP_PARAM6(ctx,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
prefill_one_step_stop,
benchmark_mode);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, accept_tokens);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, accept_num);
@@ -469,7 +478,8 @@ int speculate_verify(Context *ctx,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
prefill_one_step_stop,
benchmark_mode);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<ENABLE_TOPP, USE_TOPK>(ctx,
@@ -497,7 +507,8 @@ int speculate_verify(Context *ctx,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
prefill_one_step_stop,
benchmark_mode);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
@@ -530,7 +541,8 @@ int speculate_verify(Context *ctx,
int, /* max_seq_len */ \
int, /* max_candidate_len */ \
int, /* verify_window */ \
bool); /* prefill_one_step_stop */
bool, \
bool); /* prefill_one_step_stop */
INSTANTIATE_SPECULATE_VERIFY(false, false)
INSTANTIATE_SPECULATE_VERIFY(false, true)

View File

@@ -0,0 +1,180 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest # 导入 unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import (
adjust_batch,
gather_next_token,
get_infer_param,
)
def _run_test_base(seq_lens_this_time_data, output_padding_offset):
"""
通用的基础测试执行函数,包含了两个场景共有的逻辑。
"""
seq_lens_encoder = paddle.to_tensor([100, 0, 0, 0, 120, 140, 0], dtype="int32")
seq_lens_decoder = paddle.to_tensor([0, 5, 0, 25, 64, 0, 128], dtype="int32")
seq_lens_this_time = paddle.to_tensor(seq_lens_this_time_data, dtype="int32")
bsz = seq_lens_this_time.shape[0]
cum_offsets = paddle.zeros(bsz, dtype="int32")
block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8))
infer_params = get_infer_param(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64)
(
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
_,
_,
_,
_,
_,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
_,
_,
_,
_,
len_info_cpu,
) = infer_params
token_num = seq_lens_this_time.sum().cpu().item()
hidden_dim = 8192
row_indices = paddle.arange(token_num, dtype="int32")
row_indices_bf16 = row_indices.astype("bfloat16")
input_tensor = paddle.unsqueeze(row_indices_bf16, axis=1).expand(shape=[token_num, hidden_dim])
# 测试 adjust_batch
adjusted_output = adjust_batch(
input_tensor,
cum_offsets,
encoder_seq_lod,
decoder_seq_lod,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
len_info_cpu,
None, # output_padding_offset
-1, # max_input_length
)
adjusted_output_cpu = adjust_batch(
input_tensor.cpu(),
cum_offsets,
encoder_seq_lod,
decoder_seq_lod,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
len_info_cpu,
None, # output_padding_offset
-1, # max_input_length
)
# 用 np.testing 替代原生 assert错误信息更友好
adjusted_output_np = adjusted_output.astype("float32").cpu().numpy()
adjusted_output_cpu_np = adjusted_output_cpu.astype("float32").cpu().numpy()
np.testing.assert_allclose(adjusted_output_np, adjusted_output_cpu_np, err_msg="adjust_batch check failed!")
# 测试 gather_next_token
gather_out = gather_next_token(
adjusted_output,
cum_offsets,
encoder_seq_lod,
decoder_seq_lod,
encoder_batch_map,
decoder_batch_map,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
len_info_cpu,
output_padding_offset,
-1,
)
gather_out_cpu = gather_next_token(
adjusted_output.cpu(),
cum_offsets,
encoder_seq_lod,
decoder_seq_lod,
encoder_batch_map,
decoder_batch_map,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
len_info_cpu,
output_padding_offset,
-1,
)
gather_out_np = gather_out.astype("float32").cpu().numpy()
gather_out_cpu_np = gather_out_cpu.astype("float32").cpu().numpy()
if output_padding_offset is not None:
np.testing.assert_allclose(gather_out_np, gather_out_cpu_np, err_msg="gather_next_token check failed!")
else:
for i in range(gather_out_cpu.shape[0]):
if seq_lens_this_time[i] > 0:
np.testing.assert_allclose(
gather_out_np[i], gather_out_cpu_np[i], err_msg=f"gather_next_token check failed at index {i}!"
)
class TestXPUOps(unittest.TestCase): # 继承 unittest.TestCase
"""测试 XPU ops 的 adjust_batch 和 gather_next_token 功能"""
def test_mix_with_mtp(self):
"""测试混合批次处理中的 MTP (Multi-Token Prediction) 场景"""
print("\nRunning test: test_mix_with_mtp")
seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3]
bsz = len(seq_lens_this_time_data)
output_padding_offset = paddle.zeros(bsz, dtype="int32")
_run_test_base(seq_lens_this_time_data, output_padding_offset)
print("Test passed for scenario: With MTP")
def test_mix_without_mtp(self):
"""测试非 MTP (Single-Token Prediction) 场景下的功能"""
print("\nRunning test: test_mix_without_mtp")
seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1]
output_padding_offset = None # 非 MTP 场景下,此参数为 None
_run_test_base(seq_lens_this_time_data, output_padding_offset)
print("Test passed for scenario: Without MTP")
if __name__ == "__main__":
unittest.main() # 使用 unittest 运行测试

View File

@@ -12,50 +12,284 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import draft_model_preprocess
def run_test(device="xpu"):
paddle.seed(2022)
def process_splitwise_prefill(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
base_model_draft_tokens_len,
truncate_first_token,
kvcache_scheduler_v1,
):
not_stop_flag_sum = 0
# Define parameters
bsz = 10
draft_tokens_len = 4
input_ids_len = 8
max_draft_token = 10
for tid in range(bsz):
not_stop_flag = 0
input_ids_now = input_ids[tid]
accept_tokens_now = accept_tokens[tid]
if seq_lens_encoder[tid] > 0:
not_stop_flag = 1
seq_len_encoder = seq_lens_encoder[tid]
stop_flags[tid] = False
base_model_first_token = accept_tokens_now[0]
position = seq_len_encoder
if truncate_first_token:
input_ids_now[position - 1] = base_model_first_token
seq_lens_this_time[tid] = seq_len_encoder
else:
input_ids_now[position] = base_model_first_token
seq_lens_this_time[tid] = seq_len_encoder + 1
else:
stop_flags[tid] = True
seq_lens_this_time[tid] = 0
seq_lens_decoder[tid] = 0
seq_lens_encoder[tid] = 0
not_stop_flag = 0
not_stop_flag_sum = not_stop_flag_sum + not_stop_flag
not_need_stop[0] = not_stop_flag_sum > 0
truncate_first_token = True
splitwise_prefill = False
# Create input tensors
if device == "cpu":
paddle.set_device(device)
draft_tokens = paddle.randint(0, 100, [bsz, draft_tokens_len], dtype="int64")
input_ids = paddle.randint(0, 100, [bsz, input_ids_len], dtype="int64")
stop_flags = paddle.randint(0, 1, [bsz], dtype="int").cast("bool")
seq_lens_this_time = paddle.randint(0, 100, [bsz], dtype="int32")
seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32")
seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32")
step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
seq_lens_encoder_record = paddle.randint(0, 100, [bsz], dtype="int32")
seq_lens_decoder_record = paddle.randint(0, 100, [bsz], dtype="int32")
not_need_stop = paddle.zeros([1], dtype="bool").cpu()
batch_drop = paddle.zeros([bsz], dtype="bool")
def draft_model_preprocess_kernel(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
base_model_draft_tokens_len,
truncate_first_token,
kvcache_scheduler_v1,
):
not_stop_flag_sum = 0
# Output tensors
accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64")
accept_num = paddle.randint(1, max_draft_token + 5, [bsz], dtype="int32")
base_model_seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32")
base_model_seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32")
base_model_step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
base_model_stop_flags = paddle.zeros([bsz], dtype="bool")
base_model_is_block_step = paddle.zeros([bsz], dtype="bool")
base_model_draft_tokens = paddle.zeros([bsz, max_draft_token], dtype="int64")
# Run the op
outputs = draft_model_preprocess(
for tid in range(bsz):
not_stop_flag = 0
accept_tokens_now = accept_tokens[tid]
draft_tokens_now = draft_tokens[tid]
accept_num_now = accept_num[tid]
input_ids_now = input_ids[tid]
base_model_draft_tokens_now = base_model_draft_tokens[tid]
base_model_seq_len_decoder = base_model_seq_lens_decoder[tid]
base_model_seq_len_this_time = base_model_seq_lens_this_time[tid]
pre_ids_now = pre_ids[tid]
base_model_draft_tokens_now[1:base_model_draft_tokens_len] = -1
if kvcache_scheduler_v1:
if base_model_stop_flags[tid] and base_model_is_block_step[tid]:
stop_flags[tid] = True
is_block_step[tid] = True
# Need to continue infer
else:
if base_model_stop_flags[tid] and base_model_is_block_step[tid]:
batch_drop[tid] = True
stop_flags[tid] = True
if not (base_model_stop_flags[tid] or batch_drop[tid]):
not_stop_flag = 1
# 1. first token
if seq_lens_encoder[tid] > 0:
# Can be extended to first few tokens
seq_len_encoder = seq_lens_encoder[tid]
stop_flags[tid] = False
base_model_first_token = accept_tokens_now[0]
pre_ids_now[0] = base_model_first_token
position = seq_len_encoder
if truncate_first_token:
input_ids_now[position - 1] = base_model_first_token
seq_lens_this_time[tid] = seq_len_encoder
else:
input_ids_now[position] = base_model_first_token
seq_lens_this_time[tid] = seq_len_encoder + 1
else:
if kvcache_scheduler_v1:
# 3. try to recover mtp infer in V1 mode
if not (base_model_is_block_step[tid] and is_block_step[tid]):
is_block_step[tid] = False
if stop_flags[tid]:
stop_flags[tid] = False
# TODO: check
seq_lens_decoder[tid] = base_model_seq_len_decoder - base_model_seq_len_this_time
step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time
else:
# 2: Last base model generated token and first MTP token
seq_lens_decoder[tid] -= num_model_step - 1
step_idx[tid] -= num_model_step - 1
for i in range(accept_num_now):
draft_tokens_now[i] = accept_tokens_now[i]
pre_id_pos = base_model_step_idx[tid] - (accept_num_now - i)
accept_token = accept_tokens_now[i]
pre_ids_now[pre_id_pos] = accept_token
seq_lens_this_time[tid] = accept_num_now
else:
stop_flags[tid] = True
seq_lens_this_time[tid] = 0
seq_lens_decoder[tid] = 0
seq_lens_encoder[tid] = 0
not_stop_flag_sum = not_stop_flag_sum + not_stop_flag
not_need_stop[0] = not_stop_flag_sum > 0
def DispatchRunner(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
truncate_first_token,
splitwise_prefill,
kvcache_scheduler_v1,
):
base_model_draft_tokens_len = base_model_draft_tokens.shape[1]
if splitwise_prefill:
process_splitwise_prefill(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
base_model_draft_tokens_len,
truncate_first_token,
kvcache_scheduler_v1,
)
else:
draft_model_preprocess_kernel(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
base_model_draft_tokens_len,
truncate_first_token,
kvcache_scheduler_v1,
)
def draft_model_preprocess_ref(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
num_model_step,
truncate_first_token,
splitwise_prefill,
kvcache_scheduler_v1,
):
real_bsz = seq_lens_this_time.shape[0]
DispatchRunner(
draft_tokens,
input_ids,
stop_flags,
@@ -63,73 +297,110 @@ def run_test(device="xpu"):
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
max_draft_token=max_draft_token,
truncate_first_token=truncate_first_token,
splitwise_prefill=splitwise_prefill,
real_bsz,
num_model_step,
truncate_first_token,
splitwise_prefill,
kvcache_scheduler_v1,
)
# Return results for comparison
results = {
"draft_tokens": draft_tokens.numpy(),
"input_ids": input_ids.numpy(),
"stop_flags": stop_flags.numpy(),
"seq_lens_this_time": seq_lens_this_time.numpy(),
"accept_tokens": accept_tokens.numpy(),
"accept_num": accept_num.numpy(),
"not_need_stop": not_need_stop.numpy(),
"outputs": [x.numpy() for x in outputs],
}
return results
class TestDraftModelPreprocess:
def _run_tests(self):
paddle.seed(2022)
def compare_results(cpu_results, xpu_results):
# Compare all outputs
for key in cpu_results:
if key == "outputs":
for i, (cpu_out, xpu_out) in enumerate(zip(cpu_results[key], xpu_results[key])):
np.testing.assert_allclose(
cpu_out,
xpu_out,
rtol=1e-5,
atol=1e-8,
err_msg=f"Output {i} mismatch between CPU and GPU",
)
else:
np.testing.assert_allclose(
cpu_results[key],
xpu_results[key],
rtol=1e-5,
atol=1e-8,
err_msg=f"{key} mismatch between CPU and GPU",
)
print("CPU and GPU results match!")
# Define parameters
bsz = 10
draft_tokens_len = 4
input_ids_len = 100
max_draft_token = 10
truncate_first_token = True
splitwise_prefill = False
def test_draft_model_preprocess():
draft_tokens = paddle.randint(0, 100, [bsz, draft_tokens_len], dtype="int64")
input_ids = paddle.randint(0, 100, [bsz, input_ids_len], dtype="int64")
stop_flags = paddle.randint(0, 1, [bsz], dtype="int").cast("bool")
seq_lens_this_time = paddle.randint(0, 100, [bsz], dtype="int32")
seq_lens_encoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32")
seq_lens_decoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32")
step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
seq_lens_encoder_record = paddle.randint(0, 100, [bsz], dtype="int32") # noqa: F841
seq_lens_decoder_record = paddle.randint(0, 100, [bsz], dtype="int32") # noqa: F841
not_need_stop = paddle.zeros([1], dtype="bool").cpu()
is_block_step = paddle.zeros([bsz], dtype="bool")
batch_drop = paddle.zeros([bsz], dtype="bool")
print("Running XPU test...")
xpu_results = run_test("xpu")
# Output tensors
accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64")
accept_num = paddle.randint(1, max_draft_token + 5, [bsz], dtype="int32")
base_model_seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32")
base_model_seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32")
base_model_step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
base_model_stop_flags = paddle.zeros([bsz], dtype="bool")
base_model_is_block_step = paddle.zeros([bsz], dtype="bool")
base_model_draft_tokens = paddle.zeros([bsz, max_draft_token], dtype="int64")
# Run the op
pre_ids = input_ids.clone()
base_model_seq_lens_this_time = seq_lens_this_time
num_model_step = max_draft_token
print("Running CPU test...")
cpu_results = run_test("cpu")
kvcache_scheduler_v1 = True
inputs = (
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
num_model_step,
truncate_first_token,
splitwise_prefill,
kvcache_scheduler_v1,
)
# inplace modify, need to clone inputs
inputs_clone = [x.clone() if isinstance(x, paddle.Tensor) else x for x in inputs]
draft_model_preprocess_ref(*inputs)
draft_model_preprocess(*inputs_clone)
return inputs, inputs_clone
print("Comparing results...")
compare_results(cpu_results, xpu_results)
print("Test passed!")
def test_draft_model_preprocess(self):
results1, results2 = self._run_tests()
np.testing.assert_allclose(results1[0], results2[0]) # draft_tokens
np.testing.assert_allclose(results1[1], results2[1]) # input_ids
np.testing.assert_allclose(results1[2], results2[2]) # stop_flags
np.testing.assert_allclose(results1[3], results2[3]) # seq_lens_this_time
np.testing.assert_allclose(results1[11], results2[11]) # accept_tokens
np.testing.assert_allclose(results1[12], results2[12]) # accept_num
np.testing.assert_allclose(results1[7], results2[7]) # not_need_stop
if __name__ == "__main__":
test_draft_model_preprocess()
unittest.main()

View File

@@ -0,0 +1,312 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import speculate_step_paddle
# 固定随机种子,保证测试可复现
np.random.seed(2023)
paddle.seed(2023)
def generate_test_data():
"""
生成测试数据的辅助函数。
这部分逻辑从 pytest 的 fixture 转换而来,作为一个普通函数供测试方法调用。
"""
# max_bs = 128
max_bs = 8
bs = max_bs
max_seq_len = 8192
block_size = 64
block_bs = 8
block_ratio = 0.75
max_draft_tokens = 1
encoder_decoder_block_num = 1
# 生成原始测试数据(完全复用原有逻辑)
stop_flags = np.random.randint(0, 2, [max_bs]).astype("bool")
seq_lens_this_time = np.zeros([bs], "int32")
seq_lens_encoder = np.zeros([max_bs], "int32")
seq_lens_decoder = np.zeros([max_bs], "int32")
accept_num = np.random.randint(1, 3, [max_bs]).astype("int32")
for i in range(bs):
seq_lens_decoder[i] = 2 + i * 2
seq_lens_this_time[i] = 1
ori_seq_lens_encoder = np.zeros([max_bs], "int32")
ori_seq_lens_encoder[:] = seq_lens_decoder[:] // 2
step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64")
max_block_num = block_bs * max_seq_len // block_size
free_list_len = int(max_block_num * (1 - block_ratio))
free_list_len = np.full([1], free_list_len, "int32")
free_list = np.arange(
max_block_num - 1, max_block_num - free_list_len.item() - 1, -1, dtype="int32" # 加 .item() 转为标量
)
encoder_block_lens = np.zeros([max_bs], "int32")
used_list_len = np.zeros([max_bs], "int32")
block_tables = np.full([max_bs, 128], -1, "int32")
encoder_block_id = 0
for i in range(bs):
enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size
encoder_block_lens[i] = enc_block_num
dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num
used_list_len[i] = dec_block_num
block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
encoder_block_id += enc_block_num
if dec_block_num > 0:
block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[
free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1
]
free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1
free_list_len[0] -= dec_block_num
assert free_list_len[0] >= 0, "free_list_len should not be negative"
is_block_step = np.zeros([max_bs], "bool")
is_block_step[:bs] = np.random.randint(0, 2, [bs]).astype("bool")
step_block_list = np.full([max_bs], -1, "int32")
step_lens = np.full([1], 0, "int32")
for i in range(bs):
if is_block_step[i]:
step_block_list[step_lens[0]] = i
step_lens[0] += 1
recover_lens = np.full([1], 0, "int32")
recover_block_list = np.full([max_bs], -1, "int32")
need_block_len = np.full([1], 0, "int32")
need_block_list = np.full([max_bs], -1, "int32")
input_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64")
pre_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64")
next_tokens = np.random.randint(0, 1000, [max_bs], "int64")
first_token_ids = np.random.randint(0, 1000, [max_bs], "int64")
paddle.set_device("cpu")
# 转换为 paddle tensor保持原有逻辑
data_cpu = {
"stop_flags": paddle.to_tensor(stop_flags),
"seq_lens_this_time": paddle.to_tensor(seq_lens_this_time),
"seq_lens_encoder": paddle.to_tensor(seq_lens_encoder),
"seq_lens_decoder": paddle.to_tensor(seq_lens_decoder),
"ori_seq_lens_encoder": paddle.to_tensor(ori_seq_lens_encoder),
"block_tables": paddle.to_tensor(block_tables),
"encoder_block_lens": paddle.to_tensor(encoder_block_lens),
"is_block_step": paddle.to_tensor(is_block_step),
"step_block_list": paddle.to_tensor(step_block_list),
"step_lens": paddle.to_tensor(step_lens),
"recover_block_list": paddle.to_tensor(recover_block_list),
"recover_lens": paddle.to_tensor(recover_lens),
"need_block_list": paddle.to_tensor(need_block_list),
"need_block_len": paddle.to_tensor(need_block_len),
"used_list_len": paddle.to_tensor(used_list_len),
"free_list_len": paddle.to_tensor(free_list_len),
"free_list": paddle.to_tensor(free_list),
"input_ids": paddle.to_tensor(input_ids),
"pre_ids": paddle.to_tensor(pre_ids),
"step_idx": paddle.to_tensor(step_idx),
"next_tokens": paddle.to_tensor(next_tokens),
"first_token_ids": paddle.to_tensor(first_token_ids),
"accept_num": paddle.to_tensor(accept_num),
"block_size": block_size,
"encoder_decoder_block_num": encoder_decoder_block_num,
"max_draft_tokens": max_draft_tokens,
}
paddle.set_device("xpu:0")
data_xpu = {
"stop_flags": paddle.to_tensor(stop_flags),
"seq_lens_this_time": paddle.to_tensor(seq_lens_this_time),
"seq_lens_encoder": paddle.to_tensor(seq_lens_encoder),
"seq_lens_decoder": paddle.to_tensor(seq_lens_decoder),
"ori_seq_lens_encoder": paddle.to_tensor(ori_seq_lens_encoder),
"block_tables": paddle.to_tensor(block_tables),
"encoder_block_lens": paddle.to_tensor(encoder_block_lens),
"is_block_step": paddle.to_tensor(is_block_step),
"step_block_list": paddle.to_tensor(step_block_list),
"step_lens": paddle.to_tensor(step_lens),
"recover_block_list": paddle.to_tensor(recover_block_list),
"recover_lens": paddle.to_tensor(recover_lens),
"need_block_list": paddle.to_tensor(need_block_list),
"need_block_len": paddle.to_tensor(need_block_len),
"used_list_len": paddle.to_tensor(used_list_len),
"free_list_len": paddle.to_tensor(free_list_len),
"free_list": paddle.to_tensor(free_list),
"input_ids": paddle.to_tensor(input_ids),
"pre_ids": paddle.to_tensor(pre_ids),
"step_idx": paddle.to_tensor(step_idx),
"next_tokens": paddle.to_tensor(next_tokens),
"first_token_ids": paddle.to_tensor(first_token_ids),
"accept_num": paddle.to_tensor(accept_num),
"block_size": block_size,
"encoder_decoder_block_num": encoder_decoder_block_num,
"max_draft_tokens": max_draft_tokens,
}
# 恢复默认设备,避免影响其他测试
paddle.set_device("cpu")
return data_cpu, data_xpu
def speculate_step_paddle_execution(test_data):
"""测试 speculate_step_paddle 函数的执行性和输出合理性"""
# 提取输入数据
stop_flags = test_data["stop_flags"] # 克隆避免影响夹具数据
seq_lens_this_time = test_data["seq_lens_this_time"]
ori_seq_lens_encoder = test_data["ori_seq_lens_encoder"]
seq_lens_encoder = test_data["seq_lens_encoder"]
seq_lens_decoder = test_data["seq_lens_decoder"]
block_tables = test_data["block_tables"]
encoder_block_lens = test_data["encoder_block_lens"]
is_block_step = test_data["is_block_step"]
step_block_list = test_data["step_block_list"]
step_lens = test_data["step_lens"]
recover_block_list = test_data["recover_block_list"]
recover_lens = test_data["recover_lens"]
need_block_list = test_data["need_block_list"]
need_block_len = test_data["need_block_len"]
used_list_len = test_data["used_list_len"]
free_list = test_data["free_list"]
free_list_len = test_data["free_list_len"]
input_ids = test_data["input_ids"]
pre_ids = test_data["pre_ids"]
step_idx = test_data["step_idx"]
next_tokens = test_data["next_tokens"]
first_token_ids = test_data["first_token_ids"]
accept_num = test_data["accept_num"]
block_size = test_data["block_size"]
encoder_decoder_block_num = test_data["encoder_decoder_block_num"]
max_draft_tokens = test_data["max_draft_tokens"]
# 可选:打印执行前关键信息(如需调试可开启)
if os.environ.get("STEP_TEST_DEBUG", "0") == "1":
print("-" * 50 + "before step op" + "-" * 50)
# ... (省略打印内容以保持简洁)
# 执行目标函数(核心测试步骤)
speculate_step_paddle(
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
seq_lens_encoder,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_lens,
recover_block_list,
recover_lens,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
input_ids,
pre_ids,
step_idx,
next_tokens,
first_token_ids,
accept_num,
block_size,
encoder_decoder_block_num,
max_draft_tokens,
)
# 可选:打印执行后关键信息(如需调试可开启)
if os.environ.get("STEP_TEST_DEBUG", "0") == "1":
print("-" * 50 + "after step op" + "-" * 50)
# ... (省略打印内容以保持简洁)
return test_data
class TestSpeculateStepPaddle(unittest.TestCase):
"""
测试类,继承自 unittest.TestCase。
所有以 'test_' 开头的方法都会被视为测试用例。
"""
def assert_test_data_equal(self, test_data1, test_data2, rtol=1e-05, atol=1e-08):
"""
自定义的断言方法,用于比较两个 test_data 结构和数据。
在 unittest 中,自定义断言通常以 'assert' 开头。
"""
# 1. 先校验两个 test_data 的字段名完全一致
keys1 = set(test_data1.keys())
keys2 = set(test_data2.keys())
self.assertEqual(
keys1,
keys2,
msg=f"两个 test_data 字段不一致!\n仅在第一个中存在:{keys1 - keys2}\n仅在第二个中存在:{keys2 - keys1}",
)
# 2. 逐字段校验数据
for key in keys1:
data1 = test_data1[key]
data2 = test_data2[key]
# 区分paddle Tensor需转 numpy和 普通标量/数组(直接使用)
if isinstance(data1, paddle.Tensor):
np1 = data1.detach().cpu().numpy()
else:
np1 = np.asarray(data1)
if isinstance(data2, paddle.Tensor):
np2 = data2.detach().cpu().numpy()
else:
np2 = np.asarray(data2)
# 3. 校验数据
if np1.dtype in (np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8):
# 布尔/整数型:必须完全相等
np.testing.assert_array_equal(np1, np2, err_msg=f"字段 {key} 数据不一致!")
else:
# 浮点型:允许 rtol/atol 范围内的误差
np.testing.assert_allclose(np1, np2, rtol=rtol, atol=atol, err_msg=f"字段 {key} 浮点数据不一致!")
print("✅ 两个 test_data 结构和数据完全一致!")
def test_speculate_step_paddle_execution(self):
"""
核心测试用例方法。
该方法会调用 generate_test_data 获取数据,
分别在 CPU 和 XPU 上执行测试函数,
并使用自定义的断言方法比较结果。
"""
print("\nRunning test: test_speculate_step_paddle_execution")
# 1. 获取测试数据
data_cpu, data_xpu = generate_test_data()
# 2. 执行测试函数
result_xpu = speculate_step_paddle_execution(data_xpu)
result_cpu = speculate_step_paddle_execution(data_cpu)
# 3. 断言结果一致
self.assert_test_data_equal(result_xpu, result_cpu)
if __name__ == "__main__":
# 使用 unittest 的主程序来运行所有测试用例
unittest.main()

View File

@@ -12,101 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import unittest
# tests/test_speculate_update_v3.py
import numpy as np
import paddle
# 假设这是你的自定义算子
from fastdeploy.model_executor.ops.xpu import speculate_update_v3
# ---------------- NumPy 参考实现 ----------------
def speculate_update_v3_np(
seq_lens_encoder,
seq_lens_decoder,
not_need_stop,
draft_tokens,
actual_draft_token_nums,
accept_tokens,
accept_num,
stop_flags,
seq_lens_this_time,
is_block_step,
stop_nums,
):
"""
完全复现 CPU / CUDA 逻辑的 NumPy 参考版本(就地修改)。
"""
stop_sum = 0
real_bsz = seq_lens_this_time.shape[0]
max_bsz = stop_flags.shape[0]
max_draft_tokens = draft_tokens.shape[1]
for bid in range(max_bsz):
stop_flag_now_int = 0
inactive = bid >= real_bsz
block_step = (not inactive) and is_block_step[bid]
if (not block_step) and (not inactive):
if stop_flags[bid]:
stop_flag_now_int = 1
# encoder 长度为 0 时直接累加 decoder
if seq_lens_encoder[bid] == 0:
seq_lens_decoder[bid] += accept_num[bid]
# draft 长度自适应
if (seq_lens_encoder[bid] == 0) and (seq_lens_this_time[bid] > 1):
cur_len = actual_draft_token_nums[bid]
if accept_num[bid] - 1 == cur_len: # 全部接受
if cur_len + 2 <= max_draft_tokens - 1:
cur_len += 2
elif cur_len + 1 <= max_draft_tokens - 1:
cur_len += 1
else:
cur_len = max_draft_tokens - 1
else: # 有拒绝
cur_len = max(1, cur_len - 1)
actual_draft_token_nums[bid] = cur_len
# 偿还 encoder 欠账
if seq_lens_encoder[bid] != 0:
seq_lens_decoder[bid] += seq_lens_encoder[bid]
seq_lens_encoder[bid] = 0
# 写回下一轮首 token
draft_tokens[bid, 0] = accept_tokens[bid, accept_num[bid] - 1]
# 停止则清零 decoder
if stop_flag_now_int:
seq_lens_decoder[bid] = 0
elif inactive:
stop_flag_now_int = 1 # padding slot 视为 stop
stop_sum += stop_flag_now_int
# print("stop_sum: ", stop_sum)
not_need_stop[0] = stop_sum < stop_nums[0]
# 返回引用,仅供一致性
return (
seq_lens_encoder,
seq_lens_decoder,
not_need_stop,
draft_tokens,
actual_draft_token_nums,
)
# ---------------- 生成随机输入 ----------------
def gen_inputs(
max_bsz=512, # 与 CUDA BlockSize 对齐
max_draft_tokens=16,
real_bsz=123, # 可自调;须 ≤ max_bsz
seed=2022,
):
"""生成随机测试输入数据"""
rng = np.random.default_rng(seed)
# 基本张量
@@ -122,89 +43,91 @@ def gen_inputs(
stop_nums = np.array([5], dtype=np.int64) # 阈值随意
# seq_lens_this_time 仅取 real_bsz 长度
seq_lens_this_time = rng.integers(1, max_draft_tokens, size=real_bsz, dtype=np.int32)
seq_lens_this_time = rng.integers(1, max_draft_tokens + 1, size=real_bsz, dtype=np.int32)
return {
"seq_lens_encoder": seq_lens_encoder,
"seq_lens_decoder": seq_lens_decoder,
"not_need_stop": not_need_stop,
"draft_tokens": draft_tokens,
"actual_draft_token_nums": actual_draft_nums,
"accept_tokens": accept_tokens,
"accept_num": accept_num,
"stop_flags": stop_flags,
"seq_lens_this_time": seq_lens_this_time,
"is_block_step": is_block_step,
"stop_nums": stop_nums,
# real_bsz = real_bsz,
# max_bsz = max_bsz,
# max_draft_tokens = max_draft_tokens
paddle.set_device("xpu:0")
data_xpu = {
"seq_lens_encoder": paddle.to_tensor(seq_lens_encoder),
"seq_lens_decoder": paddle.to_tensor(seq_lens_decoder),
"not_need_stop": paddle.to_tensor(not_need_stop).cpu(),
"draft_tokens": paddle.to_tensor(draft_tokens),
"actual_draft_token_nums": paddle.to_tensor(actual_draft_nums),
"accept_tokens": paddle.to_tensor(accept_tokens),
"accept_num": paddle.to_tensor(accept_num),
"stop_flags": paddle.to_tensor(stop_flags),
"seq_lens_this_time": paddle.to_tensor(seq_lens_this_time),
"is_block_step": paddle.to_tensor(is_block_step),
"stop_nums": paddle.to_tensor(stop_nums),
}
# ------------------- 单测主体 -------------------
inputs = gen_inputs(max_bsz=512, max_draft_tokens=32, real_bsz=201)
# ---- Paddle 端 ----
paddle_inputs = {}
for k, v in inputs.items():
if k in ("real_bsz", "max_bsz", "max_draft_tokens"):
paddle_inputs[k] = v # 纯 python int
else:
if k == "not_need_stop":
paddle_inputs[k] = paddle.to_tensor(v, place=paddle.CPUPlace())
else:
# 其余张量保持默认 place想测 GPU 就手动加 place=paddle.CUDAPlace(0)
paddle_inputs[k] = paddle.to_tensor(v)
# ---- NumPy 端 ----
# 为保证初值一致,这里必须复制 Paddle 入参的 numpy 值再传给参考实现
np_inputs = {
k: (paddle_inputs[k].numpy().copy() if isinstance(paddle_inputs[k], paddle.Tensor) else paddle_inputs[k])
for k in paddle_inputs
}
# 调用自定义算子
# print("seq_lens_encoder_xpu_before: ", paddle_inputs["seq_lens_encoder"])
out_pd = speculate_update_v3(**paddle_inputs)
# print("seq_lens_encoder_xpu_after: ", out_pd[0])
# print("not_need_stop: ", out_pd[2])
# speculate_update_v3 返回 5 个张量(与 Outputs 对应)
(
seq_lens_encoder_pd,
seq_lens_decoder_pd,
not_need_stop_pd,
draft_tokens_pd,
actual_draft_nums_pd,
) = out_pd
# print("seq_lens_encoder_np_before: ", np_inputs["seq_lens_encoder"])
out_np = speculate_update_v3_np(**np_inputs)
# print("seq_lens_encoder_np_after: ", out_np[0])
# print("not_need_stop: ", out_np[2])
paddle.set_device("cpu")
data_cpu = {
"seq_lens_encoder": paddle.to_tensor(seq_lens_encoder),
"seq_lens_decoder": paddle.to_tensor(seq_lens_decoder),
"not_need_stop": paddle.to_tensor(not_need_stop),
"draft_tokens": paddle.to_tensor(draft_tokens),
"actual_draft_token_nums": paddle.to_tensor(actual_draft_nums),
"accept_tokens": paddle.to_tensor(accept_tokens),
"accept_num": paddle.to_tensor(accept_num),
"stop_flags": paddle.to_tensor(stop_flags),
"seq_lens_this_time": paddle.to_tensor(seq_lens_this_time),
"is_block_step": paddle.to_tensor(is_block_step),
"stop_nums": paddle.to_tensor(stop_nums),
}
return data_xpu, data_cpu
# ---------------- 校对 ----------------
names = [
"seq_lens_encoder",
"seq_lens_decoder",
"not_need_stop",
"draft_tokens",
"actual_draft_token_nums",
]
pd_tensors = [
seq_lens_encoder_pd,
seq_lens_decoder_pd,
not_need_stop_pd,
draft_tokens_pd,
actual_draft_nums_pd,
]
class TestSpeculateUpdateV3(unittest.TestCase):
"""测试 speculate_update_v3 算子"""
for name, pd_val, np_val in zip(names, pd_tensors, out_np):
pd_arr = pd_val.numpy()
ok = np.array_equal(pd_arr, np_val)
print(f"{name:25s} equal :", ok)
def test_op_vs_golden(self, max_bsz=512, max_draft_tokens=16, real_bsz=123):
"""
核心测试:比较自定义算子的输出与纯 NumPy 参考实现的输出。
"""
# 1. gen inputs for cpu/xpu
data_xpu, data_cpu = gen_inputs(max_bsz=max_bsz, max_draft_tokens=max_draft_tokens, real_bsz=real_bsz)
# 也可以加 assert配合 pytest
# assert all(np.array_equal(p.numpy(), n) for p,n in zip(pd_tensors, out_np))
# 3. run xpu kernel
speculate_update_v3(**data_xpu)
# 4. run cpu kernel
speculate_update_v3(**data_cpu)
# 5. format outputs
outputs_xpu = [
data_xpu["seq_lens_encoder"].cpu().numpy(),
data_xpu["seq_lens_decoder"].cpu().numpy(),
data_xpu["not_need_stop"].cpu().numpy(),
data_xpu["draft_tokens"].cpu().numpy(),
data_xpu["actual_draft_token_nums"].cpu().numpy(),
]
outputs_cpu = [
data_cpu["seq_lens_encoder"].numpy(),
data_cpu["seq_lens_decoder"].numpy(),
data_cpu["not_need_stop"].numpy(),
data_cpu["draft_tokens"].numpy(),
data_cpu["actual_draft_token_nums"].numpy(),
]
output_names = [
"seq_lens_encoder",
"seq_lens_decoder",
"not_need_stop",
"draft_tokens",
"actual_draft_token_nums",
]
# 6. check outputs
for name, pd_out, np_out in zip(output_names, outputs_xpu, outputs_cpu):
with self.subTest(output_name=name):
np.testing.assert_allclose(
pd_out,
np_out,
atol=0,
rtol=1e-6,
err_msg=f"Output mismatch for tensor '{name}'.\nPaddle Output:\n{pd_out}\nGolden Output:\n{np_out}",
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,315 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Dict, Optional
import paddle
from fastdeploy import envs
from fastdeploy.model_executor.forward_meta import XPUForwardMeta
from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import ModelOutputData
if current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
adjust_batch,
gather_next_token,
get_infer_param,
get_padding_offset,
limit_thinking_content_length_v1,
limit_thinking_content_length_v2,
update_inputs_v1,
)
def xpu_pre_process(
input_ids: paddle.Tensor,
seq_lens_this_time: int,
share_inputs: Dict,
use_speculate_method: bool,
block_size: int,
draft_tokens: Optional[paddle.Tensor] = None,
seq_lens_encoder: Optional[paddle.Tensor] = None,
seq_lens_decoder: Optional[paddle.Tensor] = None,
is_profiling: bool = False,
) -> XPUForwardMeta:
""" """
max_len = input_ids.shape[1]
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
token_num = paddle.sum(seq_lens_this_time)
(
ids_remove_padding,
cum_offsets,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
share_inputs["ids_remove_padding"] = None # set this after adjust batch
share_inputs["cum_offsets"] = cum_offsets
share_inputs["batch_id_per_token"] = batch_id_per_token
share_inputs["cu_seqlens_q"] = cu_seqlens_q
share_inputs["cu_seqlens_k"] = cu_seqlens_k
xpu_forward_meta = XPUForwardMeta(
ids_remove_padding=share_inputs["ids_remove_padding"],
rotary_embs=share_inputs["rope_emb"],
attn_backend=None,
seq_lens_encoder=share_inputs["seq_lens_encoder"],
seq_lens_decoder=share_inputs["seq_lens_decoder"],
seq_lens_this_time=share_inputs["seq_lens_this_time"],
cum_offsets=share_inputs["cum_offsets"],
batch_id_per_token=share_inputs["batch_id_per_token"],
cu_seqlens_q=share_inputs["cu_seqlens_q"],
cu_seqlens_k=share_inputs["cu_seqlens_k"],
block_tables=share_inputs["block_tables"],
caches=share_inputs["caches"],
)
(
xpu_forward_meta.encoder_batch_map,
xpu_forward_meta.decoder_batch_map,
xpu_forward_meta.encoder_batch_idx,
xpu_forward_meta.decoder_batch_idx,
xpu_forward_meta.encoder_seq_lod,
xpu_forward_meta.decoder_seq_lod,
xpu_forward_meta.encoder_kv_lod,
xpu_forward_meta.prefix_len,
xpu_forward_meta.decoder_context_len,
xpu_forward_meta.decoder_context_len_cache,
xpu_forward_meta.prefix_block_tables,
xpu_forward_meta.encoder_batch_map_cpu,
xpu_forward_meta.decoder_batch_map_cpu,
xpu_forward_meta.encoder_batch_idx_cpu,
xpu_forward_meta.decoder_batch_idx_cpu,
xpu_forward_meta.encoder_seq_lod_cpu,
xpu_forward_meta.decoder_seq_lod_cpu,
xpu_forward_meta.encoder_kv_lod_cpu,
xpu_forward_meta.prefix_len_cpu,
xpu_forward_meta.decoder_context_len_cpu,
xpu_forward_meta.decoder_context_len_cache_cpu,
xpu_forward_meta.len_info_cpu,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, xpu_forward_meta.block_tables, block_size
)
xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0]
xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1]
xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2]
adjusted_input = adjust_batch(
ids_remove_padding.reshape([-1, 1]),
cum_offsets,
xpu_forward_meta.encoder_seq_lod,
xpu_forward_meta.decoder_seq_lod,
xpu_forward_meta.encoder_batch_idx,
xpu_forward_meta.decoder_batch_idx,
xpu_forward_meta.encoder_seq_lod_cpu,
xpu_forward_meta.decoder_seq_lod_cpu,
xpu_forward_meta.encoder_batch_idx_cpu,
xpu_forward_meta.decoder_batch_idx_cpu,
xpu_forward_meta.len_info_cpu,
None, # output_padding_offset
-1, # max bs
)
adjusted_input = adjusted_input.squeeze(1)
share_inputs["ids_remove_padding"] = adjusted_input
xpu_forward_meta.ids_remove_padding = adjusted_input
# Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
xpu_forward_meta.is_profiling = is_profiling
return xpu_forward_meta
def xpu_process_output(
forward_output,
cum_offsets: paddle.Tensor,
xpu_forward_meta: XPUForwardMeta,
share_inputs,
) -> paddle.Tensor:
""" """
output_padding_offset = share_inputs.get("output_padding_offset", None)
hiddden_states = gather_next_token(
forward_output,
cum_offsets,
xpu_forward_meta.encoder_seq_lod,
xpu_forward_meta.decoder_seq_lod,
xpu_forward_meta.encoder_batch_map,
xpu_forward_meta.decoder_batch_map,
xpu_forward_meta.encoder_seq_lod_cpu,
xpu_forward_meta.decoder_seq_lod_cpu,
xpu_forward_meta.encoder_batch_map_cpu,
xpu_forward_meta.decoder_batch_map_cpu,
xpu_forward_meta.len_info_cpu,
output_padding_offset, # output_padding_offset
-1, # max_input_length
)
return hiddden_states
def xpu_post_process_normal(
sampled_token_ids: paddle.Tensor,
model_output: ModelOutputData,
share_inputs: Dict[str, paddle.Tensor],
block_size: int = 64,
skip_save_output: bool = False,
think_end_id: int = None,
line_break_id: int = None,
) -> None:
""" """
from fastdeploy.model_executor.ops.xpu import (
save_output,
set_stop_value_multi_ends,
update_inputs,
)
if think_end_id > 0:
limit_strategy = envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR
max_think_lens = share_inputs["max_think_lens"]
step_idx = share_inputs["step_idx"]
limit_think_status = share_inputs["limit_think_status"]
stop_flags = share_inputs["stop_flags"]
eos_token_ids = share_inputs["eos_token_id"]
if limit_strategy == "</think>":
# for ernie-45-vl
limit_thinking_content_length_v1(
sampled_token_ids,
max_think_lens,
step_idx,
limit_think_status,
stop_flags,
eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题
think_end_id,
)
elif limit_strategy == "\n</think>\n\n":
# for ernie-x1
assert line_break_id > 0
limit_thinking_content_length_v2(
sampled_token_ids,
max_think_lens,
step_idx,
limit_think_status,
stop_flags,
think_end_id,
line_break_id,
)
else:
raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.")
# 1. Set stop value
paddle.assign(
paddle.where(
model_output.stop_flags,
model_output.step_idx,
model_output.step_idx + 1,
),
model_output.step_idx,
)
length_cond = paddle.greater_equal(model_output.step_idx, model_output.max_dec_len)
paddle.assign(
paddle.logical_or(model_output.stop_flags, length_cond),
model_output.stop_flags,
)
set_stop_value_multi_ends(
sampled_token_ids,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.eos_token_id,
model_output.next_tokens,
False,
) # multi ends
# 2. Update the input buffer of the model
with paddle.framework._no_check_dy2st_diff():
if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output:
update_inputs_v1(
model_output.stop_flags,
model_output.not_need_stop,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
share_inputs["step_seq_lens_decoder"],
share_inputs["prompt_lens"],
sampled_token_ids,
model_output.input_ids,
share_inputs["block_tables"],
model_output.stop_nums,
model_output.next_tokens,
model_output.is_block_step,
block_size,
)
else:
update_inputs(
model_output.stop_flags,
model_output.not_need_stop,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.input_ids,
model_output.stop_nums,
sampled_token_ids,
model_output.is_block_step,
)
# 3. Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach.
if not skip_save_output:
save_output(
sampled_token_ids,
model_output.not_need_stop,
model_output.mp_rank,
False, # use_ep
)
def step_xpu(
share_inputs: Dict[str, paddle.Tensor],
block_size: int,
enc_dec_block_num: int,
) -> None:
"""
TODO(gongshaotian): normalization name
"""
from fastdeploy.model_executor.ops.xpu import step_paddle
step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
block_size,
enc_dec_block_num,
)

View File

@@ -17,7 +17,7 @@
import os
import random
import time
from typing import Dict, List, Optional
from typing import List, Optional
import numpy as np
import paddle
@@ -28,7 +28,7 @@ from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request, RequestType
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.utils import (
profile_run_guard,
sot_warmup_guard,
@@ -43,17 +43,17 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.model_executor.ops.xpu import (
adjust_batch,
create_kv_signal_sender,
destroy_kv_signal_sender,
get_infer_param,
get_padding_offset,
limit_thinking_content_length_v1,
limit_thinking_content_length_v2,
recover_decode_task,
set_data_ipc,
share_external_data,
update_inputs_v1,
)
from fastdeploy.model_executor.xpu_pre_and_post_process import ( # xpu_post_process_specualate, # TODO(chenhuan09): add xpu_post_process_specualate
step_xpu,
xpu_post_process_normal,
xpu_pre_process,
xpu_process_output,
)
from fastdeploy.utils import get_logger
from fastdeploy.worker.model_runner_base import ModelRunnerBase
@@ -62,282 +62,6 @@ from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
logger = get_logger("xpu_model_runner", "xpu_model_runner.log")
def xpu_pre_process(
input_ids: paddle.Tensor,
seq_lens_this_time: int,
share_inputs: Dict,
use_speculate_method: bool,
block_size: int,
draft_tokens: Optional[paddle.Tensor] = None,
seq_lens_encoder: Optional[paddle.Tensor] = None,
seq_lens_decoder: Optional[paddle.Tensor] = None,
is_profiling: bool = False,
) -> XPUForwardMeta:
""" """
max_len = input_ids.shape[1]
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
token_num = paddle.sum(seq_lens_this_time)
(
ids_remove_padding,
cum_offsets,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
share_inputs["ids_remove_padding"] = None # set this after adjust batch
share_inputs["cum_offsets"] = cum_offsets
share_inputs["batch_id_per_token"] = batch_id_per_token
share_inputs["cu_seqlens_q"] = cu_seqlens_q
share_inputs["cu_seqlens_k"] = cu_seqlens_k
xpu_forward_meta = XPUForwardMeta(
ids_remove_padding=share_inputs["ids_remove_padding"],
rotary_embs=share_inputs["rope_emb"],
attn_backend=None,
seq_lens_encoder=share_inputs["seq_lens_encoder"],
seq_lens_decoder=share_inputs["seq_lens_decoder"],
seq_lens_this_time=share_inputs["seq_lens_this_time"],
cum_offsets=share_inputs["cum_offsets"],
batch_id_per_token=share_inputs["batch_id_per_token"],
cu_seqlens_q=share_inputs["cu_seqlens_q"],
cu_seqlens_k=share_inputs["cu_seqlens_k"],
block_tables=share_inputs["block_tables"],
caches=share_inputs["caches"],
)
(
xpu_forward_meta.encoder_batch_map,
xpu_forward_meta.decoder_batch_map,
xpu_forward_meta.encoder_batch_idx,
xpu_forward_meta.decoder_batch_idx,
xpu_forward_meta.encoder_seq_lod,
xpu_forward_meta.decoder_seq_lod,
xpu_forward_meta.encoder_kv_lod,
xpu_forward_meta.prefix_len,
xpu_forward_meta.decoder_context_len,
xpu_forward_meta.decoder_context_len_cache,
xpu_forward_meta.prefix_block_tables,
xpu_forward_meta.encoder_batch_map_cpu,
xpu_forward_meta.decoder_batch_map_cpu,
xpu_forward_meta.encoder_batch_idx_cpu,
xpu_forward_meta.decoder_batch_idx_cpu,
xpu_forward_meta.encoder_seq_lod_cpu,
xpu_forward_meta.decoder_seq_lod_cpu,
xpu_forward_meta.encoder_kv_lod_cpu,
xpu_forward_meta.prefix_len_cpu,
xpu_forward_meta.decoder_context_len_cpu,
xpu_forward_meta.decoder_context_len_cache_cpu,
xpu_forward_meta.len_info_cpu,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, xpu_forward_meta.block_tables, block_size
)
xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0]
xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1]
xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2]
adjusted_input = adjust_batch(
ids_remove_padding.reshape([-1, 1]),
cum_offsets,
xpu_forward_meta.encoder_seq_lod,
xpu_forward_meta.encoder_batch_idx,
xpu_forward_meta.decoder_batch_idx,
xpu_forward_meta.encoder_seq_lod_cpu,
xpu_forward_meta.encoder_batch_idx_cpu,
xpu_forward_meta.decoder_batch_idx_cpu,
xpu_forward_meta.enc_batch,
xpu_forward_meta.dec_batch,
None, # output_padding_offset
-1, # max_input_length
)
adjusted_input = adjusted_input.squeeze(1)
share_inputs["ids_remove_padding"] = adjusted_input
xpu_forward_meta.ids_remove_padding = adjusted_input
# Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
xpu_forward_meta.is_profiling = is_profiling
return xpu_forward_meta
def xpu_process_output(
forward_output,
cum_offsets: paddle.Tensor,
xpu_forward_meta: XPUForwardMeta,
) -> paddle.Tensor:
""" """
from fastdeploy.model_executor.ops.xpu import gather_next_token
hiddden_states = gather_next_token(
forward_output,
cum_offsets,
xpu_forward_meta.encoder_seq_lod,
xpu_forward_meta.encoder_batch_map,
xpu_forward_meta.decoder_batch_map,
xpu_forward_meta.encoder_seq_lod_cpu,
xpu_forward_meta.encoder_batch_map_cpu,
xpu_forward_meta.decoder_batch_map_cpu,
xpu_forward_meta.enc_batch,
xpu_forward_meta.dec_batch,
None, # output_padding_offset
-1, # max_input_length
)
return hiddden_states
def xpu_post_process(
sampled_token_ids: paddle.Tensor,
model_output: ModelOutputData,
share_inputs: Dict[str, paddle.Tensor],
block_size: int = 64,
skip_save_output: bool = False,
think_end_id: int = None,
line_break_id: int = None,
) -> None:
""" """
from fastdeploy.model_executor.ops.xpu import (
save_output,
set_stop_value_multi_ends,
update_inputs,
)
if think_end_id > 0:
limit_strategy = envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR
max_think_lens = share_inputs["max_think_lens"]
step_idx = share_inputs["step_idx"]
limit_think_status = share_inputs["limit_think_status"]
stop_flags = share_inputs["stop_flags"]
eos_token_ids = share_inputs["eos_token_id"]
if limit_strategy == "</think>":
# for ernie-45-vl
limit_thinking_content_length_v1(
sampled_token_ids,
max_think_lens,
step_idx,
limit_think_status,
stop_flags,
eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题
think_end_id,
)
elif limit_strategy == "\n</think>\n\n":
# for ernie-x1
assert line_break_id > 0
limit_thinking_content_length_v2(
sampled_token_ids,
max_think_lens,
step_idx,
limit_think_status,
stop_flags,
think_end_id,
line_break_id,
)
else:
raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.")
# 1. Set stop value
paddle.assign(
paddle.where(
model_output.stop_flags,
model_output.step_idx,
model_output.step_idx + 1,
),
model_output.step_idx,
)
length_cond = paddle.greater_equal(model_output.step_idx, model_output.max_dec_len)
paddle.assign(
paddle.logical_or(model_output.stop_flags, length_cond),
model_output.stop_flags,
)
set_stop_value_multi_ends(
sampled_token_ids,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.eos_token_id,
model_output.next_tokens,
False,
) # multi ends
# 2. Update the input buffer of the model
with paddle.framework._no_check_dy2st_diff():
if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output:
update_inputs_v1(
model_output.stop_flags,
model_output.not_need_stop,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
share_inputs["step_seq_lens_decoder"],
share_inputs["prompt_lens"],
sampled_token_ids,
model_output.input_ids,
share_inputs["block_tables"],
model_output.stop_nums,
model_output.next_tokens,
model_output.is_block_step,
block_size,
)
else:
update_inputs(
model_output.stop_flags,
model_output.not_need_stop,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.input_ids,
model_output.stop_nums,
sampled_token_ids,
model_output.is_block_step,
)
# 3. Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach.
if not skip_save_output:
save_output(
sampled_token_ids,
model_output.not_need_stop,
model_output.mp_rank,
False, # use_ep
)
def step_paddle(
share_inputs: Dict[str, paddle.Tensor],
block_size: int,
enc_dec_block_num: int,
) -> None:
"""
TODO(gongshaotian): normalization name
"""
from fastdeploy.model_executor.ops.xpu import step_paddle
step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
block_size,
enc_dec_block_num,
)
class XPUModelRunner(ModelRunnerBase):
""" """
@@ -1212,8 +936,9 @@ class XPUModelRunner(ModelRunnerBase):
forward_meta=self.forward_meta,
)
hidden_states = xpu_process_output(model_output, self.share_inputs["cum_offsets"], self.forward_meta)
hidden_states = xpu_process_output(
model_output, self.share_inputs["cum_offsets"], self.forward_meta, self.share_inputs
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)
sampler_output = self.sampler(logits, self.sampling_metadata)
@@ -1247,7 +972,7 @@ class XPUModelRunner(ModelRunnerBase):
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
)
xpu_post_process(
xpu_post_process_normal(
sampled_token_ids=sampler_output.sampled_token_ids,
model_output=model_output_data,
share_inputs=self.share_inputs,
@@ -1260,7 +985,7 @@ class XPUModelRunner(ModelRunnerBase):
# 7. Updata 'infer_seed' and step_paddle()
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
step_paddle(
step_xpu(
self.share_inputs,
self.cache_config.block_size,
self.cache_config.enc_dec_block_num,