mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] support kernel for mtp(base) (#4748)
* [XPU] support kernel for mtp(base) * [XPU] support kernel for mtp(base) * format * format * format * fix gather next token * fix step && add test * fix * mv pre/post process * add adjust batch / gather next token for mtp * fix code style * fix mtp kenrel name * fix mtp kernel test * mv xpu pre/post process * mv xpu pre/post process
This commit is contained in:
@@ -18,38 +18,49 @@
|
||||
#include "utility/helper.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> AdjustBatchKernel(
|
||||
const paddle::Tensor &x, // [token_num, dim_embed]
|
||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor &encoder_seq_lod,
|
||||
const paddle::Tensor &decoder_seq_lod,
|
||||
const paddle::Tensor &encoder_batch_idx,
|
||||
const paddle::Tensor &decoder_batch_idx,
|
||||
const paddle::Tensor &encoder_seq_lod_cpu,
|
||||
const paddle::Tensor &decoder_seq_lod_cpu,
|
||||
const paddle::Tensor &encoder_batch_idx_cpu,
|
||||
const paddle::Tensor &decoder_batch_idx_cpu,
|
||||
const paddle::Tensor &enc_batch_tensor,
|
||||
const paddle::Tensor &dec_batch_tensor,
|
||||
const paddle::Tensor &len_info_cpu,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
auto ctx = static_cast<const phi::XPUContext *>(dev_ctx)->x_context();
|
||||
PD_CHECK(x.dtype() == T);
|
||||
PD_CHECK(x.dims().size() == 2);
|
||||
|
||||
if (x.is_cpu()) {
|
||||
ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU);
|
||||
}
|
||||
using XPUType = typename XPUTypeTrait<typename PDTraits<T>::DataType>::Type;
|
||||
using data_t = typename PDTraits<T>::data_t;
|
||||
const int token_num = x.dims()[0];
|
||||
const int dim = x.dims()[1];
|
||||
const int bsz = cum_offsets.shape()[0];
|
||||
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
|
||||
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
|
||||
int enc_batch = len_info_cpu.data<int32_t>()[0];
|
||||
int dec_batch = len_info_cpu.data<int32_t>()[1];
|
||||
|
||||
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp{
|
||||
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
|
||||
enc_batch + 1,
|
||||
const_cast<int32_t *>(encoder_seq_lod.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> decoder_seqs_lods_vp{
|
||||
const_cast<int32_t *>(decoder_seq_lod_cpu.data<int32_t>()),
|
||||
dec_batch + 1,
|
||||
const_cast<int32_t *>(decoder_seq_lod.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp{
|
||||
const_cast<int32_t *>(encoder_batch_idx_cpu.data<int32_t>()),
|
||||
enc_batch,
|
||||
@@ -59,13 +70,14 @@ std::vector<paddle::Tensor> AdjustBatchKernel(
|
||||
dec_batch,
|
||||
const_cast<int32_t *>(decoder_batch_idx.data<int32_t>())};
|
||||
|
||||
auto out = paddle::full({token_num, dim}, -2, x.type(), x.place());
|
||||
auto out = paddle::empty({token_num, dim}, x.type(), x.place());
|
||||
|
||||
int r = baidu::xpu::api::plugin::eb_adjust_batch<XPUType, XPUType>(
|
||||
xpu_ctx->x_context(),
|
||||
ctx,
|
||||
reinterpret_cast<const XPUType *>(x.data<data_t>()),
|
||||
reinterpret_cast<XPUType *>(out.data<data_t>()),
|
||||
encoder_seqs_lods_vp,
|
||||
decoder_seqs_lods_vp,
|
||||
encoder_batch_map_vp,
|
||||
decoder_batch_map_vp,
|
||||
dim);
|
||||
@@ -76,13 +88,14 @@ using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
|
||||
const paddle::Tensor &x, // [token_num, dim_embed]
|
||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor &encoder_seq_lod,
|
||||
const paddle::Tensor &decoder_seq_lod,
|
||||
const paddle::Tensor &encoder_batch_idx,
|
||||
const paddle::Tensor &decoder_batch_idx,
|
||||
const paddle::Tensor &encoder_seq_lod_cpu,
|
||||
const paddle::Tensor &decoder_seq_lod_cpu,
|
||||
const paddle::Tensor &encoder_batch_idx_cpu,
|
||||
const paddle::Tensor &decoder_batch_idx_cpu,
|
||||
const paddle::Tensor &enc_batch_tensor,
|
||||
const paddle::Tensor &dec_batch_tensor,
|
||||
const paddle::Tensor &len_info_cpu,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length);
|
||||
|
||||
@@ -90,13 +103,14 @@ std::vector<paddle::Tensor> AdjustBatch(
|
||||
const paddle::Tensor &x, // [token_num, dim_embed]
|
||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor &encoder_seq_lod,
|
||||
const paddle::Tensor &decoder_seq_lod,
|
||||
const paddle::Tensor &encoder_batch_idx,
|
||||
const paddle::Tensor &decoder_batch_idx,
|
||||
const paddle::Tensor &encoder_seq_lod_cpu,
|
||||
const paddle::Tensor &decoder_seq_lod_cpu,
|
||||
const paddle::Tensor &encoder_batch_idx_cpu,
|
||||
const paddle::Tensor &decoder_batch_idx_cpu,
|
||||
const paddle::Tensor &enc_batch_tensor,
|
||||
const paddle::Tensor &dec_batch_tensor,
|
||||
const paddle::Tensor &len_info_cpu,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
AdjustBatchKernelFuncPtr func = nullptr;
|
||||
@@ -108,12 +122,12 @@ std::vector<paddle::Tensor> AdjustBatch(
|
||||
case paddle::DataType::FLOAT16:
|
||||
func = &AdjustBatchKernel<paddle::DataType::FLOAT16>;
|
||||
break;
|
||||
case paddle::DataType::FLOAT32:
|
||||
func = &AdjustBatchKernel<paddle::DataType::FLOAT32>;
|
||||
break;
|
||||
case paddle::DataType::INT64:
|
||||
func = &AdjustBatchKernel<paddle::DataType::INT64>;
|
||||
break;
|
||||
case paddle::DataType::FLOAT32:
|
||||
func = &AdjustBatchKernel<paddle::DataType::FLOAT32>;
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type: ", x.dtype());
|
||||
}
|
||||
@@ -121,13 +135,14 @@ std::vector<paddle::Tensor> AdjustBatch(
|
||||
return func(x,
|
||||
cum_offsets,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
enc_batch_tensor,
|
||||
dec_batch_tensor,
|
||||
len_info_cpu,
|
||||
output_padding_offset,
|
||||
max_input_length);
|
||||
}
|
||||
@@ -136,13 +151,14 @@ std::vector<std::vector<int64_t>> AdjustBatchInferShape(
|
||||
const std::vector<int64_t> &x_shape,
|
||||
const std::vector<int64_t> &cum_offsets_shape,
|
||||
const std::vector<int64_t> &encoder_seq_lod_shape,
|
||||
const std::vector<int64_t> &decoder_seq_lod_shape,
|
||||
const std::vector<int64_t> &encoder_batch_idx_shape,
|
||||
const std::vector<int64_t> &decoder_batch_idx_shape,
|
||||
const std::vector<int64_t> &encoder_seq_lod_cpu_shape,
|
||||
const std::vector<int64_t> &decoder_seq_lod_cpu_shape,
|
||||
const std::vector<int64_t> &encoder_batch_idx_cpu_shape,
|
||||
const std::vector<int64_t> &decoder_batch_idx_cpu_shape,
|
||||
const std::vector<int64_t> &enc_batch_tensor_shape,
|
||||
const std::vector<int64_t> &dec_batch_tensor_shape,
|
||||
const std::vector<int64_t> &len_info_cpu_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
|
||||
if (output_padding_offset_shape) {
|
||||
PD_THROW("speculative decoding is not supported in XPU.");
|
||||
@@ -156,28 +172,30 @@ std::vector<paddle::DataType> AdjustBatchInferDtype(
|
||||
const paddle::DataType &x_dtype,
|
||||
const paddle::DataType &cum_offsets_dtype,
|
||||
const paddle::DataType &encoder_seq_lod_dtype,
|
||||
const paddle::DataType &decoder_seq_lod_dtype,
|
||||
const paddle::DataType &encoder_batch_idx_dtype,
|
||||
const paddle::DataType &decoder_batch_idx_dtype,
|
||||
const paddle::DataType &encoder_seq_lod_cpu_dtype,
|
||||
const paddle::DataType &decoder_seq_lod_cpu_dtype,
|
||||
const paddle::DataType &encoder_batch_idx_cpu_dtype,
|
||||
const paddle::DataType &decoder_batch_idx_cpu_dtype,
|
||||
const paddle::DataType &enc_batch_tensor_dtype,
|
||||
const paddle::DataType &dec_batch_tensor_dtype,
|
||||
const paddle::DataType &len_info_cpu_dtype,
|
||||
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
|
||||
return {x_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(adjust_batch)
|
||||
PD_BUILD_STATIC_OP(adjust_batch)
|
||||
.Inputs({"x",
|
||||
"cum_offsets",
|
||||
"encoder_seq_lod",
|
||||
"decoder_seq_lod",
|
||||
"encoder_batch_idx",
|
||||
"decoder_batch_idx",
|
||||
"encoder_seq_lod_cpu",
|
||||
"decoder_seq_lod_cpu",
|
||||
"encoder_batch_idx_cpu",
|
||||
"decoder_batch_idx_cpu",
|
||||
"enc_batch_tensor",
|
||||
"dec_batch_tensor",
|
||||
"len_info_cpu",
|
||||
paddle::Optional("output_padding_offset")})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"max_input_length: int"})
|
||||
|
||||
@@ -722,7 +722,6 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
: quant_v_scale_inv,
|
||||
nullptr, // o_maxptr
|
||||
param.head_dim); // vo_head_dim
|
||||
PD_CHECK(0, "speculative_attention unimplemented");
|
||||
PD_CHECK(ret == api::SUCCESS,
|
||||
"xfa::speculative_attention_decoder failed.");
|
||||
if (!Eq_len) {
|
||||
|
||||
@@ -13,107 +13,169 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include <xft/xdnn_plugin.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
std::vector<paddle::Tensor> GatherNextToken(
|
||||
const paddle::Tensor &tmp_out, // [token_num, dim_embed]
|
||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor &encoder_seq_lod,
|
||||
const paddle::Tensor &encoder_batch_map,
|
||||
const paddle::Tensor &decoder_batch_map,
|
||||
const paddle::Tensor &encoder_seq_lod_cpu,
|
||||
const paddle::Tensor &encoder_batch_map_cpu,
|
||||
const paddle::Tensor &decoder_batch_map_cpu,
|
||||
const paddle::Tensor &enc_batch_tensor,
|
||||
const paddle::Tensor &dec_batch_tensor,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
const paddle::Tensor& x, // [token_num, dim_embed]
|
||||
const paddle::Tensor& cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor& encoder_seq_lod,
|
||||
const paddle::Tensor& decoder_seq_lod,
|
||||
const paddle::Tensor& encoder_batch_map,
|
||||
const paddle::Tensor& decoder_batch_map,
|
||||
const paddle::Tensor& encoder_seq_lod_cpu,
|
||||
const paddle::Tensor& decoder_seq_lod_cpu,
|
||||
const paddle::Tensor& encoder_batch_map_cpu,
|
||||
const paddle::Tensor& decoder_batch_map_cpu,
|
||||
const paddle::Tensor& len_info_cpu,
|
||||
const paddle::optional<paddle::Tensor>& output_padding_offset,
|
||||
int max_bsz) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
auto ctx = static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
|
||||
if (x.is_cpu()) {
|
||||
ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU);
|
||||
}
|
||||
using XPUType =
|
||||
typename XPUTypeTrait<bfloat16>::Type; // only support bfloat16
|
||||
typedef paddle::bfloat16 data_t;
|
||||
const int dim = tmp_out.dims()[1];
|
||||
const int bsz = cum_offsets.shape()[0];
|
||||
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
|
||||
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
|
||||
|
||||
const int dim = x.dims()[1];
|
||||
const int token_num = x.shape()[0];
|
||||
int bsz = cum_offsets.shape()[0];
|
||||
int enc_batch = len_info_cpu.data<int32_t>()[0];
|
||||
int dec_batch = len_info_cpu.data<int32_t>()[1];
|
||||
if (max_bsz > 0) {
|
||||
PD_CHECK(encoder_batch_map_cpu.data<int32_t>()[enc_batch - 1] <= max_bsz,
|
||||
"encoder_batch_map_cpu check failed");
|
||||
PD_CHECK(decoder_batch_map_cpu.data<int32_t>()[dec_batch - 1] <= max_bsz,
|
||||
"decoder_batch_map_cpu check failed");
|
||||
bsz = max_bsz;
|
||||
}
|
||||
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp{
|
||||
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
|
||||
const_cast<int32_t*>(encoder_seq_lod_cpu.data<int32_t>()),
|
||||
enc_batch + 1,
|
||||
const_cast<int32_t *>(encoder_seq_lod.data<int32_t>())};
|
||||
const_cast<int32_t*>(encoder_seq_lod.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> decoder_seqs_lods_vp{
|
||||
const_cast<int32_t*>(decoder_seq_lod_cpu.data<int32_t>()),
|
||||
dec_batch + 1,
|
||||
const_cast<int32_t*>(decoder_seq_lod.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp{
|
||||
const_cast<int32_t *>(encoder_batch_map_cpu.data<int32_t>()),
|
||||
const_cast<int32_t*>(encoder_batch_map_cpu.data<int32_t>()),
|
||||
enc_batch,
|
||||
const_cast<int32_t *>(encoder_batch_map.data<int32_t>())};
|
||||
const_cast<int32_t*>(encoder_batch_map.data<int32_t>())};
|
||||
baidu::xpu::api::VectorParam<int32_t> decoder_batch_map_vp{
|
||||
const_cast<int32_t *>(decoder_batch_map_cpu.data<int32_t>()),
|
||||
const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()),
|
||||
dec_batch,
|
||||
const_cast<int32_t *>(decoder_batch_map.data<int32_t>())};
|
||||
const_cast<int32_t*>(decoder_batch_map.data<int32_t>())};
|
||||
|
||||
auto out = paddle::full({bsz, dim}, -2, tmp_out.type(), tmp_out.place());
|
||||
paddle::Tensor out;
|
||||
if (output_padding_offset) {
|
||||
int need_delete_token_num = 0;
|
||||
if (enc_batch > 0) {
|
||||
need_delete_token_num =
|
||||
encoder_seq_lod_cpu.data<int32_t>()[enc_batch] - enc_batch;
|
||||
}
|
||||
out = paddle::empty(
|
||||
{token_num - need_delete_token_num, dim}, x.type(), x.place());
|
||||
} else {
|
||||
out = paddle::empty({bsz, dim}, x.type(), x.place());
|
||||
}
|
||||
if (x.shape()[0] == 0) {
|
||||
return {out};
|
||||
}
|
||||
|
||||
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPUType *>(tmp_out.data<data_t>()),
|
||||
reinterpret_cast<XPUType *>(out.data<data_t>()),
|
||||
encoder_seqs_lods_vp,
|
||||
encoder_batch_map_vp,
|
||||
decoder_batch_map_vp,
|
||||
dim);
|
||||
if (enc_batch <= 0) {
|
||||
out = x.copy_to(x.place(), false);
|
||||
} else {
|
||||
if (output_padding_offset) {
|
||||
int r =
|
||||
baidu::xpu::api::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
|
||||
ctx,
|
||||
reinterpret_cast<const XPUType*>(x.data<data_t>()),
|
||||
reinterpret_cast<XPUType*>(out.data<data_t>()),
|
||||
encoder_seqs_lods_vp,
|
||||
decoder_seqs_lods_vp,
|
||||
encoder_batch_map_vp,
|
||||
decoder_batch_map_vp,
|
||||
dim);
|
||||
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
|
||||
} else {
|
||||
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
|
||||
ctx,
|
||||
reinterpret_cast<const XPUType*>(x.data<data_t>()),
|
||||
reinterpret_cast<XPUType*>(out.data<data_t>()),
|
||||
encoder_seqs_lods_vp,
|
||||
encoder_batch_map_vp,
|
||||
decoder_batch_map_vp,
|
||||
dim);
|
||||
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
|
||||
}
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
|
||||
const std::vector<int64_t> &tmp_out_shape,
|
||||
const std::vector<int64_t> &cum_offsets_shape,
|
||||
const std::vector<int64_t> &encoder_seq_lod_shape,
|
||||
const std::vector<int64_t> &encoder_batch_map_shape,
|
||||
const std::vector<int64_t> &decoder_batch_map_shape,
|
||||
const std::vector<int64_t> &encoder_seq_lod_cpu_shape,
|
||||
const std::vector<int64_t> &encoder_batch_map_cpu_shape,
|
||||
const std::vector<int64_t> &decoder_batch_map_cpu_shape,
|
||||
const std::vector<int64_t> &enc_batch_tensor_shape,
|
||||
const std::vector<int64_t> &dec_batch_tensor_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
|
||||
if (output_padding_offset_shape) {
|
||||
PD_THROW("speculative decoding is not supported in XPU.");
|
||||
}
|
||||
const std::vector<int64_t>& x_shape,
|
||||
const std::vector<int64_t>& cum_offsets_shape,
|
||||
const std::vector<int64_t>& encoder_seq_lod_shape,
|
||||
const std::vector<int64_t>& decoder_seq_lod_shape,
|
||||
const std::vector<int64_t>& encoder_batch_map_shape,
|
||||
const std::vector<int64_t>& decoder_batch_map_shape,
|
||||
const std::vector<int64_t>& encoder_seq_lod_cpu_shape,
|
||||
const std::vector<int64_t>& decoder_seq_lod_cpu_shape,
|
||||
const std::vector<int64_t>& encoder_batch_map_cpu_shape,
|
||||
const std::vector<int64_t>& decoder_batch_map_cpu_shape,
|
||||
const std::vector<int64_t>& len_info_cpu_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& output_padding_offset_shape) {
|
||||
// if (output_padding_offset_shape) {
|
||||
// PD_THROW("speculative decoding is not supported in XPU.");
|
||||
// }
|
||||
int64_t bsz = cum_offsets_shape[0];
|
||||
int64_t dim_embed = tmp_out_shape[1];
|
||||
return {{bsz, dim_embed}};
|
||||
int64_t dim_embed = x_shape[1];
|
||||
if (output_padding_offset_shape) {
|
||||
return {{-1, dim_embed}};
|
||||
} else {
|
||||
int64_t bsz = cum_offsets_shape[0];
|
||||
return {{bsz, dim_embed}};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GatherNextTokenInferDtype(
|
||||
const paddle::DataType &tmp_out_dtype,
|
||||
const paddle::DataType &cum_offsets_dtype,
|
||||
const paddle::DataType &encoder_seq_lod_dtype,
|
||||
const paddle::DataType &encoder_batch_map_dtype,
|
||||
const paddle::DataType &decoder_batch_map_dtype,
|
||||
const paddle::DataType &encoder_seq_lod_cpu_dtype,
|
||||
const paddle::DataType &encoder_batch_map_cpu_dtype,
|
||||
const paddle::DataType &decoder_batch_map_cpu_dtype,
|
||||
const paddle::DataType &enc_batch_tensor_dtype,
|
||||
const paddle::DataType &dec_batch_tensor_dtype,
|
||||
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
|
||||
return {tmp_out_dtype};
|
||||
const paddle::DataType& x_dtype,
|
||||
const paddle::DataType& cum_offsets_dtype,
|
||||
const paddle::DataType& encoder_seq_lod_dtype,
|
||||
const paddle::DataType& decoder_seq_lod_dtype,
|
||||
const paddle::DataType& encoder_batch_map_dtype,
|
||||
const paddle::DataType& decoder_batch_map_dtype,
|
||||
const paddle::DataType& encoder_seq_lod_cpu_dtype,
|
||||
const paddle::DataType& decoder_seq_lod_cpu_dtype,
|
||||
const paddle::DataType& encoder_batch_map_cpu_dtype,
|
||||
const paddle::DataType& decoder_batch_map_cpu_dtype,
|
||||
const paddle::DataType& len_info_cpu_dtype,
|
||||
const paddle::optional<paddle::DataType>& output_padding_offset_dtype) {
|
||||
return {x_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(gather_next_token)
|
||||
.Inputs({"tmp_out",
|
||||
PD_BUILD_STATIC_OP(gather_next_token)
|
||||
.Inputs({"x",
|
||||
"cum_offsets",
|
||||
"encoder_seq_lod",
|
||||
"decoder_seq_lod",
|
||||
"encoder_batch_map",
|
||||
"decoder_batch_map",
|
||||
"encoder_seq_lod_cpu",
|
||||
"decoder_seq_lod_cpu",
|
||||
"encoder_batch_map_cpu",
|
||||
"decoder_batch_map_cpu",
|
||||
"enc_batch_tensor",
|
||||
"dec_batch_tensor",
|
||||
"len_info_cpu",
|
||||
paddle::Optional("output_padding_offset")})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"max_input_length: int"})
|
||||
.Attrs({"max_bsz: int"})
|
||||
.SetKernelFn(PD_KERNEL(GatherNextToken))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype));
|
||||
|
||||
@@ -29,21 +29,23 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& seq_lens_encoder_record,
|
||||
const paddle::Tensor& seq_lens_decoder_record,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
const paddle::Tensor& base_model_stop_flags,
|
||||
const paddle::Tensor& base_model_is_block_step,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
api::Context* ctx = static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
|
||||
@@ -54,6 +56,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
int accept_tokens_len = accept_tokens.shape()[1];
|
||||
int input_ids_len = input_ids.shape()[1];
|
||||
int draft_tokens_len = draft_tokens.shape()[1];
|
||||
int pre_ids_len = pre_ids.shape()[1];
|
||||
constexpr int BlockSize = 512;
|
||||
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
|
||||
auto not_need_stop_gpu =
|
||||
not_need_stop.copy_to(seq_lens_this_time.place(), false);
|
||||
@@ -67,12 +71,13 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<int*>(seq_lens_encoder_record.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder_record.data<int>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(is_block_step.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
@@ -80,13 +85,16 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill);
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1);
|
||||
|
||||
PD_CHECK(r == 0, "xpu::plugin::draft_model_preprocess failed.");
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
@@ -102,12 +110,13 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"step_idx",
|
||||
"seq_lens_encoder_record",
|
||||
"seq_lens_decoder_record",
|
||||
"not_need_stop",
|
||||
"is_block_step",
|
||||
"batch_drop",
|
||||
"pre_ids",
|
||||
"accept_tokens",
|
||||
"accept_num",
|
||||
"base_model_seq_lens_this_time",
|
||||
"base_model_seq_lens_encoder",
|
||||
"base_model_seq_lens_decoder",
|
||||
"base_model_step_idx",
|
||||
@@ -123,11 +132,11 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"step_idx_out",
|
||||
"not_need_stop_out",
|
||||
"batch_drop_out",
|
||||
"seq_lens_encoder_record_out",
|
||||
"seq_lens_decoder_record_out"})
|
||||
.Attrs({"max_draft_token: int",
|
||||
"pre_ids_out"})
|
||||
.Attrs({"num_model_step: int",
|
||||
"truncate_first_token: bool",
|
||||
"splitwise_prefill: bool"})
|
||||
"splitwise_prefill: bool",
|
||||
"kvcache_scheduler_v1: bool"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
@@ -137,6 +146,5 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
{"step_idx", "step_idx_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"batch_drop", "batch_drop_out"},
|
||||
{"seq_lens_encoder_record", "seq_lens_encoder_record_out"},
|
||||
{"seq_lens_decoder_record", "seq_lens_decoder_record_out"}})
|
||||
{"pre_ids", "pre_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));
|
||||
|
||||
@@ -43,6 +43,8 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
{token_num_data}, paddle::DataType::INT64, input_ids.place());
|
||||
auto padding_offset = paddle::empty(
|
||||
{token_num_data}, paddle::DataType::INT32, input_ids.place());
|
||||
auto batch_id_per_token = paddle::empty(
|
||||
{token_num_data}, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_q =
|
||||
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
@@ -57,7 +59,7 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
|
||||
int r = baidu::xpu::api::plugin::speculate_get_padding_offset(
|
||||
xpu_ctx->x_context(),
|
||||
padding_offset.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
@@ -83,7 +85,7 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
|
||||
return {x_remove_padding,
|
||||
cum_offsets_out,
|
||||
padding_offset,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k}; // , enc_token_num, dec_token_num};
|
||||
}
|
||||
@@ -123,7 +125,7 @@ PD_BUILD_STATIC_OP(speculate_get_padding_offset)
|
||||
"seq_lens_encoder"})
|
||||
.Outputs({"x_remove_padding",
|
||||
"cum_offsets_out",
|
||||
"padding_offset",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"cu_seqlens_k"})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset))
|
||||
|
||||
@@ -35,7 +35,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
int save_each_rank) {
|
||||
bool save_each_rank) {
|
||||
// printf("enter save output");
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
return;
|
||||
|
||||
187
custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc
Normal file
187
custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc
Normal file
@@ -0,0 +1,187 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
void SpeculateStepPaddle(
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &ori_seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor &encoder_block_lens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &step_block_list,
|
||||
const paddle::Tensor &step_lens,
|
||||
const paddle::Tensor &recover_block_list,
|
||||
const paddle::Tensor &recover_lens,
|
||||
const paddle::Tensor &need_block_list,
|
||||
const paddle::Tensor &need_block_len,
|
||||
const paddle::Tensor &used_list_len,
|
||||
const paddle::Tensor &free_list,
|
||||
const paddle::Tensor &free_list_len,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &first_token_ids,
|
||||
const paddle::Tensor &accept_num,
|
||||
const int block_size,
|
||||
const int encoder_decoder_block_num,
|
||||
const int max_draft_tokens) {
|
||||
namespace api = baidu::xpu::api;
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
api::Context *ctx = xpu_ctx->x_context();
|
||||
if (seq_lens_this_time.is_cpu()) {
|
||||
ctx = new api::Context(api::kCPU);
|
||||
}
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
PADDLE_ENFORCE_LE(
|
||||
bsz,
|
||||
640,
|
||||
phi::errors::InvalidArgument(
|
||||
"Only support bsz <= 640, but received bsz is %d", bsz));
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
const int length = input_ids.shape()[1];
|
||||
const int pre_id_length = pre_ids.shape()[1];
|
||||
const int max_decoder_block_num = pre_id_length / block_size;
|
||||
int r = baidu::xpu::api::plugin::speculate_free_and_dispatch_block(
|
||||
ctx,
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<int *>(encoder_block_lens.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<int *>(step_block_list.data<int>()),
|
||||
const_cast<int *>(step_lens.data<int>()),
|
||||
const_cast<int *>(recover_block_list.data<int>()),
|
||||
const_cast<int *>(recover_lens.data<int>()),
|
||||
const_cast<int *>(need_block_list.data<int>()),
|
||||
const_cast<int *>(need_block_len.data<int>()),
|
||||
const_cast<int *>(used_list_len.data<int>()),
|
||||
const_cast<int *>(free_list.data<int>()),
|
||||
const_cast<int *>(free_list_len.data<int>()),
|
||||
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num,
|
||||
max_draft_tokens);
|
||||
PD_CHECK(r == 0, "speculate_free_and_dispatch_block failed.");
|
||||
auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false);
|
||||
int recover_lens_cpu_data = recover_lens_cpu.data<int>()[0];
|
||||
if (recover_lens_cpu_data > 0) {
|
||||
r = baidu::xpu::api::plugin::speculate_recover_block(
|
||||
ctx,
|
||||
const_cast<int *>(recover_block_list.data<int>()),
|
||||
const_cast<int *>(recover_lens.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
ori_seq_lens_encoder.data<int>(),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
seq_lens_decoder.data<int>(),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<int *>(free_list.data<int>()),
|
||||
const_cast<int *>(free_list_len.data<int>()),
|
||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||
pre_ids.data<int64_t>(),
|
||||
step_idx.data<int64_t>(),
|
||||
encoder_block_lens.data<int>(),
|
||||
used_list_len.data<int>(),
|
||||
next_tokens.data<int64_t>(),
|
||||
first_token_ids.data<int64_t>(),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
pre_id_length);
|
||||
PD_CHECK(r == 0, "speculate_recover_block failed.");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_step_paddle)
|
||||
.Inputs({"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"ori_seq_lens_encoder",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"block_tables",
|
||||
"encoder_block_lens",
|
||||
"is_block_step",
|
||||
"step_block_list",
|
||||
"step_lens",
|
||||
"recover_block_list",
|
||||
"recover_lens",
|
||||
"need_block_list",
|
||||
"need_block_len",
|
||||
"used_list_len",
|
||||
"free_list",
|
||||
"free_list_len",
|
||||
"input_ids",
|
||||
"pre_ids",
|
||||
"step_idx",
|
||||
"next_tokens",
|
||||
"first_token_ids",
|
||||
"accept_num"})
|
||||
.Attrs({"block_size: int",
|
||||
"encoder_decoder_block_num: int",
|
||||
"max_draft_tokens: int"})
|
||||
.Outputs({"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"block_tables_out",
|
||||
"encoder_block_lens_out",
|
||||
"is_block_step_out",
|
||||
"step_block_list_out",
|
||||
"step_lens_out",
|
||||
"recover_block_list_out",
|
||||
"recover_lens_out",
|
||||
"need_block_list_out",
|
||||
"need_block_len_out",
|
||||
"used_list_len_out",
|
||||
"free_list_out",
|
||||
"free_list_len_out",
|
||||
"input_ids_out",
|
||||
"first_token_ids_out"})
|
||||
.SetInplaceMap({{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"block_tables", "block_tables_out"},
|
||||
{"encoder_block_lens", "encoder_block_lens_out"},
|
||||
{"is_block_step", "is_block_step_out"},
|
||||
{"step_block_list", "step_block_list_out"},
|
||||
{"step_lens", "step_lens_out"},
|
||||
{"recover_block_list", "recover_block_list_out"},
|
||||
{"recover_lens", "recover_lens_out"},
|
||||
{"need_block_list", "need_block_list_out"},
|
||||
{"need_block_len", "need_block_len_out"},
|
||||
{"used_list_len", "used_list_len_out"},
|
||||
{"free_list", "free_list_out"},
|
||||
{"free_list_len", "free_list_len_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"first_token_ids", "first_token_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateStepPaddle));
|
||||
@@ -45,7 +45,10 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &topp,
|
||||
int max_seq_len,
|
||||
int verify_window,
|
||||
bool enable_topp) {
|
||||
bool enable_topp,
|
||||
bool benchmark_mode,
|
||||
bool accept_all_drafts) {
|
||||
// TODO(chenhuan09):support accept_all_drafts
|
||||
auto bsz = accept_tokens.shape()[0];
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
auto max_draft_tokens = draft_tokens.shape()[1];
|
||||
@@ -133,7 +136,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
||||
max_seq_len,
|
||||
max_candidate_len,
|
||||
verify_window,
|
||||
prefill_one_step_stop);
|
||||
prefill_one_step_stop,
|
||||
benchmark_mode);
|
||||
} else {
|
||||
baidu::xpu::api::plugin::speculate_verify<false, true>(
|
||||
ctx,
|
||||
@@ -161,7 +165,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
||||
max_seq_len,
|
||||
max_candidate_len,
|
||||
verify_window,
|
||||
prefill_one_step_stop);
|
||||
prefill_one_step_stop,
|
||||
benchmark_mode);
|
||||
}
|
||||
} else {
|
||||
if (enable_topp) {
|
||||
@@ -191,7 +196,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
||||
max_seq_len,
|
||||
max_candidate_len,
|
||||
verify_window,
|
||||
prefill_one_step_stop);
|
||||
prefill_one_step_stop,
|
||||
benchmark_mode);
|
||||
} else {
|
||||
baidu::xpu::api::plugin::speculate_verify<false, false>(
|
||||
ctx,
|
||||
@@ -219,7 +225,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
||||
max_seq_len,
|
||||
max_candidate_len,
|
||||
verify_window,
|
||||
prefill_one_step_stop);
|
||||
prefill_one_step_stop,
|
||||
benchmark_mode);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -246,7 +253,11 @@ PD_BUILD_STATIC_OP(speculate_verify)
|
||||
"accept_num_out",
|
||||
"step_idx_out",
|
||||
"stop_flags_out"})
|
||||
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"})
|
||||
.Attrs({"max_seq_len: int",
|
||||
"verify_window: int",
|
||||
"enable_topp: bool",
|
||||
"benchmark_mode: bool",
|
||||
"accept_all_drafts: bool"})
|
||||
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
|
||||
{"accept_num", "accept_num_out"},
|
||||
{"step_idx", "step_idx_out"},
|
||||
|
||||
@@ -37,13 +37,14 @@ std::vector<paddle::Tensor> AdjustBatch(
|
||||
const paddle::Tensor& x, // [token_num, dim_embed]
|
||||
const paddle::Tensor& cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor& encoder_seq_lod,
|
||||
const paddle::Tensor& decoder_seq_lod,
|
||||
const paddle::Tensor& encoder_batch_idx,
|
||||
const paddle::Tensor& decoder_batch_idx,
|
||||
const paddle::Tensor& encoder_seq_lod_cpu,
|
||||
const paddle::Tensor& decoder_seq_lod_cpu,
|
||||
const paddle::Tensor& encoder_batch_idx_cpu,
|
||||
const paddle::Tensor& decoder_batch_idx_cpu,
|
||||
const paddle::Tensor& enc_batch_tensor,
|
||||
const paddle::Tensor& dec_batch_tensor,
|
||||
const paddle::Tensor& len_info_cpu,
|
||||
const paddle::optional<paddle::Tensor>& output_padding_offset,
|
||||
int max_input_length);
|
||||
|
||||
@@ -264,7 +265,9 @@ void SpeculateVerify(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& topp,
|
||||
int max_seq_len,
|
||||
int verify_window,
|
||||
bool enable_topp);
|
||||
bool enable_topp,
|
||||
bool benchmark_mode,
|
||||
bool accept_all_drafts);
|
||||
|
||||
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
@@ -285,21 +288,23 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& seq_lens_encoder_record,
|
||||
const paddle::Tensor& seq_lens_decoder_record,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
const paddle::Tensor& base_model_stop_flags,
|
||||
const paddle::Tensor& base_model_is_block_step,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill);
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1);
|
||||
|
||||
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
@@ -324,18 +329,19 @@ std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
|
||||
const paddle::Tensor& step_idx);
|
||||
|
||||
std::vector<paddle::Tensor> GatherNextToken(
|
||||
const paddle::Tensor& tmp_out, // [token_num, dim_embed]
|
||||
const paddle::Tensor& x, // [token_num, dim_embed]
|
||||
const paddle::Tensor& cum_offsets, // [bsz, 1]
|
||||
const paddle::Tensor& encoder_seq_lod,
|
||||
const paddle::Tensor& decoder_seq_lod,
|
||||
const paddle::Tensor& encoder_batch_map,
|
||||
const paddle::Tensor& decoder_batch_map,
|
||||
const paddle::Tensor& encoder_seq_lod_cpu,
|
||||
const paddle::Tensor& decoder_seq_lod_cpu,
|
||||
const paddle::Tensor& encoder_batch_map_cpu,
|
||||
const paddle::Tensor& decoder_batch_map_cpu,
|
||||
const paddle::Tensor& enc_batch_tensor,
|
||||
const paddle::Tensor& dec_batch_tensor,
|
||||
const paddle::Tensor& len_info_cpu,
|
||||
const paddle::optional<paddle::Tensor>& output_padding_offset,
|
||||
int max_input_length);
|
||||
int max_bsz);
|
||||
|
||||
std::vector<paddle::Tensor> GetImgBoundaries(
|
||||
const paddle::Tensor& task_input_ids,
|
||||
@@ -436,6 +442,34 @@ void MTPStepPaddle(
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void SpeculateStepPaddle(
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& ori_seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor& encoder_block_lens,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& step_block_list,
|
||||
const paddle::Tensor& step_lens,
|
||||
const paddle::Tensor& recover_block_list,
|
||||
const paddle::Tensor& recover_lens,
|
||||
const paddle::Tensor& need_block_list,
|
||||
const paddle::Tensor& need_block_len,
|
||||
const paddle::Tensor& used_list_len,
|
||||
const paddle::Tensor& free_list,
|
||||
const paddle::Tensor& free_list_len,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& first_token_ids,
|
||||
const paddle::Tensor& accept_num,
|
||||
const int block_size,
|
||||
const int encoder_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void SaveOutMmsgStatic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
@@ -542,13 +576,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("x"),
|
||||
py::arg("cum_offsets"),
|
||||
py::arg("encoder_seq_lod"),
|
||||
py::arg("decoder_seq_lod"),
|
||||
py::arg("encoder_batch_idx"),
|
||||
py::arg("decoder_batch_idx"),
|
||||
py::arg("encoder_seq_lod_cpu"),
|
||||
py::arg("decoder_seq_lod_cpu"),
|
||||
py::arg("encoder_batch_idx_cpu"),
|
||||
py::arg("decoder_batch_idx_cpu"),
|
||||
py::arg("enc_batch_tensor"),
|
||||
py::arg("dec_batch_tensor"),
|
||||
py::arg("len_info_cpu"),
|
||||
py::arg("output_padding_offset"),
|
||||
py::arg("max_input_length"),
|
||||
"adjust batch in XPU");
|
||||
@@ -620,21 +655,23 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("step_idx"),
|
||||
py::arg("seq_lens_encoder_record"),
|
||||
py::arg("seq_lens_decoder_record"),
|
||||
py::arg("not_need_stop"),
|
||||
py::arg("is_block_step"),
|
||||
py::arg("batch_drop"),
|
||||
py::arg("pre_ids"),
|
||||
py::arg("accept_tokens"),
|
||||
py::arg("accept_num"),
|
||||
py::arg("base_model_seq_lens_this_time"),
|
||||
py::arg("base_model_seq_lens_encoder"),
|
||||
py::arg("base_model_seq_lens_decoder"),
|
||||
py::arg("base_model_step_idx"),
|
||||
py::arg("base_model_stop_flags"),
|
||||
py::arg("base_model_is_block_step"),
|
||||
py::arg("base_model_draft_tokens"),
|
||||
py::arg("max_draft_token"),
|
||||
py::arg("num_model_step"),
|
||||
py::arg("truncate_first_token"),
|
||||
py::arg("splitwise_prefill"),
|
||||
py::arg("kvcache_scheduler_v1"),
|
||||
"Preprocess data for draft model in speculative decoding");
|
||||
|
||||
m.def("draft_model_postprocess",
|
||||
@@ -727,18 +764,19 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("gather_next_token",
|
||||
&GatherNextToken,
|
||||
py::arg("tmp_out"),
|
||||
py::arg("x"),
|
||||
py::arg("cum_offsets"),
|
||||
py::arg("encoder_seq_lod"),
|
||||
py::arg("decoder_seq_lod"),
|
||||
py::arg("encoder_batch_map"),
|
||||
py::arg("decoder_batch_map"),
|
||||
py::arg("encoder_seq_lod_cpu"),
|
||||
py::arg("decoder_seq_lod_cpu"),
|
||||
py::arg("encoder_batch_map_cpu"),
|
||||
py::arg("decoder_batch_map_cpu"),
|
||||
py::arg("enc_batch_tensor"),
|
||||
py::arg("dec_batch_tensor"),
|
||||
py::arg("len_info_cpu"),
|
||||
py::arg("output_padding_offset"),
|
||||
py::arg("max_input_length"),
|
||||
py::arg("max_bsz"),
|
||||
"Gather next token for XPU");
|
||||
|
||||
m.def("get_img_boundaries",
|
||||
@@ -983,6 +1021,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("max_seq_len"),
|
||||
py::arg("verify_window"),
|
||||
py::arg("enable_topp"),
|
||||
py::arg("benchmark_mode"),
|
||||
py::arg("accept_all_drafts"),
|
||||
"Perform speculative verification for decoding");
|
||||
|
||||
m.def("speculate_clear_accept_nums",
|
||||
@@ -1104,6 +1144,36 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("encoder_decoder_block_num"),
|
||||
"Step paddle function");
|
||||
|
||||
m.def("speculate_step_paddle",
|
||||
&SpeculateStepPaddle,
|
||||
py::arg("stop_flags"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("ori_seq_lens_encoder"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("block_tables"),
|
||||
py::arg("encoder_block_lens"),
|
||||
py::arg("is_block_step"),
|
||||
py::arg("step_block_list"),
|
||||
py::arg("step_lens"),
|
||||
py::arg("recover_block_list"),
|
||||
py::arg("recover_lens"),
|
||||
py::arg("need_block_list"),
|
||||
py::arg("need_block_len"),
|
||||
py::arg("used_list_len"),
|
||||
py::arg("free_list"),
|
||||
py::arg("free_list_len"),
|
||||
py::arg("input_ids"),
|
||||
py::arg("pre_ids"),
|
||||
py::arg("step_idx"),
|
||||
py::arg("next_tokens"),
|
||||
py::arg("first_token_ids"),
|
||||
py::arg("accept_num"),
|
||||
py::arg("block_size"),
|
||||
py::arg("encoder_decoder_block_num"),
|
||||
py::arg("max_draft_tokens"),
|
||||
"Step paddle function");
|
||||
|
||||
m.def("text_image_gather_scatter",
|
||||
&TextImageGatherScatter,
|
||||
py::arg("input"),
|
||||
|
||||
@@ -75,6 +75,48 @@ DLL_EXPORT int get_padding_offset(Context* ctx,
|
||||
const int max_seq_len,
|
||||
const int bs);
|
||||
|
||||
DLL_EXPORT int speculate_get_padding_offset(Context* ctx,
|
||||
int* batch_id_per_token,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len,
|
||||
int bsz);
|
||||
|
||||
DLL_EXPORT int draft_model_preprocess(api::Context* ctx,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1);
|
||||
|
||||
DLL_EXPORT int update_inputs(Context* ctx,
|
||||
bool* not_need_stop,
|
||||
int* seq_lens_this_time,
|
||||
@@ -111,6 +153,31 @@ DLL_EXPORT int free_and_dispatch_block(Context* ctx,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num);
|
||||
|
||||
DLL_EXPORT int speculate_free_and_dispatch_block(
|
||||
Context* ctx,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_decoder,
|
||||
int* block_tables,
|
||||
int* encoder_block_lens,
|
||||
bool* is_block_step,
|
||||
int* step_block_list, // [bsz]
|
||||
int* step_len,
|
||||
int* recover_block_list,
|
||||
int* recover_len,
|
||||
int* need_block_list,
|
||||
int* need_block_len,
|
||||
int* used_list_len,
|
||||
int* free_list,
|
||||
int* free_list_len,
|
||||
int64_t* first_token_ids,
|
||||
int* accept_num,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
DLL_EXPORT int recover_block(Context* ctx,
|
||||
int* recover_block_list, // [bsz]
|
||||
int* recover_len,
|
||||
@@ -134,6 +201,29 @@ DLL_EXPORT int recover_block(Context* ctx,
|
||||
const int length,
|
||||
const int pre_id_length);
|
||||
|
||||
DLL_EXPORT int speculate_recover_block(Context* ctx,
|
||||
int* recover_block_list, // [bsz]
|
||||
int* recover_len,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
const int* ori_seq_lens_encoder,
|
||||
int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
int* block_tables,
|
||||
int* free_list,
|
||||
int* free_list_len,
|
||||
int64_t* input_ids,
|
||||
const int64_t* pre_ids,
|
||||
const int64_t* step_idx,
|
||||
const int* encoder_block_lens,
|
||||
const int* used_list_len,
|
||||
const int64_t* next_tokens,
|
||||
const int64_t* first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length);
|
||||
|
||||
DLL_EXPORT int recover_decode_task(Context* ctx,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
@@ -172,6 +262,7 @@ DLL_EXPORT int eb_adjust_batch(
|
||||
const TX* x,
|
||||
TY* y,
|
||||
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
@@ -186,6 +277,17 @@ DLL_EXPORT int eb_gather_next_token(
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
|
||||
template <typename TX, typename TY>
|
||||
DLL_EXPORT int eb_mtp_gather_next_token(
|
||||
Context* ctx,
|
||||
const TX* x,
|
||||
TY* y,
|
||||
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
|
||||
template <typename TX, typename TSCALE = float, typename TY = int8_t>
|
||||
DLL_EXPORT int quant2d_per_channel(api::Context* ctx,
|
||||
const TX* x,
|
||||
@@ -305,7 +407,8 @@ DLL_EXPORT int speculate_verify(Context* ctx,
|
||||
const int max_seq_len,
|
||||
const int max_candidate_len,
|
||||
const int verify_window,
|
||||
const bool prefill_one_step_stop);
|
||||
const bool prefill_one_step_stop,
|
||||
const bool benchmark_mode);
|
||||
|
||||
DLL_EXPORT int speculate_clear_accept_nums(Context* ctx,
|
||||
int* accept_num,
|
||||
@@ -342,35 +445,6 @@ DLL_EXPORT int draft_model_update(Context* ctx,
|
||||
const int substep,
|
||||
const bool prefill_one_step_stop);
|
||||
|
||||
DLL_EXPORT int draft_model_preprocess(api::Context* ctx,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
int real_bsz,
|
||||
int max_draft_token,
|
||||
int accept_tokens_len,
|
||||
int draft_tokens_len,
|
||||
int input_ids_len,
|
||||
int base_model_draft_tokens_len,
|
||||
bool truncate_first_token,
|
||||
bool splitwise_prefill);
|
||||
|
||||
DLL_EXPORT int speculate_set_stop_value_multi_seqs(Context* ctx,
|
||||
bool* stop_flags,
|
||||
int64_t* accept_tokens,
|
||||
@@ -411,16 +485,6 @@ DLL_EXPORT int speculate_remove_padding(Context* ctx,
|
||||
int bsz,
|
||||
int token_num_data);
|
||||
|
||||
DLL_EXPORT int speculate_get_padding_offset(Context* ctx,
|
||||
int* padding_offset,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len,
|
||||
int bsz);
|
||||
|
||||
DLL_EXPORT int compute_self_order(api::Context* ctx,
|
||||
const int* last_seq_lens_this_time,
|
||||
const int* seq_lens_this_time,
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
#define MAX_LM_SIZE 28672
|
||||
// One core has 32KB LM(group LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// One core has 32KB LM(gropu LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// the stack space
|
||||
#define MAX_BATCH 512
|
||||
#define ALIGNMENT 64
|
||||
@@ -53,6 +53,7 @@ template <typename TX, typename TY>
|
||||
__global__ void eb_adjust_batch(TX* src,
|
||||
TY* dst,
|
||||
int* encoder_seqs_lods,
|
||||
int* decoder_seqs_lods,
|
||||
int* encoder_batch_map,
|
||||
int* decoder_batch_map,
|
||||
int en_batch,
|
||||
@@ -61,9 +62,11 @@ __global__ void eb_adjust_batch(TX* src,
|
||||
int tid = core_id() * cluster_num() + cluster_id();
|
||||
int nthreads = core_num() * cluster_num();
|
||||
__group_shared__ int local_lods_en[MAX_BATCH + 1];
|
||||
__group_shared__ int local_lods_de[MAX_BATCH + 1];
|
||||
__group_shared__ int local_map_en[MAX_BATCH];
|
||||
__group_shared__ int local_map_de[MAX_BATCH];
|
||||
GM2GSM_ASYNC(encoder_seqs_lods, local_lods_en, (en_batch + 1) * sizeof(int));
|
||||
GM2GSM_ASYNC(decoder_seqs_lods, local_lods_de, (de_batch + 1) * sizeof(int));
|
||||
if (en_batch > 0) {
|
||||
GM2GSM_ASYNC(encoder_batch_map, local_map_en, en_batch * sizeof(int));
|
||||
}
|
||||
@@ -72,7 +75,8 @@ __global__ void eb_adjust_batch(TX* src,
|
||||
}
|
||||
mfence();
|
||||
int max_encoder_len = local_lods_en[en_batch];
|
||||
int seq_sum = max_encoder_len + de_batch;
|
||||
int max_decoder_len = local_lods_de[de_batch];
|
||||
int seq_sum = max_encoder_len + max_decoder_len;
|
||||
int total_batch = en_batch + de_batch;
|
||||
int start = 0;
|
||||
int end = 0;
|
||||
@@ -82,13 +86,16 @@ __global__ void eb_adjust_batch(TX* src,
|
||||
while (i < end) {
|
||||
if (i >= max_encoder_len) {
|
||||
// dst decode part
|
||||
int cur_de_bs = i - max_encoder_len;
|
||||
int cur_de_bs = 0;
|
||||
get_cur_batch(local_lods_de, de_batch, i - max_encoder_len, cur_de_bs);
|
||||
int cur_en_bs = local_map_de[cur_de_bs] - cur_de_bs;
|
||||
int cur_len =
|
||||
min(end, local_lods_de[cur_de_bs + 1] + max_encoder_len) - i;
|
||||
_global_ptr_ TY* cur_dst = dst + i * copy_size;
|
||||
_global_ptr_ TX* cur_src =
|
||||
src + (cur_de_bs + local_lods_en[cur_en_bs]) * copy_size;
|
||||
do_memcpy_1d<TX, TY>(cur_src, cur_dst, copy_size);
|
||||
i++;
|
||||
src + (local_lods_en[cur_en_bs] + i - max_encoder_len) * copy_size;
|
||||
do_memcpy_1d<TX, TY>(cur_src, cur_dst, copy_size * cur_len);
|
||||
i += cur_len;
|
||||
} else {
|
||||
// dst encode part
|
||||
int cur_en_bs = 0;
|
||||
@@ -97,7 +104,8 @@ __global__ void eb_adjust_batch(TX* src,
|
||||
cur_de_bs = local_map_en[cur_en_bs] - cur_en_bs;
|
||||
int cur_len = min(end, local_lods_en[cur_en_bs + 1]) - i;
|
||||
_global_ptr_ TY* cur_dst = dst + i * copy_size;
|
||||
_global_ptr_ TX* cur_src = src + (cur_de_bs + i) * copy_size;
|
||||
_global_ptr_ TX* cur_src =
|
||||
src + (local_lods_de[cur_de_bs] + i) * copy_size;
|
||||
do_memcpy_1d<TX, TY>(cur_src, cur_dst, copy_size * cur_len);
|
||||
i += cur_len;
|
||||
}
|
||||
@@ -108,6 +116,7 @@ __global__ void eb_adjust_batch(TX* src,
|
||||
template __global__ void eb_adjust_batch<TX, TY>(TX * src, \
|
||||
TY * dst, \
|
||||
int* encoder_seqs_lods, \
|
||||
int* decoder_seqs_lods, \
|
||||
int* encoder_batch_map, \
|
||||
int* decoder_batch_map, \
|
||||
int en_batch, \
|
||||
|
||||
@@ -20,6 +20,7 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time,
|
||||
return;
|
||||
}
|
||||
|
||||
// 256 * int
|
||||
char lm[6 * 1024];
|
||||
int buf_size = 6 * 1024 / (6 * sizeof(int));
|
||||
int* lm_base_model_seq_lens_this_time = (int*)lm;
|
||||
@@ -68,10 +69,7 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time,
|
||||
in_offset += write_size;
|
||||
}
|
||||
mfence_lm();
|
||||
// 2. base model encoder. Base step=0
|
||||
} else if (cur_base_model_seq_lens_encoder != 0) {
|
||||
// nothing happens
|
||||
// 3. New end
|
||||
// 2. Base model stop at last verify-step.
|
||||
} else if (cur_base_model_seq_lens_this_time != 0 &&
|
||||
cur_seq_lens_this_time == 0) {
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
@@ -80,27 +78,16 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time,
|
||||
cur_seq_lens_this_time == 0) {
|
||||
// nothing happens
|
||||
} else {
|
||||
if (accept_num <= actual_draft_token_num) {
|
||||
int position_map_val = out_offset;
|
||||
LM2GM(&position_map_val,
|
||||
position_map + in_offset + accept_num - 1,
|
||||
sizeof(int));
|
||||
out_offset++;
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
} else {
|
||||
int position_map_val_1 = out_offset;
|
||||
LM2GM(&position_map_val_1,
|
||||
position_map + in_offset + accept_num - 2,
|
||||
sizeof(int));
|
||||
out_offset++;
|
||||
int position_map_val_2 = out_offset;
|
||||
LM2GM(&position_map_val_2,
|
||||
position_map + in_offset + accept_num - 1,
|
||||
sizeof(int));
|
||||
out_offset++;
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
// accept_num << buf_size, so do not need split
|
||||
for (int i = 0; i < accept_num; i++) {
|
||||
lm_position_map[i] = out_offset++;
|
||||
}
|
||||
mfence_lm();
|
||||
LM2GM(lm_position_map,
|
||||
position_map + in_offset,
|
||||
accept_num * sizeof(int));
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
mfence_lm();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,26 +13,29 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
int real_bsz,
|
||||
int max_draft_token,
|
||||
int accept_tokens_len,
|
||||
int draft_tokens_len,
|
||||
int input_ids_len,
|
||||
int base_model_draft_tokens_len,
|
||||
bool truncate_first_token,
|
||||
bool splitwise_prefill) {
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
int cid = core_id();
|
||||
int ncores = core_num();
|
||||
int clusterid = cluster_id();
|
||||
@@ -46,7 +49,7 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
|
||||
int64_t value_fu = -1;
|
||||
|
||||
if (splitwise_prefill) {
|
||||
for (; tid < real_bsz; tid += ncores * nclusters) {
|
||||
for (; tid < bsz; tid += ncores * nclusters) {
|
||||
int64_t base_model_step_idx_now = 0;
|
||||
int seq_lens_encoder_now = 0;
|
||||
int seq_lens_this_time_now = 0;
|
||||
@@ -57,35 +60,25 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
|
||||
|
||||
GM2LM_ASYNC(
|
||||
base_model_step_idx + tid, &base_model_step_idx_now, sizeof(int64_t));
|
||||
GM2LM_ASYNC(seq_lens_encoder_record + tid,
|
||||
&seq_lens_encoder_record_now,
|
||||
sizeof(int));
|
||||
GM2LM_ASYNC(seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int));
|
||||
GM2LM(accept_tokens + tid * accept_tokens_len,
|
||||
&base_model_first_token,
|
||||
sizeof(int64_t));
|
||||
|
||||
if (base_model_step_idx_now == 1 && seq_lens_encoder_record_now > 0) {
|
||||
if (seq_lens_encoder_now > 0) {
|
||||
not_stop_flag_sm[cid] += 1;
|
||||
int seq_len_encoder_record = seq_lens_encoder_record_now;
|
||||
seq_lens_encoder_now = seq_len_encoder_record;
|
||||
seq_lens_encoder_record_now = -1;
|
||||
stop_flags_now = false;
|
||||
int position = seq_len_encoder_record;
|
||||
int position = seq_lens_encoder_now;
|
||||
if (truncate_first_token) {
|
||||
position = position - 1;
|
||||
input_ids_now = base_model_first_token;
|
||||
seq_lens_this_time_now = seq_len_encoder_record;
|
||||
seq_lens_this_time_now = seq_lens_encoder_now;
|
||||
} else {
|
||||
input_ids_now = base_model_first_token;
|
||||
seq_lens_this_time_now = seq_len_encoder_record + 1;
|
||||
seq_lens_this_time_now = seq_lens_encoder_now + 1;
|
||||
}
|
||||
LM2GM_ASYNC(&input_ids_now,
|
||||
input_ids + tid * input_ids_len + position,
|
||||
sizeof(int64_t));
|
||||
LM2GM_ASYNC(&seq_lens_encoder_record_now,
|
||||
seq_lens_encoder_record + tid,
|
||||
sizeof(int));
|
||||
|
||||
} else {
|
||||
stop_flags_now = true;
|
||||
seq_lens_this_time_now = 0;
|
||||
@@ -98,21 +91,23 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
|
||||
LM2GM(&seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int));
|
||||
}
|
||||
} else {
|
||||
for (; tid < real_bsz; tid += ncores * nclusters) {
|
||||
for (; tid < bsz; tid += ncores * nclusters) {
|
||||
bool base_model_stop_flags_now = false;
|
||||
bool base_model_is_block_step_now = false;
|
||||
bool batch_drop_now = false;
|
||||
bool stop_flags_now = false;
|
||||
bool is_block_step_now = false;
|
||||
int seq_lens_this_time_now = 0;
|
||||
int seq_lens_encoder_record_now = 0;
|
||||
int seq_lens_encoder_now = 0;
|
||||
int seq_lens_decoder_new = 0;
|
||||
int seq_lens_decoder_record_now = 0;
|
||||
int accept_num_now = 0;
|
||||
int base_model_seq_lens_decoder_now = 0;
|
||||
int base_model_seq_lens_this_time_now = 0;
|
||||
int64_t step_id_now = 0;
|
||||
int64_t base_model_step_idx_now;
|
||||
int64_t pre_ids_now;
|
||||
mfence();
|
||||
GM2LM_ASYNC(is_block_step + tid, &is_block_step_now, sizeof(bool));
|
||||
GM2LM_ASYNC(base_model_stop_flags + tid,
|
||||
&base_model_stop_flags_now,
|
||||
sizeof(bool));
|
||||
@@ -121,12 +116,6 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
|
||||
sizeof(bool));
|
||||
GM2LM_ASYNC(batch_drop + tid, &batch_drop_now, sizeof(bool));
|
||||
GM2LM_ASYNC(stop_flags + tid, &stop_flags_now, sizeof(bool));
|
||||
GM2LM_ASYNC(seq_lens_encoder_record + tid,
|
||||
&seq_lens_encoder_record_now,
|
||||
sizeof(int));
|
||||
GM2LM_ASYNC(seq_lens_decoder_record + tid,
|
||||
&seq_lens_decoder_record_now,
|
||||
sizeof(int));
|
||||
GM2LM_ASYNC(seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int));
|
||||
GM2LM_ASYNC(seq_lens_decoder + tid, &seq_lens_decoder_new, sizeof(int));
|
||||
|
||||
@@ -135,6 +124,9 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
|
||||
accept_tokens_len * sizeof(int64_t));
|
||||
GM2LM_ASYNC(accept_num + tid, &accept_num_now, sizeof(int));
|
||||
|
||||
GM2LM_ASYNC(base_model_seq_lens_this_time + tid,
|
||||
&base_model_seq_lens_this_time_now,
|
||||
sizeof(int));
|
||||
GM2LM_ASYNC(base_model_seq_lens_decoder + tid,
|
||||
&base_model_seq_lens_decoder_now,
|
||||
sizeof(int));
|
||||
@@ -148,57 +140,67 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
|
||||
base_model_draft_tokens + tid * base_model_draft_tokens_len + i,
|
||||
sizeof(int));
|
||||
}
|
||||
if (base_model_stop_flags_now && base_model_is_block_step_now) {
|
||||
batch_drop_now = true;
|
||||
stop_flags_now = true;
|
||||
if (kvcache_scheduler_v1) {
|
||||
if (base_model_stop_flags_now && base_model_is_block_step_now) {
|
||||
stop_flags_now = true;
|
||||
is_block_step_now = true;
|
||||
}
|
||||
} else {
|
||||
if (base_model_stop_flags_now && base_model_is_block_step_now) {
|
||||
batch_drop_now = true;
|
||||
stop_flags_now = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!(base_model_stop_flags_now || batch_drop_now)) {
|
||||
not_stop_flag_sm[cid] += 1;
|
||||
if (base_model_step_idx_now == 0) {
|
||||
seq_lens_this_time_now = 0;
|
||||
not_stop_flag_sm[cid] -= 1; // 因为上面加过,这次减去,符合=0逻辑
|
||||
} else if (base_model_step_idx_now == 1 &&
|
||||
seq_lens_encoder_record_now > 0) {
|
||||
int seq_len_encoder_record = seq_lens_encoder_record_now;
|
||||
seq_lens_encoder_now = seq_len_encoder_record;
|
||||
seq_lens_encoder_record_now = -1;
|
||||
seq_lens_decoder_new = seq_lens_decoder_record_now;
|
||||
seq_lens_decoder_record_now = 0;
|
||||
if (seq_lens_encoder_now > 0) {
|
||||
int seq_len_encoder = seq_lens_encoder_now;
|
||||
stop_flags_now = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
int position = seq_len_encoder_record;
|
||||
LM2GM(&base_model_first_token,
|
||||
pre_ids + tid * pre_ids_len,
|
||||
sizeof(int64_t));
|
||||
int position = seq_len_encoder;
|
||||
if (truncate_first_token) {
|
||||
LM2GM(&base_model_first_token,
|
||||
input_ids + tid * input_ids_len + position - 1,
|
||||
sizeof(int64_t));
|
||||
seq_lens_this_time_now = seq_len_encoder_record;
|
||||
seq_lens_this_time_now = seq_len_encoder;
|
||||
} else {
|
||||
LM2GM(&base_model_first_token,
|
||||
input_ids + tid * input_ids_len + position,
|
||||
sizeof(int64_t));
|
||||
seq_lens_this_time_now = seq_len_encoder_record + 1;
|
||||
seq_lens_this_time_now = seq_len_encoder + 1;
|
||||
}
|
||||
} else {
|
||||
if (kvcache_scheduler_v1) {
|
||||
if (!base_model_is_block_step_now && is_block_step_now) {
|
||||
is_block_step_now = false;
|
||||
}
|
||||
}
|
||||
} else if (accept_num_now <= max_draft_token) {
|
||||
if (stop_flags_now) {
|
||||
stop_flags_now = false;
|
||||
seq_lens_decoder_new = base_model_seq_lens_decoder_now;
|
||||
step_id_now = base_model_step_idx_now;
|
||||
} else {
|
||||
seq_lens_decoder_new -= max_draft_token - accept_num_now;
|
||||
step_id_now -= max_draft_token - accept_num_now;
|
||||
}
|
||||
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
|
||||
LM2GM(&modified_token,
|
||||
draft_tokens + tid * draft_tokens_len,
|
||||
sizeof(int64_t));
|
||||
seq_lens_this_time_now = 1;
|
||||
seq_lens_decoder_new = base_model_seq_lens_decoder_now -
|
||||
base_model_seq_lens_this_time_now;
|
||||
step_id_now =
|
||||
base_model_step_idx_now - base_model_seq_lens_this_time_now;
|
||||
|
||||
} else /*Accept all draft tokens*/ {
|
||||
LM2GM(accept_tokens_now + max_draft_token,
|
||||
draft_tokens + tid * draft_tokens_len + 1,
|
||||
sizeof(int64_t));
|
||||
seq_lens_this_time_now = 2;
|
||||
} else {
|
||||
seq_lens_decoder_new -= num_model_step - 1;
|
||||
step_id_now -= num_model_step - 1;
|
||||
}
|
||||
for (int i = 0; i < accept_num_now; i++) {
|
||||
const int pre_id_pos =
|
||||
base_model_step_idx_now - (accept_num_now - i);
|
||||
LM2GM(accept_tokens_now + i,
|
||||
draft_tokens + tid * draft_tokens_len + i,
|
||||
sizeof(int64_t));
|
||||
LM2GM(accept_tokens_now + i,
|
||||
pre_ids + tid * pre_ids_len + pre_id_pos,
|
||||
sizeof(int64_t));
|
||||
}
|
||||
seq_lens_this_time_now = accept_num_now;
|
||||
}
|
||||
|
||||
} else {
|
||||
@@ -209,17 +211,11 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens,
|
||||
}
|
||||
LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool));
|
||||
LM2GM_ASYNC(&batch_drop_now, batch_drop + tid, sizeof(bool));
|
||||
|
||||
LM2GM_ASYNC(&is_block_step_now, is_block_step + tid, sizeof(bool));
|
||||
LM2GM_ASYNC(&seq_lens_decoder_new, seq_lens_decoder + tid, sizeof(int));
|
||||
LM2GM_ASYNC(
|
||||
&seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int));
|
||||
LM2GM_ASYNC(&seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int));
|
||||
LM2GM_ASYNC(&seq_lens_encoder_record_now,
|
||||
seq_lens_encoder_record + tid,
|
||||
sizeof(int));
|
||||
LM2GM_ASYNC(&seq_lens_decoder_record_now,
|
||||
seq_lens_decoder_record + tid,
|
||||
sizeof(int));
|
||||
LM2GM_ASYNC(&step_id_now, step_idx + tid, sizeof(int64_t));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,10 +60,8 @@ __global__ void draft_model_update(const int64_t* inter_next_tokens,
|
||||
token_this_time = next_tokens_start[seq_len_this_time - 1];
|
||||
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
|
||||
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
||||
for (int i = 0; i < seq_len_this_time; ++i) {
|
||||
pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i];
|
||||
}
|
||||
step_idx[tid] += seq_len_this_time;
|
||||
pre_ids_now[step_idx[tid]] = token_this_time;
|
||||
} else {
|
||||
token_this_time = next_tokens_start[0];
|
||||
seq_lens_decoder[tid] = seq_len_encoder + seq_len_decoder;
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
#include "xpu/kernel/cluster.h"
|
||||
#include "xpu/kernel/cluster_debug.h"
|
||||
#include "xpu/kernel/cluster_primitive.h"
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
#define MAX_LM_SIZE 28672
|
||||
// One core has 32KB LM(group LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// the stack space
|
||||
#define MAX_BATCH 512
|
||||
#define ALIGNMENT 64
|
||||
|
||||
template <typename TX, typename TY>
|
||||
static __device__ void do_memcpy_1d(_global_ptr_ TX* src,
|
||||
_global_ptr_ TY* dst,
|
||||
int64_t copy_size) {
|
||||
#ifdef __XPU3__
|
||||
constexpr int buf_size = 2048;
|
||||
#else
|
||||
constexpr int buf_size = 512;
|
||||
#endif
|
||||
__group_shared__ __simd__ float double_lmx[2][buf_size];
|
||||
int64_t pingpong = 0;
|
||||
for (int64_t i = 0; i < copy_size; i += buf_size) {
|
||||
int real_size = min<int64_t>(buf_size, copy_size - i);
|
||||
_group_shared_ptr_ float* lmx = double_lmx[pingpong];
|
||||
GM2GSM(src + i, lmx, real_size * sizeof(TX));
|
||||
if (!xpu_std::is_same<TX, TY>::value) {
|
||||
primitive_cast_gsm<TX, float>(
|
||||
(_group_shared_ptr_ TX*)lmx, lmx, real_size);
|
||||
primitive_cast_gsm<float, TY>(
|
||||
lmx, (_group_shared_ptr_ TY*)lmx, real_size);
|
||||
}
|
||||
GSM2GM_ASYNC((_group_shared_ptr_ TY*)lmx, dst + i, real_size * sizeof(TY));
|
||||
pingpong = 1 - pingpong;
|
||||
}
|
||||
mfence();
|
||||
}
|
||||
|
||||
template <typename TX, typename TY>
|
||||
__global__ void eb_mtp_gather_next_token(TX* src,
|
||||
TY* dst,
|
||||
int* encoder_seqs_lods,
|
||||
int* decoder_seqs_lods,
|
||||
int* encoder_batch_map,
|
||||
int* decoder_batch_map,
|
||||
int en_batch,
|
||||
int de_batch,
|
||||
int64_t copy_size) {
|
||||
int tid = core_id() * cluster_num() + cluster_id();
|
||||
int nthreads = core_num() * cluster_num();
|
||||
__group_shared__ int local_lods_en[MAX_BATCH + 1];
|
||||
__group_shared__ int local_lods_de[MAX_BATCH + 1];
|
||||
__group_shared__ int local_map_en[MAX_BATCH];
|
||||
__group_shared__ int local_map_de[MAX_BATCH];
|
||||
GM2GSM_ASYNC(encoder_seqs_lods, local_lods_en, (en_batch + 1) * sizeof(int));
|
||||
GM2GSM_ASYNC(decoder_seqs_lods, local_lods_de, (de_batch + 1) * sizeof(int));
|
||||
if (en_batch > 0) {
|
||||
GM2GSM_ASYNC(encoder_batch_map, local_map_en, en_batch * sizeof(int));
|
||||
}
|
||||
if (de_batch > 0) {
|
||||
GM2GSM_ASYNC(decoder_batch_map, local_map_de, de_batch * sizeof(int));
|
||||
}
|
||||
mfence();
|
||||
int encoder_len_total = en_batch > 0 ? local_lods_en[en_batch] : 0;
|
||||
int output_len = en_batch + local_lods_de[de_batch];
|
||||
int start = 0;
|
||||
int end = 0;
|
||||
partition(tid, nthreads, output_len, 1, &start, &end);
|
||||
for (int i = start; i < end; i++) {
|
||||
int len = 0;
|
||||
int enc_idx = 0, dec_idx = 0;
|
||||
bool is_enc;
|
||||
while (i >= len) {
|
||||
if (enc_idx >= en_batch) {
|
||||
len += local_lods_de[dec_idx + 1] - local_lods_de[dec_idx];
|
||||
dec_idx++;
|
||||
is_enc = false;
|
||||
continue;
|
||||
}
|
||||
if (dec_idx >= de_batch) {
|
||||
len += 1;
|
||||
enc_idx++;
|
||||
is_enc = true;
|
||||
continue;
|
||||
}
|
||||
if (local_map_en[enc_idx] < local_map_de[dec_idx]) {
|
||||
len += 1;
|
||||
enc_idx++;
|
||||
is_enc = true;
|
||||
} else {
|
||||
len += local_lods_de[dec_idx + 1] - local_lods_de[dec_idx];
|
||||
dec_idx++;
|
||||
is_enc = false;
|
||||
}
|
||||
}
|
||||
_global_ptr_ TX* cur_src = nullptr;
|
||||
_global_ptr_ TY* cur_dst = dst + i * copy_size;
|
||||
if (is_enc) {
|
||||
cur_src = src + (local_lods_en[enc_idx] - 1) * copy_size;
|
||||
} else {
|
||||
cur_src = src + (encoder_len_total + local_lods_de[dec_idx] - (len - i)) *
|
||||
copy_size;
|
||||
}
|
||||
do_memcpy_1d<TX, TY>(cur_src, cur_dst, copy_size);
|
||||
}
|
||||
}
|
||||
#define _XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(TX, TY) \
|
||||
template __global__ void eb_mtp_gather_next_token<TX, TY>( \
|
||||
TX * src, \
|
||||
TY * dst, \
|
||||
int* encoder_seqs_lods, \
|
||||
int* decoder_seqs_lods, \
|
||||
int* encoder_batch_map, \
|
||||
int* decoder_batch_map, \
|
||||
int en_batch, \
|
||||
int de_batch, \
|
||||
int64_t copy_size);
|
||||
|
||||
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float16, float16);
|
||||
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(bfloat16, bfloat16);
|
||||
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float, float);
|
||||
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float16, float);
|
||||
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float, float16);
|
||||
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float16);
|
||||
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float16, bfloat16);
|
||||
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float);
|
||||
_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float, bfloat16);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
@@ -0,0 +1,337 @@
|
||||
#include "xpu/kernel/cluster.h"
|
||||
#include "xpu/kernel/cluster_partition.h"
|
||||
#include "xpu/kernel/cluster_primitive.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
static __device__ inline int loada_float(_shared_ptr_ const int *ptr) {
|
||||
int ret;
|
||||
__asm__ __volatile__("loada.w %0,%1" : "=&r"(ret) : "r"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ inline bool storea_float(_shared_ptr_ int *ptr, int value) {
|
||||
bool ret;
|
||||
__asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(ret) : "r"(value), "r"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ int atomic_add(_shared_ptr_ int *ptr, int value) {
|
||||
bool fail = true;
|
||||
int old_value;
|
||||
while (fail) {
|
||||
old_value = loada_float(ptr);
|
||||
int new_value = old_value + value;
|
||||
fail = storea_float(ptr, new_value);
|
||||
}
|
||||
return old_value;
|
||||
}
|
||||
|
||||
static __device__ bool in_need_block_list(const int qid,
|
||||
_shared_ptr_ int *need_block_list,
|
||||
const int need_block_len) {
|
||||
bool res = false;
|
||||
for (int i = 0; i < need_block_len; i++) {
|
||||
if (qid == need_block_list[i]) {
|
||||
need_block_list[i] = -1;
|
||||
res = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
__global__ void speculate_free_and_dispatch_block(
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len,
|
||||
int *recover_block_list,
|
||||
int *recover_len,
|
||||
int *need_block_list,
|
||||
int *need_block_len,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *first_token_ids,
|
||||
int *accept_num,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num,
|
||||
const int max_draft_tokens) {
|
||||
int cid = core_id();
|
||||
int ncores = core_num();
|
||||
int clusterid = cluster_id();
|
||||
if (clusterid != 0 || cid >= bsz) return;
|
||||
|
||||
// assert bsz <= 640
|
||||
const int max_bs = 640;
|
||||
int value_zero = 0;
|
||||
bool flag_true = true;
|
||||
|
||||
// 128 = seq_len(8192) / block_size(64)
|
||||
// 每次最多处理block_table数量为128
|
||||
const int block_table_now_len = 128;
|
||||
int block_table_now[block_table_now_len];
|
||||
for (int i = 0; i < block_table_now_len; i++) {
|
||||
block_table_now[i] = -1;
|
||||
}
|
||||
bool stop_flag_lm;
|
||||
int seq_lens_decoder_lm;
|
||||
|
||||
__shared__ int free_list_len_sm;
|
||||
// 每次最多处理free_list数量为block_table_now_len
|
||||
int free_list_now[block_table_now_len];
|
||||
__shared__ int need_block_len_sm;
|
||||
__shared__ int need_block_list_sm[max_bs];
|
||||
__shared__ int used_list_len_sm[max_bs];
|
||||
__shared__ bool step_max_block_flag;
|
||||
__shared__ int in_need_block_list_len;
|
||||
if (cid == 0) {
|
||||
step_max_block_flag = false;
|
||||
in_need_block_list_len = 0;
|
||||
GM2SM_ASYNC(free_list_len, &free_list_len_sm, sizeof(int));
|
||||
GM2SM_ASYNC(need_block_len, &need_block_len_sm, sizeof(int));
|
||||
mfence();
|
||||
if (need_block_len_sm > 0) {
|
||||
GM2SM_ASYNC(
|
||||
need_block_list, need_block_list_sm, sizeof(int) * need_block_len_sm);
|
||||
}
|
||||
GM2SM_ASYNC(used_list_len, used_list_len_sm, sizeof(int) * bsz);
|
||||
mfence();
|
||||
}
|
||||
sync_cluster();
|
||||
|
||||
for (int tid = cid; tid < bsz; tid += ncores) {
|
||||
bool is_block_step_lm;
|
||||
int seq_lens_this_time_lm;
|
||||
mfence();
|
||||
GM2LM_ASYNC(stop_flags + tid, &stop_flag_lm, sizeof(bool));
|
||||
GM2LM_ASYNC(is_block_step + tid, &is_block_step_lm, sizeof(bool));
|
||||
GM2LM_ASYNC(seq_lens_decoder + tid, &seq_lens_decoder_lm, sizeof(int));
|
||||
GM2LM_ASYNC(seq_lens_this_time + tid, &seq_lens_this_time_lm, sizeof(int));
|
||||
mfence();
|
||||
int max_possible_block_idx =
|
||||
(seq_lens_decoder_lm + max_draft_tokens + 1) / block_size;
|
||||
if (stop_flag_lm && !is_block_step_lm) {
|
||||
// 回收block块
|
||||
int64_t first_token_id_lm = -1;
|
||||
mfence_lm();
|
||||
LM2GM(&first_token_id_lm, first_token_ids + tid, sizeof(int64_t));
|
||||
int encoder_block_len_lm;
|
||||
int decoder_used_len_lm = used_list_len_sm[tid];
|
||||
GM2LM(encoder_block_lens + tid, &encoder_block_len_lm, sizeof(int));
|
||||
if (decoder_used_len_lm > 0) {
|
||||
const int ori_free_list_len =
|
||||
atomic_add(&free_list_len_sm, decoder_used_len_lm);
|
||||
for (int i = 0; i < decoder_used_len_lm; i += block_table_now_len) {
|
||||
int process_len = min(block_table_now_len, decoder_used_len_lm - i);
|
||||
GM2LM(
|
||||
block_tables + tid * block_num_per_seq + encoder_block_len_lm + i,
|
||||
free_list_now,
|
||||
process_len * sizeof(int));
|
||||
LM2GM(free_list_now,
|
||||
free_list + ori_free_list_len + i,
|
||||
process_len * sizeof(int));
|
||||
LM2GM(
|
||||
block_table_now,
|
||||
block_tables + tid * block_num_per_seq + encoder_block_len_lm + i,
|
||||
process_len * sizeof(int));
|
||||
}
|
||||
used_list_len_sm[tid] = 0;
|
||||
mfence();
|
||||
LM2GM(&value_zero, encoder_block_lens + tid, sizeof(int));
|
||||
}
|
||||
} else if (seq_lens_this_time_lm != 0 &&
|
||||
max_possible_block_idx < block_num_per_seq) {
|
||||
int next_block_id;
|
||||
GM2LM(block_tables + tid * block_num_per_seq +
|
||||
(seq_lens_decoder_lm + max_draft_tokens + 1) / block_size,
|
||||
&next_block_id,
|
||||
sizeof(int));
|
||||
if (next_block_id == -1) {
|
||||
// 统计需要分配block的位置和总数
|
||||
const int ori_need_block_len = atomic_add(&need_block_len_sm, 1);
|
||||
need_block_list_sm[ori_need_block_len] = tid;
|
||||
}
|
||||
}
|
||||
}
|
||||
sync_cluster();
|
||||
|
||||
bool is_block_step_lm[max_bs];
|
||||
int step_len_lm;
|
||||
int step_block_list_lm[max_bs];
|
||||
int recover_len_lm;
|
||||
int recover_block_list_lm[max_bs];
|
||||
if (cid == 0) {
|
||||
GM2LM_ASYNC(is_block_step, is_block_step_lm, sizeof(bool) * bsz);
|
||||
GM2LM_ASYNC(step_len, &step_len_lm, sizeof(int));
|
||||
GM2LM_ASYNC(step_block_list, step_block_list_lm, sizeof(int) * bsz);
|
||||
GM2LM_ASYNC(recover_len, &recover_len_lm, sizeof(int));
|
||||
GM2LM_ASYNC(recover_block_list, recover_block_list_lm, sizeof(int) * bsz);
|
||||
mfence();
|
||||
}
|
||||
|
||||
if (cid == 0) {
|
||||
while (need_block_len_sm > free_list_len_sm) {
|
||||
// 调度block,根据used_list_len从大到小回收block,直到满足need_block_len,已解码到最后一个block的query不参与调度(马上就结束)
|
||||
int max_used_list_len_id = 0;
|
||||
int max_used_list_len = 0;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
if (!is_block_step_lm[i] &&
|
||||
(step_max_block_flag ||
|
||||
used_list_len_sm[i] != max_decoder_block_num) &&
|
||||
(used_list_len_sm[i] > max_used_list_len)) {
|
||||
max_used_list_len_id = i;
|
||||
max_used_list_len = used_list_len_sm[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (max_used_list_len == 0) {
|
||||
step_max_block_flag = true;
|
||||
} else {
|
||||
int encoder_block_len;
|
||||
GM2LM(encoder_block_lens + max_used_list_len_id,
|
||||
&encoder_block_len,
|
||||
sizeof(int));
|
||||
for (int i = 0; i < max_used_list_len; i += block_table_now_len) {
|
||||
int process_len = min(block_table_now_len, max_used_list_len - i);
|
||||
GM2LM(block_tables + max_used_list_len_id * block_num_per_seq +
|
||||
encoder_block_len + i,
|
||||
free_list_now,
|
||||
process_len * sizeof(int));
|
||||
LM2GM(free_list_now,
|
||||
free_list + free_list_len_sm + i,
|
||||
process_len * sizeof(int));
|
||||
LM2GM(block_table_now,
|
||||
block_tables + max_used_list_len_id * block_num_per_seq +
|
||||
encoder_block_len + i,
|
||||
process_len * sizeof(int));
|
||||
}
|
||||
step_block_list_lm[step_len_lm] = max_used_list_len_id;
|
||||
int need_block_len_all = need_block_len_sm + in_need_block_list_len;
|
||||
if (in_need_block_list(
|
||||
max_used_list_len_id, need_block_list_sm, need_block_len_all)) {
|
||||
need_block_len_sm--;
|
||||
in_need_block_list_len++;
|
||||
}
|
||||
step_len_lm++;
|
||||
free_list_len_sm += max_used_list_len;
|
||||
LM2GM_ASYNC(
|
||||
&flag_true, stop_flags + max_used_list_len_id, sizeof(bool));
|
||||
is_block_step_lm[max_used_list_len_id] = true;
|
||||
LM2GM_ASYNC(&value_zero,
|
||||
seq_lens_this_time + max_used_list_len_id,
|
||||
sizeof(int));
|
||||
LM2GM_ASYNC(
|
||||
&value_zero, seq_lens_decoder + max_used_list_len_id, sizeof(int));
|
||||
// Note(@wufeisheng): when step, accept num will not be 0 so
|
||||
// that next step even if this batch member is stepped, save
|
||||
// output still stream output, so accept num should be set to 0
|
||||
LM2GM_ASYNC(
|
||||
&accept_num, accept_num + max_used_list_len_id, sizeof(int));
|
||||
mfence();
|
||||
}
|
||||
}
|
||||
}
|
||||
sync_cluster();
|
||||
|
||||
int need_block_len_all = need_block_len_sm + in_need_block_list_len;
|
||||
for (int tid = cid; tid < need_block_len_all; tid += ncores) {
|
||||
// 为需要block的位置分配block,每个位置分配一个block
|
||||
const int need_block_id = need_block_list_sm[tid];
|
||||
if (need_block_id != -1) {
|
||||
GM2LM(stop_flags + need_block_id, &stop_flag_lm, sizeof(bool));
|
||||
if (!stop_flag_lm) {
|
||||
// 如果需要的位置正好是上一步中被释放的位置,不做处理
|
||||
used_list_len_sm[need_block_id]++;
|
||||
const int ori_free_list_len = atomic_add(&free_list_len_sm, -1);
|
||||
int tmp_seq_lens_decoder;
|
||||
GM2LM(seq_lens_decoder + need_block_id,
|
||||
&tmp_seq_lens_decoder,
|
||||
sizeof(int));
|
||||
int free_block_id;
|
||||
GM2LM(free_list + ori_free_list_len - 1, &free_block_id, sizeof(int));
|
||||
LM2GM(&free_block_id,
|
||||
block_tables + need_block_id * block_num_per_seq +
|
||||
(tmp_seq_lens_decoder + max_draft_tokens + 1) / block_size,
|
||||
sizeof(int));
|
||||
}
|
||||
need_block_list_sm[tid] = -1;
|
||||
}
|
||||
}
|
||||
sync_cluster();
|
||||
|
||||
// 计算可以复原的query id
|
||||
// 每次最多只恢复max_recover_num个query
|
||||
int max_recover_num = 1;
|
||||
if (cid == 0 && step_len_lm > 0) {
|
||||
int ori_free_list_len = free_list_len_sm;
|
||||
int ori_step_block_id = step_block_list_lm[step_len_lm - 1];
|
||||
int tmp_used_len = used_list_len_sm[ori_step_block_id];
|
||||
int encoder_block_len_lm;
|
||||
GM2LM(encoder_block_lens + ori_step_block_id,
|
||||
&encoder_block_len_lm,
|
||||
sizeof(int));
|
||||
const int max_decoder_block_num_this_seq =
|
||||
max_decoder_block_num - encoder_block_len_lm;
|
||||
// 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中)
|
||||
int used_len = tmp_used_len + 1 < max_decoder_block_num_this_seq
|
||||
? tmp_used_len + 1
|
||||
: max_decoder_block_num_this_seq;
|
||||
while (step_len_lm > 0 && ori_free_list_len >= used_len &&
|
||||
max_recover_num-- > 0) {
|
||||
recover_block_list_lm[recover_len_lm] = ori_step_block_id;
|
||||
is_block_step_lm[ori_step_block_id] = false;
|
||||
used_list_len_sm[ori_step_block_id] = used_len;
|
||||
ori_free_list_len -= used_len;
|
||||
step_block_list_lm[step_len_lm - 1] = -1;
|
||||
step_len_lm--;
|
||||
recover_len_lm++;
|
||||
if (step_len_lm > 0) {
|
||||
ori_step_block_id = step_block_list_lm[step_len_lm - 1];
|
||||
tmp_used_len = used_list_len_sm[ori_step_block_id];
|
||||
used_len = tmp_used_len + 1 < max_decoder_block_num_this_seq
|
||||
? tmp_used_len + 1
|
||||
: max_decoder_block_num_this_seq;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(zhupengyang):
|
||||
// Before the operator: need_block_len is 0, need_block_list is -1
|
||||
// After the operator: need_block_len is 0, need_block_list is -1
|
||||
// May need_block_len and need_block_list not need update?
|
||||
int ori_need_block_len;
|
||||
if (cid == 0) {
|
||||
ori_need_block_len = need_block_len_sm;
|
||||
need_block_len_sm = 0;
|
||||
}
|
||||
|
||||
if (cid == 0) {
|
||||
mfence();
|
||||
LM2GM_ASYNC(step_block_list_lm, step_block_list, sizeof(int) * bsz);
|
||||
LM2GM_ASYNC(is_block_step_lm, is_block_step, sizeof(bool) * bsz);
|
||||
LM2GM_ASYNC(&step_len_lm, step_len, sizeof(int));
|
||||
LM2GM_ASYNC(&recover_len_lm, recover_len, sizeof(int));
|
||||
LM2GM_ASYNC(recover_block_list_lm, recover_block_list, sizeof(int) * bsz);
|
||||
SM2GM_ASYNC(&free_list_len_sm, free_list_len, sizeof(int));
|
||||
SM2GM_ASYNC(&need_block_len_sm, need_block_len, sizeof(int));
|
||||
if (ori_need_block_len > 0) {
|
||||
SM2GM_ASYNC(need_block_list_sm,
|
||||
need_block_list,
|
||||
sizeof(int) * ori_need_block_len);
|
||||
}
|
||||
SM2GM_ASYNC(used_list_len_sm, used_list_len, sizeof(int) * bsz);
|
||||
mfence();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
@@ -65,7 +65,7 @@ __global__ void speculate_remove_padding(T* output_data,
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void speculate_get_padding_offset(int* padding_offset,
|
||||
__global__ void speculate_get_padding_offset(int* batch_id_per_token,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
@@ -90,8 +90,8 @@ __global__ void speculate_get_padding_offset(int* padding_offset,
|
||||
GM2LM(cum_offsets + bi, &cum_offsets_now_ind, sizeof(int));
|
||||
|
||||
for (int i = tid; i < seq_lens_now; i += ncores) {
|
||||
LM2GM(&cum_offsets_now,
|
||||
padding_offset + bi * max_seq_len - cum_offsets_now + i,
|
||||
LM2GM(&bi,
|
||||
batch_id_per_token + bi * max_seq_len - cum_offsets_now + i,
|
||||
sizeof(int));
|
||||
}
|
||||
LM2GM(&cum_offsets_now, cum_offsets_out + bi, sizeof(int));
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
#include "xpu/kernel/cluster.h"
|
||||
#include "xpu/kernel/cluster_partition.h"
|
||||
#include "xpu/kernel/cluster_primitive.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
static __device__ inline int loada_float(_shared_ptr_ const int* ptr) {
|
||||
int ret;
|
||||
__asm__ __volatile__("loada.w %0,%1" : "=&r"(ret) : "r"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ inline bool storea_float(_shared_ptr_ int* ptr, int value) {
|
||||
bool ret;
|
||||
__asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(ret) : "r"(value), "r"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ int atomic_add(_shared_ptr_ int* ptr, int value) {
|
||||
bool fail = true;
|
||||
int old_value;
|
||||
while (fail) {
|
||||
old_value = loada_float(ptr);
|
||||
int new_value = old_value + value;
|
||||
fail = storea_float(ptr, new_value);
|
||||
}
|
||||
return old_value;
|
||||
}
|
||||
|
||||
__global__ void speculate_recover_block(int* recover_block_list, // [bsz]
|
||||
int* recover_len,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
const int* ori_seq_lens_encoder,
|
||||
int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
int* block_tables,
|
||||
int* free_list,
|
||||
int* free_list_len,
|
||||
int64_t* input_ids,
|
||||
const int64_t* pre_ids,
|
||||
const int64_t* step_idx,
|
||||
const int* encoder_block_lens,
|
||||
const int* used_list_len,
|
||||
const int64_t* next_tokens,
|
||||
const int64_t* first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length) {
|
||||
int cid = core_id();
|
||||
int ncores = core_num();
|
||||
int clusterid = cluster_id();
|
||||
if (clusterid != 0) return;
|
||||
|
||||
// 128 = seq_len(8192) / block_size(64)
|
||||
// 每次最多处理block_table数量为128
|
||||
const int block_table_now_len = 128;
|
||||
int block_table_now[block_table_now_len];
|
||||
// max_seq_len == length
|
||||
// max_seq_len == pre_id_length
|
||||
|
||||
// 32k local memory per 4 core on kl2.
|
||||
// No enough memory for 16382 input_ids.
|
||||
const int buf_len = 256;
|
||||
int64_t input_ids_now[buf_len];
|
||||
|
||||
bool flag_false = false;
|
||||
|
||||
__shared__ int free_list_len_sm;
|
||||
// 每次最多处理free_list数量为block_table_now_len
|
||||
int free_list_now[block_table_now_len];
|
||||
if (cid == 0) {
|
||||
GM2SM(free_list_len, &free_list_len_sm, sizeof(int));
|
||||
}
|
||||
sync_cluster();
|
||||
|
||||
int recover_len_lm;
|
||||
GM2LM(recover_len, &recover_len_lm, sizeof(int));
|
||||
|
||||
for (int bid = cid; bid < recover_len_lm; bid += ncores) {
|
||||
int recover_id;
|
||||
int ori_seq_len_encoder;
|
||||
int step_idx_now;
|
||||
int encoder_block_len;
|
||||
int decoder_used_len;
|
||||
int64_t next_token;
|
||||
GM2LM(recover_block_list + bid, &recover_id, sizeof(int));
|
||||
GM2LM_ASYNC(
|
||||
ori_seq_lens_encoder + recover_id, &ori_seq_len_encoder, sizeof(int));
|
||||
GM2LM_ASYNC(step_idx + recover_id, &step_idx_now, sizeof(int));
|
||||
GM2LM_ASYNC(
|
||||
encoder_block_lens + recover_id, &encoder_block_len, sizeof(int));
|
||||
GM2LM_ASYNC(used_list_len + recover_id, &decoder_used_len, sizeof(int));
|
||||
GM2LM_ASYNC(next_tokens + recover_id, &next_token, sizeof(int64_t));
|
||||
mfence();
|
||||
|
||||
int seq_len = ori_seq_len_encoder + step_idx_now;
|
||||
mfence();
|
||||
LM2GM_ASYNC(&seq_len, seq_lens_this_time + recover_id, sizeof(int));
|
||||
LM2GM_ASYNC(&seq_len, seq_lens_encoder + recover_id, sizeof(int));
|
||||
LM2GM_ASYNC(&flag_false, stop_flags + recover_id, sizeof(bool));
|
||||
mfence();
|
||||
// // next tokens
|
||||
// LM2GM_ASYNC(&next_token,
|
||||
// input_ids + recover_id * length + seq_len - 1,
|
||||
// sizeof(int64_t));
|
||||
// set first prompt token
|
||||
int64_t first_token_id;
|
||||
GM2LM(first_token_ids + recover_id, &first_token_id, sizeof(int64_t));
|
||||
LM2GM_ASYNC(
|
||||
&first_token_id, input_ids + recover_id * length, sizeof(int64_t));
|
||||
|
||||
int ori_free_list_len = atomic_add(&free_list_len_sm, -decoder_used_len);
|
||||
// 恢复block table
|
||||
for (int i = 0; i < decoder_used_len; i += block_table_now_len) {
|
||||
int process_len = min(block_table_now_len, decoder_used_len - i);
|
||||
GM2LM(free_list + ori_free_list_len - i - process_len,
|
||||
free_list_now,
|
||||
process_len * sizeof(int));
|
||||
for (int j = 0; j < process_len; j++) {
|
||||
block_table_now[j] = free_list_now[process_len - 1 - j];
|
||||
}
|
||||
mfence();
|
||||
LM2GM(
|
||||
block_table_now,
|
||||
block_tables + recover_id * block_num_per_seq + encoder_block_len + i,
|
||||
process_len * sizeof(int));
|
||||
}
|
||||
// 恢复input_ids
|
||||
for (int i = 0; i < step_idx_now; i += buf_len) {
|
||||
int real_len = min(buf_len, step_idx_now - i);
|
||||
GM2LM(pre_ids + recover_id * pre_id_length + i + 1,
|
||||
input_ids_now,
|
||||
sizeof(int64_t) * real_len);
|
||||
LM2GM(input_ids_now,
|
||||
input_ids + recover_id * length + ori_seq_len_encoder + i,
|
||||
sizeof(int64_t) * real_len);
|
||||
}
|
||||
mfence();
|
||||
}
|
||||
|
||||
if (cid == 0) {
|
||||
recover_len_lm = 0;
|
||||
mfence();
|
||||
LM2GM_ASYNC(&recover_len_lm, recover_len, sizeof(int));
|
||||
SM2GM_ASYNC(&free_list_len_sm, free_list_len, sizeof(int));
|
||||
mfence();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
@@ -138,7 +138,8 @@ __global__ void speculate_verify(
|
||||
const int max_candidate_len, // scalar, 每个 verify token
|
||||
// 的最大候选数(用于验证或采样)
|
||||
const int verify_window, // scalar, TopK 验证窗口(允许连续 top1 匹配次数)
|
||||
const bool prefill_one_step_stop) {
|
||||
const bool prefill_one_step_stop,
|
||||
const bool benchmark_mode) {
|
||||
const int cid = core_id();
|
||||
const int64_t tid = cluster_id() * core_num() + core_id();
|
||||
const int64_t nthreads = cluster_num() * core_num();
|
||||
@@ -161,6 +162,9 @@ __global__ void speculate_verify(
|
||||
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
|
||||
// seq_lens_this_time[bid]-1);
|
||||
for (; i < seq_lens_this_time[bid] - 1; i++) {
|
||||
if (benchmark_mode) {
|
||||
break;
|
||||
}
|
||||
if (seq_lens_encoder[bid] != 0) {
|
||||
break;
|
||||
}
|
||||
@@ -326,7 +330,8 @@ __global__ void speculate_verify(
|
||||
int max_seq_len, \
|
||||
int max_candidate_len, \
|
||||
int verify_window, \
|
||||
bool prefill_one_step_stop);
|
||||
bool prefill_one_step_stop, \
|
||||
bool benchmark_mode);
|
||||
SPECULATE_VERIFY_INSTANTIATE(true, true)
|
||||
SPECULATE_VERIFY_INSTANTIATE(true, false)
|
||||
SPECULATE_VERIFY_INSTANTIATE(false, true)
|
||||
|
||||
@@ -23,6 +23,7 @@ template <typename TX, typename TY>
|
||||
__attribute__((global)) void eb_adjust_batch(TX *src,
|
||||
TY *dst,
|
||||
int *encoder_seqs_lods,
|
||||
int *decoder_seqs_lods,
|
||||
int *encoder_batch_map,
|
||||
int *decoder_batch_map,
|
||||
int en_batch,
|
||||
@@ -41,6 +42,7 @@ static int cpu_wrapper(api::Context *ctx,
|
||||
const TX *x,
|
||||
TY *y,
|
||||
const int *encoder_seqs_lods,
|
||||
const int *decoder_seqs_lods,
|
||||
const int *encoder_batch_map,
|
||||
const int *decoder_batch_map,
|
||||
int en_batch,
|
||||
@@ -56,11 +58,12 @@ static int cpu_wrapper(api::Context *ctx,
|
||||
// get copy size && src_offset
|
||||
int cpy_m = 0;
|
||||
if (de_batch > 0 && decoder_batch_map[de_idx] == i) {
|
||||
cpy_m = 1;
|
||||
ret = api::cast<TX, TY>(ctx,
|
||||
x + cur_offset * hidden_dim,
|
||||
y + (encoder_len_total + de_idx) * hidden_dim,
|
||||
cpy_m * hidden_dim);
|
||||
cpy_m = decoder_seqs_lods[de_idx + 1] - decoder_seqs_lods[de_idx];
|
||||
ret = api::cast<TX, TY>(
|
||||
ctx,
|
||||
x + cur_offset * hidden_dim,
|
||||
y + (encoder_len_total + decoder_seqs_lods[de_idx]) * hidden_dim,
|
||||
cpy_m * hidden_dim);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
de_idx++;
|
||||
}
|
||||
@@ -84,6 +87,7 @@ static int xpu3_wrapper(api::Context *ctx,
|
||||
const TX *x,
|
||||
TY *y,
|
||||
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int en_batch,
|
||||
@@ -98,6 +102,7 @@ static int xpu3_wrapper(api::Context *ctx,
|
||||
reinterpret_cast<XPU_INDEX_TYPE_TX *>(const_cast<TX *>(x)),
|
||||
reinterpret_cast<XPU_INDEX_TYPE_TY *>(y),
|
||||
encoder_seqs_lods.xpu,
|
||||
decoder_seqs_lods.xpu,
|
||||
encoder_batch_map.xpu,
|
||||
decoder_batch_map.xpu,
|
||||
en_batch,
|
||||
@@ -111,6 +116,7 @@ int eb_adjust_batch(api::Context *ctx,
|
||||
const TX *x,
|
||||
TY *y,
|
||||
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim) {
|
||||
@@ -119,28 +125,35 @@ int eb_adjust_batch(api::Context *ctx,
|
||||
// if (dev_id ==0) {
|
||||
// ctx->set_debug_level(0xA1);
|
||||
// }
|
||||
|
||||
// std::cout << decoder_seqs_lods.cpu[0] << " " << decoder_seqs_lods.cpu[1] <<
|
||||
// std::endl;
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_adjust_batch", TX, TY);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
x,
|
||||
y,
|
||||
encoder_seqs_lods,
|
||||
decoder_seqs_lods,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
hidden_dim);
|
||||
decoder_batch_map);
|
||||
WRAPPER_DUMP_PARAM1(ctx, hidden_dim);
|
||||
WRAPPER_DUMP(ctx);
|
||||
int encoder_batch = encoder_batch_map.len;
|
||||
int total_batch = encoder_batch + decoder_batch_map.len;
|
||||
int decoder_batch = decoder_batch_map.len;
|
||||
int total_batch = encoder_batch + decoder_batch;
|
||||
int max_encoder_lod = encoder_seqs_lods.cpu[encoder_batch];
|
||||
int m = max_encoder_lod + decoder_batch_map.len;
|
||||
int max_decoder_lod = decoder_seqs_lods.cpu[decoder_batch];
|
||||
int m = max_encoder_lod + max_decoder_lod;
|
||||
WRAPPER_CHECK_PTR(ctx, TX, m * hidden_dim, x);
|
||||
WRAPPER_CHECK_PTR(ctx, TY, m * hidden_dim, y);
|
||||
WRAPPER_ASSERT_GT(ctx, hidden_dim, 0);
|
||||
// check VectorParam
|
||||
WRAPPER_ASSERT_EQ(ctx, encoder_seqs_lods.len, encoder_batch_map.len + 1);
|
||||
WRAPPER_ASSERT_EQ(ctx, decoder_seqs_lods.len, decoder_batch_map.len + 1);
|
||||
WRAPPER_ASSERT_GE(ctx, encoder_seqs_lods.cpu[0], 0);
|
||||
WRAPPER_ASSERT_LE(ctx, encoder_seqs_lods.cpu[0], max_encoder_lod);
|
||||
WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[0], 0);
|
||||
WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[0], max_decoder_lod);
|
||||
for (int i = 0; i < encoder_batch_map.len; ++i) {
|
||||
WRAPPER_ASSERT_GE(ctx, encoder_batch_map.cpu[i], 0);
|
||||
WRAPPER_ASSERT_LT(ctx, encoder_batch_map.cpu[i], total_batch)
|
||||
@@ -150,12 +163,15 @@ int eb_adjust_batch(api::Context *ctx,
|
||||
for (int i = 0; i < decoder_batch_map.len; ++i) {
|
||||
WRAPPER_ASSERT_GE(ctx, decoder_batch_map.cpu[i], 0);
|
||||
WRAPPER_ASSERT_LT(ctx, decoder_batch_map.cpu[i], total_batch)
|
||||
WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[i + 1], 0);
|
||||
WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[i + 1], max_decoder_lod);
|
||||
}
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper<TX, TY>(ctx,
|
||||
x,
|
||||
y,
|
||||
encoder_seqs_lods.cpu,
|
||||
decoder_seqs_lods.cpu,
|
||||
encoder_batch_map.cpu,
|
||||
decoder_batch_map.cpu,
|
||||
encoder_batch_map.len,
|
||||
@@ -166,6 +182,8 @@ int eb_adjust_batch(api::Context *ctx,
|
||||
api::ctx_guard RAII_GUARD(ctx);
|
||||
api::VectorParam<int32_t> encoder_seqs_lods_xpu =
|
||||
encoder_seqs_lods.to_xpu(RAII_GUARD);
|
||||
api::VectorParam<int32_t> decoder_seqs_lods_xpu =
|
||||
decoder_seqs_lods.to_xpu(RAII_GUARD);
|
||||
api::VectorParam<int32_t> encoder_batch_map_xpu =
|
||||
encoder_batch_map.to_xpu(RAII_GUARD);
|
||||
api::VectorParam<int32_t> decoder_batch_map_xpu =
|
||||
@@ -174,6 +192,7 @@ int eb_adjust_batch(api::Context *ctx,
|
||||
x,
|
||||
y,
|
||||
encoder_seqs_lods_xpu,
|
||||
decoder_seqs_lods_xpu,
|
||||
encoder_batch_map_xpu,
|
||||
decoder_batch_map_xpu,
|
||||
encoder_batch_map.len,
|
||||
@@ -190,6 +209,7 @@ int eb_adjust_batch(api::Context *ctx,
|
||||
api::VectorParam<int32_t> &, \
|
||||
api::VectorParam<int32_t> &, \
|
||||
api::VectorParam<int32_t> &, \
|
||||
api::VectorParam<int32_t> &, \
|
||||
int64_t);
|
||||
|
||||
INSTANTIATION_EB_ADJUST_BATCH(float16, float16);
|
||||
|
||||
@@ -27,26 +27,29 @@ __attribute__((global)) void draft_model_preprocess(
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
int real_bsz,
|
||||
int max_draft_token,
|
||||
int accept_tokens_len,
|
||||
int draft_tokens_len,
|
||||
int input_ids_len,
|
||||
int base_model_draft_tokens_len,
|
||||
bool truncate_first_token,
|
||||
bool splitwise_prefill);
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
@@ -67,49 +70,47 @@ static int cpu_wrapper(api::Context* ctx,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
int real_bsz,
|
||||
int max_draft_token,
|
||||
int accept_tokens_len,
|
||||
int draft_tokens_len,
|
||||
int input_ids_len,
|
||||
int base_model_draft_tokens_len,
|
||||
bool truncate_first_token,
|
||||
bool splitwise_prefill) {
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
int64_t not_stop_flag_sum = 0;
|
||||
int64_t not_stop_flag = 0;
|
||||
for (int tid = 0; tid < real_bsz; tid++) {
|
||||
for (int tid = 0; tid < bsz; tid++) {
|
||||
if (splitwise_prefill) {
|
||||
int base_model_step_idx_now = base_model_step_idx[tid];
|
||||
auto* input_ids_now = input_ids + tid * input_ids_len;
|
||||
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
|
||||
// printf("bid: %d, base_model_step_idx_now: %d seq_lens_encoder_record:
|
||||
// %d\n", tid, base_model_step_idx_now, seq_lens_encoder_record[tid]);
|
||||
if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) {
|
||||
if (seq_lens_encoder[tid] > 0) {
|
||||
not_stop_flag = 1;
|
||||
int seq_len_encoder_record = seq_lens_encoder_record[tid];
|
||||
seq_lens_encoder[tid] = seq_len_encoder_record;
|
||||
seq_lens_encoder_record[tid] = -1;
|
||||
int seq_len_encoder = seq_lens_encoder[tid];
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
int position = seq_len_encoder_record;
|
||||
int position = seq_len_encoder;
|
||||
if (truncate_first_token) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder_record;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
} else {
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
|
||||
seq_lens_this_time[tid] = seq_len_encoder + 1;
|
||||
}
|
||||
} else {
|
||||
stop_flags[tid] = true;
|
||||
@@ -120,63 +121,77 @@ static int cpu_wrapper(api::Context* ctx,
|
||||
}
|
||||
not_stop_flag_sum += not_stop_flag;
|
||||
} else {
|
||||
auto base_model_step_idx_now = base_model_step_idx[tid];
|
||||
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
|
||||
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
|
||||
auto accept_num_now = accept_num[tid];
|
||||
auto* input_ids_now = input_ids + tid * input_ids_len;
|
||||
auto* base_model_draft_tokens_now =
|
||||
base_model_draft_tokens + tid * base_model_draft_tokens_len;
|
||||
auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid];
|
||||
const int32_t base_model_seq_len_this_time =
|
||||
base_model_seq_lens_this_time[tid];
|
||||
auto* pre_ids_now = pre_ids + tid * pre_ids_len;
|
||||
for (int i = 1; i < base_model_draft_tokens_len; i++) {
|
||||
base_model_draft_tokens_now[i] = -1;
|
||||
}
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
if (kvcache_scheduler_v1) {
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
stop_flags[tid] = true;
|
||||
is_block_step[tid] = true;
|
||||
// Need to continue infer
|
||||
}
|
||||
} else {
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
|
||||
not_stop_flag = 1;
|
||||
// 1. first token
|
||||
|
||||
if (base_model_step_idx_now == 0) {
|
||||
seq_lens_this_time[tid] = 0;
|
||||
not_stop_flag = 0;
|
||||
} else if (base_model_step_idx_now == 1 &&
|
||||
seq_lens_encoder_record[tid] > 0) {
|
||||
// prefill generation
|
||||
if (seq_lens_encoder[tid] > 0) {
|
||||
// Can be extended to first few tokens
|
||||
int seq_len_encoder_record = seq_lens_encoder_record[tid];
|
||||
seq_lens_encoder[tid] = seq_len_encoder_record;
|
||||
seq_lens_encoder_record[tid] = -1;
|
||||
seq_lens_decoder[tid] = seq_lens_decoder_record[tid];
|
||||
seq_lens_decoder_record[tid] = 0;
|
||||
int seq_len_encoder = seq_lens_encoder[tid];
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
int position = seq_len_encoder_record;
|
||||
pre_ids_now[0] = base_model_first_token;
|
||||
int position = seq_len_encoder;
|
||||
if (truncate_first_token) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder_record;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
} else {
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
|
||||
seq_lens_this_time[tid] = seq_len_encoder + 1;
|
||||
}
|
||||
} else { // decode generation
|
||||
if (kvcache_scheduler_v1) {
|
||||
// 3. try to recover mtp infer in V1 mode
|
||||
if (!base_model_is_block_step[tid] && is_block_step[tid]) {
|
||||
is_block_step[tid] = false;
|
||||
}
|
||||
}
|
||||
} else if (accept_num_now <=
|
||||
max_draft_token) /*Accept partial draft tokens*/ {
|
||||
// Base Model reject stop
|
||||
if (stop_flags[tid]) {
|
||||
stop_flags[tid] = false;
|
||||
seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
|
||||
step_idx[tid] = base_model_step_idx[tid];
|
||||
// TODO: check
|
||||
seq_lens_decoder[tid] =
|
||||
base_model_seq_len_decoder - base_model_seq_len_this_time;
|
||||
step_idx[tid] =
|
||||
base_model_step_idx[tid] - base_model_seq_len_this_time;
|
||||
} else {
|
||||
seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
|
||||
step_idx[tid] -= max_draft_token - accept_num_now;
|
||||
// 2: Last base model generated token and first MTP
|
||||
// token
|
||||
seq_lens_decoder[tid] -= num_model_step - 1;
|
||||
step_idx[tid] -= num_model_step - 1;
|
||||
}
|
||||
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
|
||||
draft_tokens_now[0] = modified_token;
|
||||
seq_lens_this_time[tid] = 1;
|
||||
} else /*Accept all draft tokens*/ {
|
||||
draft_tokens_now[1] = accept_tokens_now[max_draft_token];
|
||||
seq_lens_this_time[tid] = 2;
|
||||
for (int i = 0; i < accept_num_now; i++) {
|
||||
draft_tokens_now[i] = accept_tokens_now[i];
|
||||
const int pre_id_pos =
|
||||
base_model_step_idx[tid] - (accept_num_now - i);
|
||||
const int64_t accept_token = accept_tokens_now[i];
|
||||
pre_ids_now[pre_id_pos] = accept_token;
|
||||
}
|
||||
seq_lens_this_time[tid] = accept_num_now;
|
||||
}
|
||||
} else {
|
||||
stop_flags[tid] = true;
|
||||
@@ -199,26 +214,29 @@ static int xpu3_wrapper(api::Context* ctx,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
int real_bsz,
|
||||
int max_draft_token,
|
||||
int accept_tokens_len,
|
||||
int draft_tokens_len,
|
||||
int input_ids_len,
|
||||
int base_model_draft_tokens_len,
|
||||
bool truncate_first_token,
|
||||
bool splitwise_prefill) {
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
|
||||
// NOTE: Don't change 16 to 64, because kernel use gsm
|
||||
@@ -230,26 +248,29 @@ static int xpu3_wrapper(api::Context* ctx,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
reinterpret_cast<XPU_INT64*>(step_idx),
|
||||
seq_lens_encoder_record,
|
||||
seq_lens_decoder_record,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
reinterpret_cast<XPU_INT64*>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64*>(accept_tokens),
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
reinterpret_cast<const XPU_INT64*>(base_model_step_idx),
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
reinterpret_cast<XPU_INT64*>(base_model_draft_tokens),
|
||||
real_bsz,
|
||||
max_draft_token,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill);
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
@@ -261,26 +282,29 @@ int draft_model_preprocess(api::Context* ctx,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
int real_bsz,
|
||||
int max_draft_token,
|
||||
int accept_tokens_len,
|
||||
int draft_tokens_len,
|
||||
int input_ids_len,
|
||||
int base_model_draft_tokens_len,
|
||||
bool truncate_first_token,
|
||||
bool splitwise_prefill) {
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "draft_model_preprocess", int64_t);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
@@ -290,37 +314,34 @@ int draft_model_preprocess(api::Context* ctx,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder);
|
||||
WRAPPER_DUMP_PARAM5(ctx,
|
||||
step_idx,
|
||||
seq_lens_encoder_record,
|
||||
seq_lens_decoder_record,
|
||||
not_need_stop,
|
||||
batch_drop);
|
||||
WRAPPER_DUMP_PARAM5(
|
||||
ctx, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids);
|
||||
WRAPPER_DUMP_PARAM3(
|
||||
ctx, accept_tokens, accept_num, base_model_seq_lens_encoder);
|
||||
WRAPPER_DUMP_PARAM3(ctx,
|
||||
WRAPPER_DUMP_PARAM4(ctx,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags);
|
||||
WRAPPER_DUMP_PARAM3(
|
||||
ctx, base_model_is_block_step, base_model_draft_tokens, real_bsz);
|
||||
WRAPPER_DUMP_PARAM3(
|
||||
ctx, max_draft_token, accept_tokens_len, draft_tokens_len);
|
||||
WRAPPER_DUMP_PARAM3(
|
||||
ctx, input_ids_len, base_model_draft_tokens_len, truncate_first_token);
|
||||
WRAPPER_DUMP_PARAM1(ctx, splitwise_prefill);
|
||||
ctx, base_model_is_block_step, base_model_draft_tokens, bsz);
|
||||
WRAPPER_DUMP_PARAM3(ctx, num_model_step, accept_tokens_len, draft_tokens_len);
|
||||
WRAPPER_DUMP_PARAM4(ctx,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token);
|
||||
WRAPPER_DUMP_PARAM2(ctx, splitwise_prefill, kvcache_scheduler_v1);
|
||||
WRAPPER_DUMP(ctx);
|
||||
|
||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * accept_tokens_len, accept_tokens);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * input_ids_len, input_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * draft_tokens_len, draft_tokens);
|
||||
WRAPPER_CHECK_PTR(ctx,
|
||||
int64_t,
|
||||
real_bsz * base_model_draft_tokens_len,
|
||||
base_model_draft_tokens);
|
||||
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, bsz * accept_tokens_len, accept_tokens);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, bsz * input_ids_len, input_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, bsz * draft_tokens_len, draft_tokens);
|
||||
WRAPPER_CHECK_PTR(
|
||||
ctx, int64_t, bsz * base_model_draft_tokens_len, base_model_draft_tokens);
|
||||
|
||||
WRAPPER_ASSERT_GT(ctx, real_bsz, 0);
|
||||
WRAPPER_ASSERT_GT(ctx, bsz, 0);
|
||||
WRAPPER_ASSERT_LT(ctx, accept_tokens_len, 128);
|
||||
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
@@ -332,26 +353,29 @@ int draft_model_preprocess(api::Context* ctx,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
seq_lens_encoder_record,
|
||||
seq_lens_decoder_record,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
real_bsz,
|
||||
max_draft_token,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill);
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
@@ -362,26 +386,29 @@ int draft_model_preprocess(api::Context* ctx,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
seq_lens_encoder_record,
|
||||
seq_lens_decoder_record,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
real_bsz,
|
||||
max_draft_token,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill);
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,227 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl/launch_strategy.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
#include "xpu/xdnn.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
template <typename TX, typename TY>
|
||||
__attribute__((global)) void eb_mtp_gather_next_token(TX *src,
|
||||
TY *dst,
|
||||
int *encoder_seqs_lods,
|
||||
int *decoder_seqs_lods,
|
||||
int *encoder_batch_map,
|
||||
int *decoder_batch_map,
|
||||
int en_batch,
|
||||
int de_batch,
|
||||
int64_t copy_size);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
template <typename TX, typename TY>
|
||||
static int cpu_wrapper(api::Context *ctx,
|
||||
const TX *x,
|
||||
TY *y,
|
||||
const int *encoder_seqs_lods,
|
||||
const int *decoder_seqs_lods,
|
||||
const int *encoder_batch_map,
|
||||
const int *decoder_batch_map,
|
||||
int en_batch,
|
||||
int de_batch,
|
||||
int64_t hidden_dim) {
|
||||
int ret = 0;
|
||||
int encoder_len_total = encoder_seqs_lods[en_batch];
|
||||
int decoder_len_total = decoder_seqs_lods[de_batch];
|
||||
int output_token_num = en_batch + decoder_len_total;
|
||||
for (int i = 0; i < output_token_num; i++) {
|
||||
int len = 0;
|
||||
int enc_idx = 0, dec_idx = 0;
|
||||
bool is_enc;
|
||||
while (i >= len) {
|
||||
if (enc_idx >= en_batch) {
|
||||
len += decoder_seqs_lods[dec_idx + 1] - decoder_seqs_lods[dec_idx];
|
||||
dec_idx++;
|
||||
is_enc = false;
|
||||
continue;
|
||||
}
|
||||
if (dec_idx >= de_batch) {
|
||||
len += 1;
|
||||
enc_idx++;
|
||||
is_enc = true;
|
||||
continue;
|
||||
}
|
||||
if ((encoder_batch_map[enc_idx] < decoder_batch_map[dec_idx])) {
|
||||
len += 1;
|
||||
enc_idx++;
|
||||
is_enc = true;
|
||||
} else {
|
||||
len += decoder_seqs_lods[dec_idx + 1] - decoder_seqs_lods[dec_idx];
|
||||
dec_idx++;
|
||||
is_enc = false;
|
||||
}
|
||||
}
|
||||
const TX *src = nullptr;
|
||||
if (is_enc) {
|
||||
src = x + (encoder_seqs_lods[enc_idx] - 1) * hidden_dim;
|
||||
} else {
|
||||
src = x + (encoder_len_total + decoder_seqs_lods[dec_idx] - (len - i)) *
|
||||
hidden_dim;
|
||||
}
|
||||
ret = api::cast<TX, TY>(ctx, src, y + i * hidden_dim, hidden_dim);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
template <typename TX, typename TY>
|
||||
static int xpu3_wrapper(api::Context *ctx,
|
||||
const TX *x,
|
||||
TY *y,
|
||||
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int en_batch,
|
||||
int de_batch,
|
||||
int64_t hidden_dim) {
|
||||
auto eb_mtp_gather_next_token_kernel =
|
||||
xpu3::plugin::eb_mtp_gather_next_token<TX, TY>;
|
||||
// NOTE: Don't change 16 to 64, because kernel use gsm
|
||||
eb_mtp_gather_next_token_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
|
||||
const_cast<TX *>(x),
|
||||
y,
|
||||
encoder_seqs_lods.xpu,
|
||||
decoder_seqs_lods.xpu,
|
||||
encoder_batch_map.xpu,
|
||||
decoder_batch_map.xpu,
|
||||
en_batch,
|
||||
de_batch,
|
||||
hidden_dim);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
template <typename TX, typename TY>
|
||||
int eb_mtp_gather_next_token(
|
||||
api::Context *ctx,
|
||||
const TX *x,
|
||||
TY *y,
|
||||
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_mtp_gather_next_token", TX, TY);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
x,
|
||||
y,
|
||||
encoder_seqs_lods,
|
||||
decoder_seqs_lods,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map);
|
||||
WRAPPER_DUMP_PARAM1(ctx, hidden_dim);
|
||||
WRAPPER_DUMP(ctx);
|
||||
int encoder_batch = encoder_batch_map.len;
|
||||
int decoder_batch = decoder_batch_map.len;
|
||||
int max_encoder_lod = encoder_seqs_lods.cpu[encoder_batch];
|
||||
int max_decoder_lod = decoder_seqs_lods.cpu[decoder_batch];
|
||||
int m = encoder_seqs_lods.cpu[encoder_batch] +
|
||||
decoder_seqs_lods.cpu[decoder_batch];
|
||||
int out_m = encoder_batch + decoder_seqs_lods.cpu[decoder_batch];
|
||||
WRAPPER_CHECK_PTR(ctx, TX, m * hidden_dim, x);
|
||||
WRAPPER_CHECK_PTR(ctx, TY, out_m * hidden_dim, y);
|
||||
WRAPPER_ASSERT_GT(ctx, hidden_dim, 0);
|
||||
// check VectorParam
|
||||
WRAPPER_ASSERT_EQ(ctx, encoder_seqs_lods.len, encoder_batch_map.len + 1);
|
||||
WRAPPER_ASSERT_EQ(ctx, decoder_seqs_lods.len, decoder_batch_map.len + 1);
|
||||
WRAPPER_ASSERT_GE(ctx, encoder_seqs_lods.cpu[0], 0);
|
||||
WRAPPER_ASSERT_LE(ctx, encoder_seqs_lods.cpu[0], max_encoder_lod);
|
||||
WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[0], 0);
|
||||
WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[0], max_decoder_lod);
|
||||
// 注意: encoder/decoder的batch
|
||||
// map数值上有可能大于batch,因为复原后的batch排布有可能是稀疏的,所以这里只做非负检查
|
||||
for (int i = 0; i < encoder_batch_map.len; ++i) {
|
||||
WRAPPER_ASSERT_GE(ctx, encoder_batch_map.cpu[i], 0);
|
||||
WRAPPER_ASSERT_GE(ctx, encoder_seqs_lods.cpu[i + 1], 0);
|
||||
WRAPPER_ASSERT_LE(ctx, encoder_seqs_lods.cpu[i + 1], max_encoder_lod);
|
||||
}
|
||||
for (int i = 0; i < decoder_batch_map.len; ++i) {
|
||||
WRAPPER_ASSERT_GE(ctx, decoder_batch_map.cpu[i], 0);
|
||||
WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[i + 1], 0);
|
||||
WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[i + 1], max_decoder_lod);
|
||||
}
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper<TX, TY>(ctx,
|
||||
x,
|
||||
y,
|
||||
encoder_seqs_lods.cpu,
|
||||
decoder_seqs_lods.cpu,
|
||||
encoder_batch_map.cpu,
|
||||
decoder_batch_map.cpu,
|
||||
encoder_batch_map.len,
|
||||
decoder_batch_map.len,
|
||||
hidden_dim);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU3) {
|
||||
api::ctx_guard RAII_GUARD(ctx);
|
||||
api::VectorParam<int32_t> encoder_seqs_lods_xpu =
|
||||
encoder_seqs_lods.to_xpu(RAII_GUARD);
|
||||
api::VectorParam<int32_t> decoder_seqs_lods_xpu =
|
||||
decoder_seqs_lods.to_xpu(RAII_GUARD);
|
||||
api::VectorParam<int32_t> encoder_batch_map_xpu =
|
||||
encoder_batch_map.to_xpu(RAII_GUARD);
|
||||
api::VectorParam<int32_t> decoder_batch_map_xpu =
|
||||
decoder_batch_map.to_xpu(RAII_GUARD);
|
||||
return xpu3_wrapper<TX, TY>(ctx,
|
||||
x,
|
||||
y,
|
||||
encoder_seqs_lods_xpu,
|
||||
decoder_seqs_lods_xpu,
|
||||
encoder_batch_map_xpu,
|
||||
decoder_batch_map_xpu,
|
||||
encoder_batch_map.len,
|
||||
decoder_batch_map.len,
|
||||
hidden_dim);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
#define INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(TX, TY) \
|
||||
template int eb_mtp_gather_next_token<TX, TY>(api::Context *, \
|
||||
const TX *, \
|
||||
TY *, \
|
||||
api::VectorParam<int32_t> &, \
|
||||
api::VectorParam<int32_t> &, \
|
||||
api::VectorParam<int32_t> &, \
|
||||
api::VectorParam<int32_t> &, \
|
||||
int64_t);
|
||||
|
||||
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float16, float16);
|
||||
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(bfloat16, bfloat16);
|
||||
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float, float);
|
||||
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float16, float);
|
||||
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float, float16);
|
||||
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float16);
|
||||
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float16, bfloat16);
|
||||
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float);
|
||||
INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float, bfloat16);
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
@@ -0,0 +1,340 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__attribute__((global)) void speculate_free_and_dispatch_block(
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len,
|
||||
int *recover_block_list,
|
||||
int *recover_len,
|
||||
int *need_block_list,
|
||||
int *need_block_len,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *first_token_ids,
|
||||
int *accept_num,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int cpu_wrapper(Context *ctx,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len,
|
||||
int *recover_block_list,
|
||||
int *recover_len,
|
||||
int *need_block_list,
|
||||
int *need_block_len,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *first_token_ids,
|
||||
int *accept_num,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num,
|
||||
const int max_draft_tokens) {
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
int *block_table_now = block_tables + i * block_num_per_seq;
|
||||
int max_possible_block_idx =
|
||||
(seq_lens_decoder[i] + max_draft_tokens + 1) / block_size;
|
||||
if (stop_flags[i] && !is_block_step[i]) {
|
||||
// 回收block块
|
||||
first_token_ids[i] = -1;
|
||||
const int encoder_block_len = encoder_block_lens[i];
|
||||
const int decoder_used_len = used_list_len[i];
|
||||
if (decoder_used_len > 0) {
|
||||
const int ori_free_list_len = free_list_len[0];
|
||||
free_list_len[0] += decoder_used_len;
|
||||
for (int j = 0; j < decoder_used_len; j++) {
|
||||
free_list[ori_free_list_len + j] =
|
||||
block_table_now[encoder_block_len + j];
|
||||
block_table_now[encoder_block_len + j] = -1;
|
||||
}
|
||||
encoder_block_lens[i] = 0;
|
||||
used_list_len[i] = 0;
|
||||
}
|
||||
} else if (seq_lens_this_time[i] != 0 &&
|
||||
max_possible_block_idx < block_num_per_seq &&
|
||||
block_table_now[(seq_lens_decoder[i] + max_draft_tokens + 1) /
|
||||
block_size] == -1) {
|
||||
// 统计需要分配block的位置和总数
|
||||
const int ori_need_block_len = need_block_len[0];
|
||||
need_block_len[0] += 1;
|
||||
need_block_list[ori_need_block_len] = i;
|
||||
}
|
||||
}
|
||||
|
||||
while (need_block_len[0] > free_list_len[0]) {
|
||||
// 调度block,根据used_list_len从大到小回收block,直到满足need_block_len
|
||||
int max_used_list_len_id = 0;
|
||||
int max_used_list_len = 0;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
const int used_block_num = !is_block_step[i] ? used_list_len[i] : 0;
|
||||
if (used_block_num > max_used_list_len) {
|
||||
max_used_list_len_id = i;
|
||||
max_used_list_len = used_block_num;
|
||||
}
|
||||
}
|
||||
|
||||
const int encoder_block_len = encoder_block_lens[max_used_list_len_id];
|
||||
int *block_table_now =
|
||||
block_tables + max_used_list_len_id * block_num_per_seq;
|
||||
for (int i = 0; i < max_used_list_len; i++) {
|
||||
free_list[free_list_len[0] + i] = block_table_now[encoder_block_len + i];
|
||||
block_table_now[encoder_block_len + i] = -1;
|
||||
}
|
||||
step_block_list[step_len[0]] = max_used_list_len_id;
|
||||
step_len[0] += 1;
|
||||
free_list_len[0] += max_used_list_len;
|
||||
stop_flags[max_used_list_len_id] = true;
|
||||
is_block_step[max_used_list_len_id] = true;
|
||||
seq_lens_this_time[max_used_list_len_id] = 0;
|
||||
seq_lens_decoder[max_used_list_len_id] = 0;
|
||||
accept_num[max_used_list_len_id] = 0;
|
||||
}
|
||||
|
||||
// 为需要block的位置分配block,每个位置分配一个block
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
if (i < need_block_len[0]) {
|
||||
const int need_block_id = need_block_list[i];
|
||||
if (!stop_flags[need_block_id]) {
|
||||
// 如果需要的位置正好是上一步中被释放的位置,不做处理
|
||||
used_list_len[need_block_id] += 1;
|
||||
const int ori_free_list_len = free_list_len[0];
|
||||
free_list_len[0]--;
|
||||
int *block_table_now = block_tables + need_block_id * block_num_per_seq;
|
||||
block_table_now[(seq_lens_decoder[need_block_id] + max_draft_tokens +
|
||||
1) /
|
||||
block_size] = free_list[ori_free_list_len - 1];
|
||||
}
|
||||
need_block_list[i] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
// 计算可以复原的query id
|
||||
int ori_step_len = step_len[0];
|
||||
if (ori_step_len > 0) {
|
||||
int ori_free_list_len = free_list_len[0];
|
||||
int ori_step_block_id = step_block_list[ori_step_len - 1];
|
||||
int tmp_used_len = used_list_len[ori_step_block_id];
|
||||
// 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中)
|
||||
int used_len =
|
||||
tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len;
|
||||
if (ori_step_len > 0 && ori_free_list_len >= used_len) {
|
||||
recover_block_list[recover_len[0]] = ori_step_block_id;
|
||||
is_block_step[ori_step_block_id] = false;
|
||||
used_list_len[ori_step_block_id] = used_len;
|
||||
ori_free_list_len -= used_len;
|
||||
step_block_list[ori_step_len - 1] = -1;
|
||||
step_len[0] -= 1;
|
||||
recover_len[0] += 1;
|
||||
ori_step_len = step_len[0];
|
||||
if (ori_step_len > 0) {
|
||||
ori_step_block_id = step_block_list[ori_step_len - 1];
|
||||
tmp_used_len = used_list_len[ori_step_block_id];
|
||||
used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1
|
||||
: tmp_used_len;
|
||||
}
|
||||
}
|
||||
need_block_len[0] = 0;
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int xpu3_wrapper(Context *ctx,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len,
|
||||
int *recover_block_list,
|
||||
int *recover_len,
|
||||
int *need_block_list,
|
||||
int *need_block_len,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *first_token_ids,
|
||||
int *accept_num,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num,
|
||||
const int max_draft_tokens) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto speculate_free_and_dispatch_block_kernel =
|
||||
xpu3::plugin::speculate_free_and_dispatch_block;
|
||||
speculate_free_and_dispatch_block_kernel<<<ctx->ncluster(),
|
||||
64,
|
||||
ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step,
|
||||
step_block_list,
|
||||
step_len,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
need_block_list,
|
||||
need_block_len,
|
||||
used_list_len,
|
||||
free_list,
|
||||
free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(first_token_ids),
|
||||
accept_num,
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num,
|
||||
max_draft_tokens);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int speculate_free_and_dispatch_block(Context *ctx,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len,
|
||||
int *recover_block_list,
|
||||
int *recover_len,
|
||||
int *need_block_list,
|
||||
int *need_block_len,
|
||||
int *used_list_len,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *first_token_ids,
|
||||
int *accept_num,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num,
|
||||
const int max_draft_tokens) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_free_and_dispatch_block", float);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
step_block_list,
|
||||
step_len,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
need_block_list,
|
||||
need_block_len);
|
||||
WRAPPER_DUMP_PARAM4(
|
||||
ctx, used_list_len, free_list, free_list_len, first_token_ids);
|
||||
WRAPPER_DUMP_PARAM4(
|
||||
ctx, bsz, block_size, block_num_per_seq, max_decoder_block_num);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step,
|
||||
step_block_list,
|
||||
step_len,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
need_block_list,
|
||||
need_block_len,
|
||||
used_list_len,
|
||||
free_list,
|
||||
free_list_len,
|
||||
first_token_ids,
|
||||
accept_num,
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num,
|
||||
max_draft_tokens);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step,
|
||||
step_block_list,
|
||||
step_len,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
need_block_list,
|
||||
need_block_len,
|
||||
used_list_len,
|
||||
free_list,
|
||||
free_list_len,
|
||||
first_token_ids,
|
||||
accept_num,
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num,
|
||||
max_draft_tokens);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
@@ -33,7 +33,7 @@ __attribute__((global)) void speculate_remove_padding(
|
||||
int token_num_data);
|
||||
|
||||
__attribute__((global)) void speculate_get_padding_offset(
|
||||
int* padding_offset,
|
||||
int* batch_id_per_token,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
@@ -78,7 +78,7 @@ static int cpu_wrapper_remove_padding(Context* ctx,
|
||||
}
|
||||
|
||||
static int cpu_wrapper_get_padding_offset(Context* ctx,
|
||||
int* padding_offset,
|
||||
int* batch_id_per_token,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
@@ -89,7 +89,7 @@ static int cpu_wrapper_get_padding_offset(Context* ctx,
|
||||
for (int bi = 0; bi < bsz; ++bi) {
|
||||
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
|
||||
for (int i = 0; i < seq_lens[bi]; i++) {
|
||||
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
|
||||
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
|
||||
}
|
||||
cum_offsets_out[bi] = cum_offset;
|
||||
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
|
||||
@@ -129,7 +129,7 @@ static int xpu3_wrapper_remove_padding(Context* ctx,
|
||||
}
|
||||
|
||||
static int xpu3_wrapper_get_padding_offset(Context* ctx,
|
||||
int* padding_offset,
|
||||
int* batch_id_per_token,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
@@ -139,7 +139,7 @@ static int xpu3_wrapper_get_padding_offset(Context* ctx,
|
||||
int bsz) {
|
||||
xpu3::plugin::
|
||||
speculate_get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
padding_offset,
|
||||
batch_id_per_token,
|
||||
cum_offsets_out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
@@ -215,7 +215,7 @@ int speculate_remove_padding(Context* ctx,
|
||||
}
|
||||
|
||||
int speculate_get_padding_offset(Context* ctx,
|
||||
int* padding_offset,
|
||||
int* batch_id_per_token,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
@@ -227,7 +227,7 @@ int speculate_get_padding_offset(Context* ctx,
|
||||
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_padding_offset", float);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
padding_offset,
|
||||
batch_id_per_token,
|
||||
cum_offsets_out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
@@ -247,7 +247,7 @@ int speculate_get_padding_offset(Context* ctx,
|
||||
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper_get_padding_offset(ctx,
|
||||
padding_offset,
|
||||
batch_id_per_token,
|
||||
cum_offsets_out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
@@ -258,7 +258,7 @@ int speculate_get_padding_offset(Context* ctx,
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper_get_padding_offset(ctx,
|
||||
padding_offset,
|
||||
batch_id_per_token,
|
||||
cum_offsets_out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__attribute__((global)) void speculate_recover_block(
|
||||
int *recover_block_list, // [bsz]
|
||||
int *recover_len,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
const int *ori_seq_lens_encoder,
|
||||
int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *input_ids,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int *encoder_block_lens,
|
||||
const int *used_list_len,
|
||||
const int64_t *next_tokens,
|
||||
const int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int cpu_wrapper(Context *ctx,
|
||||
int *recover_block_list, // [bsz]
|
||||
int *recover_len,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
const int *ori_seq_lens_encoder,
|
||||
int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *input_ids,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int *encoder_block_lens,
|
||||
const int *used_list_len,
|
||||
const int64_t *next_tokens,
|
||||
const int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length) {
|
||||
for (int bid = 0; bid < recover_len[0]; bid++) {
|
||||
const int recover_id = recover_block_list[bid];
|
||||
const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id];
|
||||
const int step_idx_now = step_idx[recover_id];
|
||||
const int seq_len = ori_seq_len_encoder + step_idx_now;
|
||||
const int encoder_block_len = encoder_block_lens[recover_id];
|
||||
const int decoder_used_len = used_list_len[recover_id];
|
||||
int *block_table_now = block_tables + recover_id * block_num_per_seq;
|
||||
int64_t *input_ids_now = input_ids + recover_id * length;
|
||||
const int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length;
|
||||
|
||||
seq_lens_this_time[recover_id] = seq_len;
|
||||
seq_lens_encoder[recover_id] = seq_len;
|
||||
stop_flags[recover_id] = false;
|
||||
// input_ids_now[seq_len - 1] = next_tokens[recover_id]; // next tokens
|
||||
input_ids_now[0] = first_token_ids[recover_id]; // set first prompt token
|
||||
int ori_free_list_len = free_list_len[0];
|
||||
free_list_len[0] -= decoder_used_len;
|
||||
|
||||
// 恢复block table
|
||||
for (int i = 0; i < decoder_used_len; i++) {
|
||||
block_table_now[encoder_block_len + i] =
|
||||
free_list[ori_free_list_len - i - 1];
|
||||
}
|
||||
// 恢复input_ids
|
||||
for (int i = 0; i < step_idx_now; i++) {
|
||||
input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1];
|
||||
}
|
||||
}
|
||||
recover_len[0] = 0;
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int xpu3_wrapper(Context *ctx,
|
||||
int *recover_block_list, // [bsz]
|
||||
int *recover_len,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
const int *ori_seq_lens_encoder,
|
||||
int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *input_ids,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int *encoder_block_lens,
|
||||
const int *used_list_len,
|
||||
const int64_t *next_tokens,
|
||||
const int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto recover_block_kernel = xpu3::plugin::speculate_recover_block;
|
||||
recover_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
recover_block_list, // [bsz]
|
||||
recover_len,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
ori_seq_lens_encoder,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
free_list,
|
||||
free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(input_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(step_idx),
|
||||
encoder_block_lens,
|
||||
used_list_len,
|
||||
reinterpret_cast<const XPU_INT64 *>(next_tokens),
|
||||
reinterpret_cast<const XPU_INT64 *>(first_token_ids),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
pre_id_length);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int speculate_recover_block(Context *ctx,
|
||||
int *recover_block_list, // [bsz]
|
||||
int *recover_len,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
const int *ori_seq_lens_encoder,
|
||||
int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
int *block_tables,
|
||||
int *free_list,
|
||||
int *free_list_len,
|
||||
int64_t *input_ids,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int *encoder_block_lens,
|
||||
const int *used_list_len,
|
||||
const int64_t *next_tokens,
|
||||
const int64_t *first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_recover_block", float);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
ori_seq_lens_encoder,
|
||||
seq_lens_encoder);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
free_list,
|
||||
free_list_len,
|
||||
input_ids,
|
||||
pre_ids);
|
||||
WRAPPER_DUMP_PARAM5(ctx,
|
||||
step_idx,
|
||||
encoder_block_lens,
|
||||
used_list_len,
|
||||
next_tokens,
|
||||
first_token_ids);
|
||||
WRAPPER_DUMP_PARAM4(ctx, bsz, block_num_per_seq, length, pre_id_length);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx,
|
||||
recover_block_list, // [bsz]
|
||||
recover_len,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
ori_seq_lens_encoder,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
free_list,
|
||||
free_list_len,
|
||||
input_ids,
|
||||
pre_ids,
|
||||
step_idx,
|
||||
encoder_block_lens,
|
||||
used_list_len,
|
||||
next_tokens,
|
||||
first_token_ids,
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
pre_id_length);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
recover_block_list, // [bsz]
|
||||
recover_len,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
ori_seq_lens_encoder,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
free_list,
|
||||
free_list_len,
|
||||
input_ids,
|
||||
pre_ids,
|
||||
step_idx,
|
||||
encoder_block_lens,
|
||||
used_list_len,
|
||||
next_tokens,
|
||||
first_token_ids,
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
pre_id_length);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
@@ -48,7 +48,8 @@ __attribute__((global)) void speculate_verify(
|
||||
const int max_seq_len,
|
||||
const int max_candidate_len,
|
||||
const int verify_window,
|
||||
const bool prefill_one_step_stop);
|
||||
const bool prefill_one_step_stop,
|
||||
const bool benchmark_mode);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
@@ -136,14 +137,15 @@ static int cpu_wrapper(Context *ctx,
|
||||
const int max_seq_len,
|
||||
const int max_candidate_len,
|
||||
const int verify_window,
|
||||
const bool prefill_one_step_stop) {
|
||||
const bool prefill_one_step_stop,
|
||||
const bool benchmark_mode) {
|
||||
for (int bid = 0; bid < real_bsz; ++bid) {
|
||||
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
|
||||
// verify and set stop flags
|
||||
int accept_num_now = 1;
|
||||
int stop_flag_now_int = 0;
|
||||
|
||||
if (!(is_block_step[bid] || bid >= real_bsz)) {
|
||||
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
|
||||
// printf("debug cpu bid:%d,start_token_id:%d\n",bid, start_token_id);
|
||||
// printf("bid %d\n", bid);
|
||||
|
||||
@@ -160,6 +162,9 @@ static int cpu_wrapper(Context *ctx,
|
||||
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
|
||||
// seq_lens_this_time[bid]-1);
|
||||
for (; i < seq_lens_this_time[bid] - 1; i++) {
|
||||
if (benchmark_mode) {
|
||||
break;
|
||||
}
|
||||
if (seq_lens_encoder[bid] != 0) {
|
||||
break;
|
||||
}
|
||||
@@ -326,7 +331,8 @@ static int xpu3_wrapper(Context *ctx,
|
||||
const int max_seq_len,
|
||||
const int max_candidate_len,
|
||||
const int verify_window,
|
||||
const bool prefill_one_step_stop) {
|
||||
const bool prefill_one_step_stop,
|
||||
const bool benchmark_mode) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
xpu3::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>
|
||||
<<<1, 64, ctx->xpu_stream>>>(
|
||||
@@ -354,7 +360,8 @@ static int xpu3_wrapper(Context *ctx,
|
||||
max_seq_len,
|
||||
max_candidate_len,
|
||||
verify_window,
|
||||
prefill_one_step_stop);
|
||||
prefill_one_step_stop,
|
||||
benchmark_mode);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||
@@ -383,7 +390,8 @@ int speculate_verify(Context *ctx,
|
||||
const int max_seq_len,
|
||||
const int max_candidate_len,
|
||||
const int verify_window,
|
||||
const bool prefill_one_step_stop) {
|
||||
const bool prefill_one_step_stop,
|
||||
const bool benchmark_mode) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_verify", int64_t);
|
||||
WRAPPER_DUMP_PARAM3(ctx, accept_tokens, accept_num, step_idx);
|
||||
@@ -406,12 +414,13 @@ int speculate_verify(Context *ctx,
|
||||
actual_candidate_len,
|
||||
real_bsz,
|
||||
max_draft_tokens);
|
||||
WRAPPER_DUMP_PARAM5(ctx,
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
end_length,
|
||||
max_seq_len,
|
||||
max_candidate_len,
|
||||
verify_window,
|
||||
prefill_one_step_stop);
|
||||
prefill_one_step_stop,
|
||||
benchmark_mode);
|
||||
WRAPPER_DUMP(ctx);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, accept_tokens);
|
||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, accept_num);
|
||||
@@ -469,7 +478,8 @@ int speculate_verify(Context *ctx,
|
||||
max_seq_len,
|
||||
max_candidate_len,
|
||||
verify_window,
|
||||
prefill_one_step_stop);
|
||||
prefill_one_step_stop,
|
||||
benchmark_mode);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper<ENABLE_TOPP, USE_TOPK>(ctx,
|
||||
@@ -497,7 +507,8 @@ int speculate_verify(Context *ctx,
|
||||
max_seq_len,
|
||||
max_candidate_len,
|
||||
verify_window,
|
||||
prefill_one_step_stop);
|
||||
prefill_one_step_stop,
|
||||
benchmark_mode);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
@@ -530,7 +541,8 @@ int speculate_verify(Context *ctx,
|
||||
int, /* max_seq_len */ \
|
||||
int, /* max_candidate_len */ \
|
||||
int, /* verify_window */ \
|
||||
bool); /* prefill_one_step_stop */
|
||||
bool, \
|
||||
bool); /* prefill_one_step_stop */
|
||||
|
||||
INSTANTIATE_SPECULATE_VERIFY(false, false)
|
||||
INSTANTIATE_SPECULATE_VERIFY(false, true)
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest # 导入 unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
adjust_batch,
|
||||
gather_next_token,
|
||||
get_infer_param,
|
||||
)
|
||||
|
||||
|
||||
def _run_test_base(seq_lens_this_time_data, output_padding_offset):
|
||||
"""
|
||||
通用的基础测试执行函数,包含了两个场景共有的逻辑。
|
||||
"""
|
||||
seq_lens_encoder = paddle.to_tensor([100, 0, 0, 0, 120, 140, 0], dtype="int32")
|
||||
seq_lens_decoder = paddle.to_tensor([0, 5, 0, 25, 64, 0, 128], dtype="int32")
|
||||
seq_lens_this_time = paddle.to_tensor(seq_lens_this_time_data, dtype="int32")
|
||||
|
||||
bsz = seq_lens_this_time.shape[0]
|
||||
cum_offsets = paddle.zeros(bsz, dtype="int32")
|
||||
block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8))
|
||||
|
||||
infer_params = get_infer_param(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64)
|
||||
|
||||
(
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
len_info_cpu,
|
||||
) = infer_params
|
||||
|
||||
token_num = seq_lens_this_time.sum().cpu().item()
|
||||
hidden_dim = 8192
|
||||
row_indices = paddle.arange(token_num, dtype="int32")
|
||||
row_indices_bf16 = row_indices.astype("bfloat16")
|
||||
input_tensor = paddle.unsqueeze(row_indices_bf16, axis=1).expand(shape=[token_num, hidden_dim])
|
||||
|
||||
# 测试 adjust_batch
|
||||
adjusted_output = adjust_batch(
|
||||
input_tensor,
|
||||
cum_offsets,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
len_info_cpu,
|
||||
None, # output_padding_offset
|
||||
-1, # max_input_length
|
||||
)
|
||||
|
||||
adjusted_output_cpu = adjust_batch(
|
||||
input_tensor.cpu(),
|
||||
cum_offsets,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
len_info_cpu,
|
||||
None, # output_padding_offset
|
||||
-1, # max_input_length
|
||||
)
|
||||
|
||||
# 用 np.testing 替代原生 assert,错误信息更友好
|
||||
adjusted_output_np = adjusted_output.astype("float32").cpu().numpy()
|
||||
adjusted_output_cpu_np = adjusted_output_cpu.astype("float32").cpu().numpy()
|
||||
np.testing.assert_allclose(adjusted_output_np, adjusted_output_cpu_np, err_msg="adjust_batch check failed!")
|
||||
|
||||
# 测试 gather_next_token
|
||||
gather_out = gather_next_token(
|
||||
adjusted_output,
|
||||
cum_offsets,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
len_info_cpu,
|
||||
output_padding_offset,
|
||||
-1,
|
||||
)
|
||||
|
||||
gather_out_cpu = gather_next_token(
|
||||
adjusted_output.cpu(),
|
||||
cum_offsets,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
len_info_cpu,
|
||||
output_padding_offset,
|
||||
-1,
|
||||
)
|
||||
|
||||
gather_out_np = gather_out.astype("float32").cpu().numpy()
|
||||
gather_out_cpu_np = gather_out_cpu.astype("float32").cpu().numpy()
|
||||
|
||||
if output_padding_offset is not None:
|
||||
np.testing.assert_allclose(gather_out_np, gather_out_cpu_np, err_msg="gather_next_token check failed!")
|
||||
else:
|
||||
for i in range(gather_out_cpu.shape[0]):
|
||||
if seq_lens_this_time[i] > 0:
|
||||
np.testing.assert_allclose(
|
||||
gather_out_np[i], gather_out_cpu_np[i], err_msg=f"gather_next_token check failed at index {i}!"
|
||||
)
|
||||
|
||||
|
||||
class TestXPUOps(unittest.TestCase): # 继承 unittest.TestCase
|
||||
"""测试 XPU ops 的 adjust_batch 和 gather_next_token 功能"""
|
||||
|
||||
def test_mix_with_mtp(self):
|
||||
"""测试混合批次处理中的 MTP (Multi-Token Prediction) 场景"""
|
||||
print("\nRunning test: test_mix_with_mtp")
|
||||
seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3]
|
||||
bsz = len(seq_lens_this_time_data)
|
||||
output_padding_offset = paddle.zeros(bsz, dtype="int32")
|
||||
|
||||
_run_test_base(seq_lens_this_time_data, output_padding_offset)
|
||||
print("Test passed for scenario: With MTP")
|
||||
|
||||
def test_mix_without_mtp(self):
|
||||
"""测试非 MTP (Single-Token Prediction) 场景下的功能"""
|
||||
print("\nRunning test: test_mix_without_mtp")
|
||||
seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1]
|
||||
output_padding_offset = None # 非 MTP 场景下,此参数为 None
|
||||
|
||||
_run_test_base(seq_lens_this_time_data, output_padding_offset)
|
||||
print("Test passed for scenario: Without MTP")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main() # 使用 unittest 运行测试
|
||||
@@ -12,50 +12,284 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.xpu import draft_model_preprocess
|
||||
|
||||
|
||||
def run_test(device="xpu"):
|
||||
paddle.seed(2022)
|
||||
def process_splitwise_prefill(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
base_model_draft_tokens_len,
|
||||
truncate_first_token,
|
||||
kvcache_scheduler_v1,
|
||||
):
|
||||
not_stop_flag_sum = 0
|
||||
|
||||
# Define parameters
|
||||
bsz = 10
|
||||
draft_tokens_len = 4
|
||||
input_ids_len = 8
|
||||
max_draft_token = 10
|
||||
for tid in range(bsz):
|
||||
not_stop_flag = 0
|
||||
input_ids_now = input_ids[tid]
|
||||
accept_tokens_now = accept_tokens[tid]
|
||||
if seq_lens_encoder[tid] > 0:
|
||||
not_stop_flag = 1
|
||||
seq_len_encoder = seq_lens_encoder[tid]
|
||||
stop_flags[tid] = False
|
||||
base_model_first_token = accept_tokens_now[0]
|
||||
position = seq_len_encoder
|
||||
if truncate_first_token:
|
||||
input_ids_now[position - 1] = base_model_first_token
|
||||
seq_lens_this_time[tid] = seq_len_encoder
|
||||
else:
|
||||
input_ids_now[position] = base_model_first_token
|
||||
seq_lens_this_time[tid] = seq_len_encoder + 1
|
||||
else:
|
||||
stop_flags[tid] = True
|
||||
seq_lens_this_time[tid] = 0
|
||||
seq_lens_decoder[tid] = 0
|
||||
seq_lens_encoder[tid] = 0
|
||||
not_stop_flag = 0
|
||||
not_stop_flag_sum = not_stop_flag_sum + not_stop_flag
|
||||
not_need_stop[0] = not_stop_flag_sum > 0
|
||||
|
||||
truncate_first_token = True
|
||||
splitwise_prefill = False
|
||||
# Create input tensors
|
||||
if device == "cpu":
|
||||
paddle.set_device(device)
|
||||
|
||||
draft_tokens = paddle.randint(0, 100, [bsz, draft_tokens_len], dtype="int64")
|
||||
input_ids = paddle.randint(0, 100, [bsz, input_ids_len], dtype="int64")
|
||||
stop_flags = paddle.randint(0, 1, [bsz], dtype="int").cast("bool")
|
||||
seq_lens_this_time = paddle.randint(0, 100, [bsz], dtype="int32")
|
||||
seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32")
|
||||
seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32")
|
||||
step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
|
||||
seq_lens_encoder_record = paddle.randint(0, 100, [bsz], dtype="int32")
|
||||
seq_lens_decoder_record = paddle.randint(0, 100, [bsz], dtype="int32")
|
||||
not_need_stop = paddle.zeros([1], dtype="bool").cpu()
|
||||
batch_drop = paddle.zeros([bsz], dtype="bool")
|
||||
def draft_model_preprocess_kernel(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
base_model_draft_tokens_len,
|
||||
truncate_first_token,
|
||||
kvcache_scheduler_v1,
|
||||
):
|
||||
not_stop_flag_sum = 0
|
||||
|
||||
# Output tensors
|
||||
accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64")
|
||||
accept_num = paddle.randint(1, max_draft_token + 5, [bsz], dtype="int32")
|
||||
base_model_seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32")
|
||||
base_model_seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32")
|
||||
base_model_step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
|
||||
base_model_stop_flags = paddle.zeros([bsz], dtype="bool")
|
||||
base_model_is_block_step = paddle.zeros([bsz], dtype="bool")
|
||||
base_model_draft_tokens = paddle.zeros([bsz, max_draft_token], dtype="int64")
|
||||
# Run the op
|
||||
outputs = draft_model_preprocess(
|
||||
for tid in range(bsz):
|
||||
not_stop_flag = 0
|
||||
accept_tokens_now = accept_tokens[tid]
|
||||
draft_tokens_now = draft_tokens[tid]
|
||||
accept_num_now = accept_num[tid]
|
||||
input_ids_now = input_ids[tid]
|
||||
base_model_draft_tokens_now = base_model_draft_tokens[tid]
|
||||
base_model_seq_len_decoder = base_model_seq_lens_decoder[tid]
|
||||
base_model_seq_len_this_time = base_model_seq_lens_this_time[tid]
|
||||
pre_ids_now = pre_ids[tid]
|
||||
|
||||
base_model_draft_tokens_now[1:base_model_draft_tokens_len] = -1
|
||||
|
||||
if kvcache_scheduler_v1:
|
||||
if base_model_stop_flags[tid] and base_model_is_block_step[tid]:
|
||||
stop_flags[tid] = True
|
||||
is_block_step[tid] = True
|
||||
# Need to continue infer
|
||||
else:
|
||||
if base_model_stop_flags[tid] and base_model_is_block_step[tid]:
|
||||
batch_drop[tid] = True
|
||||
stop_flags[tid] = True
|
||||
|
||||
if not (base_model_stop_flags[tid] or batch_drop[tid]):
|
||||
not_stop_flag = 1
|
||||
# 1. first token
|
||||
if seq_lens_encoder[tid] > 0:
|
||||
# Can be extended to first few tokens
|
||||
seq_len_encoder = seq_lens_encoder[tid]
|
||||
stop_flags[tid] = False
|
||||
base_model_first_token = accept_tokens_now[0]
|
||||
pre_ids_now[0] = base_model_first_token
|
||||
position = seq_len_encoder
|
||||
if truncate_first_token:
|
||||
input_ids_now[position - 1] = base_model_first_token
|
||||
seq_lens_this_time[tid] = seq_len_encoder
|
||||
else:
|
||||
input_ids_now[position] = base_model_first_token
|
||||
seq_lens_this_time[tid] = seq_len_encoder + 1
|
||||
else:
|
||||
if kvcache_scheduler_v1:
|
||||
# 3. try to recover mtp infer in V1 mode
|
||||
if not (base_model_is_block_step[tid] and is_block_step[tid]):
|
||||
is_block_step[tid] = False
|
||||
|
||||
if stop_flags[tid]:
|
||||
stop_flags[tid] = False
|
||||
# TODO: check
|
||||
seq_lens_decoder[tid] = base_model_seq_len_decoder - base_model_seq_len_this_time
|
||||
step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time
|
||||
else:
|
||||
# 2: Last base model generated token and first MTP token
|
||||
seq_lens_decoder[tid] -= num_model_step - 1
|
||||
step_idx[tid] -= num_model_step - 1
|
||||
|
||||
for i in range(accept_num_now):
|
||||
draft_tokens_now[i] = accept_tokens_now[i]
|
||||
pre_id_pos = base_model_step_idx[tid] - (accept_num_now - i)
|
||||
accept_token = accept_tokens_now[i]
|
||||
pre_ids_now[pre_id_pos] = accept_token
|
||||
|
||||
seq_lens_this_time[tid] = accept_num_now
|
||||
else:
|
||||
stop_flags[tid] = True
|
||||
seq_lens_this_time[tid] = 0
|
||||
seq_lens_decoder[tid] = 0
|
||||
seq_lens_encoder[tid] = 0
|
||||
not_stop_flag_sum = not_stop_flag_sum + not_stop_flag
|
||||
not_need_stop[0] = not_stop_flag_sum > 0
|
||||
|
||||
|
||||
def DispatchRunner(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
truncate_first_token,
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1,
|
||||
):
|
||||
base_model_draft_tokens_len = base_model_draft_tokens.shape[1]
|
||||
if splitwise_prefill:
|
||||
process_splitwise_prefill(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
base_model_draft_tokens_len,
|
||||
truncate_first_token,
|
||||
kvcache_scheduler_v1,
|
||||
)
|
||||
else:
|
||||
draft_model_preprocess_kernel(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
base_model_draft_tokens_len,
|
||||
truncate_first_token,
|
||||
kvcache_scheduler_v1,
|
||||
)
|
||||
|
||||
|
||||
def draft_model_preprocess_ref(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
num_model_step,
|
||||
truncate_first_token,
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1,
|
||||
):
|
||||
real_bsz = seq_lens_this_time.shape[0]
|
||||
|
||||
DispatchRunner(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
@@ -63,73 +297,110 @@ def run_test(device="xpu"):
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
seq_lens_encoder_record,
|
||||
seq_lens_decoder_record,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
max_draft_token=max_draft_token,
|
||||
truncate_first_token=truncate_first_token,
|
||||
splitwise_prefill=splitwise_prefill,
|
||||
real_bsz,
|
||||
num_model_step,
|
||||
truncate_first_token,
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1,
|
||||
)
|
||||
|
||||
# Return results for comparison
|
||||
results = {
|
||||
"draft_tokens": draft_tokens.numpy(),
|
||||
"input_ids": input_ids.numpy(),
|
||||
"stop_flags": stop_flags.numpy(),
|
||||
"seq_lens_this_time": seq_lens_this_time.numpy(),
|
||||
"accept_tokens": accept_tokens.numpy(),
|
||||
"accept_num": accept_num.numpy(),
|
||||
"not_need_stop": not_need_stop.numpy(),
|
||||
"outputs": [x.numpy() for x in outputs],
|
||||
}
|
||||
return results
|
||||
|
||||
class TestDraftModelPreprocess:
|
||||
def _run_tests(self):
|
||||
paddle.seed(2022)
|
||||
|
||||
def compare_results(cpu_results, xpu_results):
|
||||
# Compare all outputs
|
||||
for key in cpu_results:
|
||||
if key == "outputs":
|
||||
for i, (cpu_out, xpu_out) in enumerate(zip(cpu_results[key], xpu_results[key])):
|
||||
np.testing.assert_allclose(
|
||||
cpu_out,
|
||||
xpu_out,
|
||||
rtol=1e-5,
|
||||
atol=1e-8,
|
||||
err_msg=f"Output {i} mismatch between CPU and GPU",
|
||||
)
|
||||
else:
|
||||
np.testing.assert_allclose(
|
||||
cpu_results[key],
|
||||
xpu_results[key],
|
||||
rtol=1e-5,
|
||||
atol=1e-8,
|
||||
err_msg=f"{key} mismatch between CPU and GPU",
|
||||
)
|
||||
print("CPU and GPU results match!")
|
||||
# Define parameters
|
||||
bsz = 10
|
||||
draft_tokens_len = 4
|
||||
input_ids_len = 100
|
||||
max_draft_token = 10
|
||||
|
||||
truncate_first_token = True
|
||||
splitwise_prefill = False
|
||||
|
||||
def test_draft_model_preprocess():
|
||||
draft_tokens = paddle.randint(0, 100, [bsz, draft_tokens_len], dtype="int64")
|
||||
input_ids = paddle.randint(0, 100, [bsz, input_ids_len], dtype="int64")
|
||||
stop_flags = paddle.randint(0, 1, [bsz], dtype="int").cast("bool")
|
||||
seq_lens_this_time = paddle.randint(0, 100, [bsz], dtype="int32")
|
||||
seq_lens_encoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32")
|
||||
seq_lens_decoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32")
|
||||
step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
|
||||
seq_lens_encoder_record = paddle.randint(0, 100, [bsz], dtype="int32") # noqa: F841
|
||||
seq_lens_decoder_record = paddle.randint(0, 100, [bsz], dtype="int32") # noqa: F841
|
||||
not_need_stop = paddle.zeros([1], dtype="bool").cpu()
|
||||
is_block_step = paddle.zeros([bsz], dtype="bool")
|
||||
batch_drop = paddle.zeros([bsz], dtype="bool")
|
||||
|
||||
print("Running XPU test...")
|
||||
xpu_results = run_test("xpu")
|
||||
# Output tensors
|
||||
accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64")
|
||||
accept_num = paddle.randint(1, max_draft_token + 5, [bsz], dtype="int32")
|
||||
base_model_seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32")
|
||||
base_model_seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32")
|
||||
base_model_step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
|
||||
base_model_stop_flags = paddle.zeros([bsz], dtype="bool")
|
||||
base_model_is_block_step = paddle.zeros([bsz], dtype="bool")
|
||||
base_model_draft_tokens = paddle.zeros([bsz, max_draft_token], dtype="int64")
|
||||
# Run the op
|
||||
pre_ids = input_ids.clone()
|
||||
base_model_seq_lens_this_time = seq_lens_this_time
|
||||
num_model_step = max_draft_token
|
||||
|
||||
print("Running CPU test...")
|
||||
cpu_results = run_test("cpu")
|
||||
kvcache_scheduler_v1 = True
|
||||
inputs = (
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
num_model_step,
|
||||
truncate_first_token,
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1,
|
||||
)
|
||||
# inplace modify, need to clone inputs
|
||||
inputs_clone = [x.clone() if isinstance(x, paddle.Tensor) else x for x in inputs]
|
||||
draft_model_preprocess_ref(*inputs)
|
||||
draft_model_preprocess(*inputs_clone)
|
||||
return inputs, inputs_clone
|
||||
|
||||
print("Comparing results...")
|
||||
compare_results(cpu_results, xpu_results)
|
||||
|
||||
print("Test passed!")
|
||||
def test_draft_model_preprocess(self):
|
||||
results1, results2 = self._run_tests()
|
||||
np.testing.assert_allclose(results1[0], results2[0]) # draft_tokens
|
||||
np.testing.assert_allclose(results1[1], results2[1]) # input_ids
|
||||
np.testing.assert_allclose(results1[2], results2[2]) # stop_flags
|
||||
np.testing.assert_allclose(results1[3], results2[3]) # seq_lens_this_time
|
||||
np.testing.assert_allclose(results1[11], results2[11]) # accept_tokens
|
||||
np.testing.assert_allclose(results1[12], results2[12]) # accept_num
|
||||
np.testing.assert_allclose(results1[7], results2[7]) # not_need_stop
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_draft_model_preprocess()
|
||||
unittest.main()
|
||||
|
||||
312
custom_ops/xpu_ops/test/test_speculate_step.py
Normal file
312
custom_ops/xpu_ops/test/test_speculate_step.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.xpu import speculate_step_paddle
|
||||
|
||||
# 固定随机种子,保证测试可复现
|
||||
np.random.seed(2023)
|
||||
paddle.seed(2023)
|
||||
|
||||
|
||||
def generate_test_data():
|
||||
"""
|
||||
生成测试数据的辅助函数。
|
||||
这部分逻辑从 pytest 的 fixture 转换而来,作为一个普通函数供测试方法调用。
|
||||
"""
|
||||
# max_bs = 128
|
||||
max_bs = 8
|
||||
bs = max_bs
|
||||
max_seq_len = 8192
|
||||
block_size = 64
|
||||
block_bs = 8
|
||||
block_ratio = 0.75
|
||||
max_draft_tokens = 1
|
||||
encoder_decoder_block_num = 1
|
||||
|
||||
# 生成原始测试数据(完全复用原有逻辑)
|
||||
stop_flags = np.random.randint(0, 2, [max_bs]).astype("bool")
|
||||
seq_lens_this_time = np.zeros([bs], "int32")
|
||||
seq_lens_encoder = np.zeros([max_bs], "int32")
|
||||
seq_lens_decoder = np.zeros([max_bs], "int32")
|
||||
accept_num = np.random.randint(1, 3, [max_bs]).astype("int32")
|
||||
for i in range(bs):
|
||||
seq_lens_decoder[i] = 2 + i * 2
|
||||
seq_lens_this_time[i] = 1
|
||||
|
||||
ori_seq_lens_encoder = np.zeros([max_bs], "int32")
|
||||
ori_seq_lens_encoder[:] = seq_lens_decoder[:] // 2
|
||||
step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64")
|
||||
|
||||
max_block_num = block_bs * max_seq_len // block_size
|
||||
free_list_len = int(max_block_num * (1 - block_ratio))
|
||||
free_list_len = np.full([1], free_list_len, "int32")
|
||||
free_list = np.arange(
|
||||
max_block_num - 1, max_block_num - free_list_len.item() - 1, -1, dtype="int32" # 加 .item() 转为标量
|
||||
)
|
||||
encoder_block_lens = np.zeros([max_bs], "int32")
|
||||
used_list_len = np.zeros([max_bs], "int32")
|
||||
block_tables = np.full([max_bs, 128], -1, "int32")
|
||||
encoder_block_id = 0
|
||||
|
||||
for i in range(bs):
|
||||
enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size
|
||||
encoder_block_lens[i] = enc_block_num
|
||||
dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num
|
||||
used_list_len[i] = dec_block_num
|
||||
block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
|
||||
encoder_block_id += enc_block_num
|
||||
if dec_block_num > 0:
|
||||
block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[
|
||||
free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1
|
||||
]
|
||||
free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1
|
||||
free_list_len[0] -= dec_block_num
|
||||
|
||||
assert free_list_len[0] >= 0, "free_list_len should not be negative"
|
||||
|
||||
is_block_step = np.zeros([max_bs], "bool")
|
||||
is_block_step[:bs] = np.random.randint(0, 2, [bs]).astype("bool")
|
||||
step_block_list = np.full([max_bs], -1, "int32")
|
||||
step_lens = np.full([1], 0, "int32")
|
||||
|
||||
for i in range(bs):
|
||||
if is_block_step[i]:
|
||||
step_block_list[step_lens[0]] = i
|
||||
step_lens[0] += 1
|
||||
|
||||
recover_lens = np.full([1], 0, "int32")
|
||||
recover_block_list = np.full([max_bs], -1, "int32")
|
||||
need_block_len = np.full([1], 0, "int32")
|
||||
need_block_list = np.full([max_bs], -1, "int32")
|
||||
|
||||
input_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64")
|
||||
pre_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64")
|
||||
next_tokens = np.random.randint(0, 1000, [max_bs], "int64")
|
||||
first_token_ids = np.random.randint(0, 1000, [max_bs], "int64")
|
||||
|
||||
paddle.set_device("cpu")
|
||||
# 转换为 paddle tensor(保持原有逻辑)
|
||||
data_cpu = {
|
||||
"stop_flags": paddle.to_tensor(stop_flags),
|
||||
"seq_lens_this_time": paddle.to_tensor(seq_lens_this_time),
|
||||
"seq_lens_encoder": paddle.to_tensor(seq_lens_encoder),
|
||||
"seq_lens_decoder": paddle.to_tensor(seq_lens_decoder),
|
||||
"ori_seq_lens_encoder": paddle.to_tensor(ori_seq_lens_encoder),
|
||||
"block_tables": paddle.to_tensor(block_tables),
|
||||
"encoder_block_lens": paddle.to_tensor(encoder_block_lens),
|
||||
"is_block_step": paddle.to_tensor(is_block_step),
|
||||
"step_block_list": paddle.to_tensor(step_block_list),
|
||||
"step_lens": paddle.to_tensor(step_lens),
|
||||
"recover_block_list": paddle.to_tensor(recover_block_list),
|
||||
"recover_lens": paddle.to_tensor(recover_lens),
|
||||
"need_block_list": paddle.to_tensor(need_block_list),
|
||||
"need_block_len": paddle.to_tensor(need_block_len),
|
||||
"used_list_len": paddle.to_tensor(used_list_len),
|
||||
"free_list_len": paddle.to_tensor(free_list_len),
|
||||
"free_list": paddle.to_tensor(free_list),
|
||||
"input_ids": paddle.to_tensor(input_ids),
|
||||
"pre_ids": paddle.to_tensor(pre_ids),
|
||||
"step_idx": paddle.to_tensor(step_idx),
|
||||
"next_tokens": paddle.to_tensor(next_tokens),
|
||||
"first_token_ids": paddle.to_tensor(first_token_ids),
|
||||
"accept_num": paddle.to_tensor(accept_num),
|
||||
"block_size": block_size,
|
||||
"encoder_decoder_block_num": encoder_decoder_block_num,
|
||||
"max_draft_tokens": max_draft_tokens,
|
||||
}
|
||||
|
||||
paddle.set_device("xpu:0")
|
||||
data_xpu = {
|
||||
"stop_flags": paddle.to_tensor(stop_flags),
|
||||
"seq_lens_this_time": paddle.to_tensor(seq_lens_this_time),
|
||||
"seq_lens_encoder": paddle.to_tensor(seq_lens_encoder),
|
||||
"seq_lens_decoder": paddle.to_tensor(seq_lens_decoder),
|
||||
"ori_seq_lens_encoder": paddle.to_tensor(ori_seq_lens_encoder),
|
||||
"block_tables": paddle.to_tensor(block_tables),
|
||||
"encoder_block_lens": paddle.to_tensor(encoder_block_lens),
|
||||
"is_block_step": paddle.to_tensor(is_block_step),
|
||||
"step_block_list": paddle.to_tensor(step_block_list),
|
||||
"step_lens": paddle.to_tensor(step_lens),
|
||||
"recover_block_list": paddle.to_tensor(recover_block_list),
|
||||
"recover_lens": paddle.to_tensor(recover_lens),
|
||||
"need_block_list": paddle.to_tensor(need_block_list),
|
||||
"need_block_len": paddle.to_tensor(need_block_len),
|
||||
"used_list_len": paddle.to_tensor(used_list_len),
|
||||
"free_list_len": paddle.to_tensor(free_list_len),
|
||||
"free_list": paddle.to_tensor(free_list),
|
||||
"input_ids": paddle.to_tensor(input_ids),
|
||||
"pre_ids": paddle.to_tensor(pre_ids),
|
||||
"step_idx": paddle.to_tensor(step_idx),
|
||||
"next_tokens": paddle.to_tensor(next_tokens),
|
||||
"first_token_ids": paddle.to_tensor(first_token_ids),
|
||||
"accept_num": paddle.to_tensor(accept_num),
|
||||
"block_size": block_size,
|
||||
"encoder_decoder_block_num": encoder_decoder_block_num,
|
||||
"max_draft_tokens": max_draft_tokens,
|
||||
}
|
||||
|
||||
# 恢复默认设备,避免影响其他测试
|
||||
paddle.set_device("cpu")
|
||||
|
||||
return data_cpu, data_xpu
|
||||
|
||||
|
||||
def speculate_step_paddle_execution(test_data):
|
||||
"""测试 speculate_step_paddle 函数的执行性和输出合理性"""
|
||||
# 提取输入数据
|
||||
stop_flags = test_data["stop_flags"] # 克隆避免影响夹具数据
|
||||
seq_lens_this_time = test_data["seq_lens_this_time"]
|
||||
ori_seq_lens_encoder = test_data["ori_seq_lens_encoder"]
|
||||
seq_lens_encoder = test_data["seq_lens_encoder"]
|
||||
seq_lens_decoder = test_data["seq_lens_decoder"]
|
||||
block_tables = test_data["block_tables"]
|
||||
encoder_block_lens = test_data["encoder_block_lens"]
|
||||
is_block_step = test_data["is_block_step"]
|
||||
step_block_list = test_data["step_block_list"]
|
||||
step_lens = test_data["step_lens"]
|
||||
recover_block_list = test_data["recover_block_list"]
|
||||
recover_lens = test_data["recover_lens"]
|
||||
need_block_list = test_data["need_block_list"]
|
||||
need_block_len = test_data["need_block_len"]
|
||||
used_list_len = test_data["used_list_len"]
|
||||
free_list = test_data["free_list"]
|
||||
free_list_len = test_data["free_list_len"]
|
||||
input_ids = test_data["input_ids"]
|
||||
pre_ids = test_data["pre_ids"]
|
||||
step_idx = test_data["step_idx"]
|
||||
next_tokens = test_data["next_tokens"]
|
||||
first_token_ids = test_data["first_token_ids"]
|
||||
accept_num = test_data["accept_num"]
|
||||
block_size = test_data["block_size"]
|
||||
encoder_decoder_block_num = test_data["encoder_decoder_block_num"]
|
||||
max_draft_tokens = test_data["max_draft_tokens"]
|
||||
|
||||
# 可选:打印执行前关键信息(如需调试可开启)
|
||||
if os.environ.get("STEP_TEST_DEBUG", "0") == "1":
|
||||
print("-" * 50 + "before step op" + "-" * 50)
|
||||
# ... (省略打印内容以保持简洁)
|
||||
|
||||
# 执行目标函数(核心测试步骤)
|
||||
speculate_step_paddle(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
ori_seq_lens_encoder,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step,
|
||||
step_block_list,
|
||||
step_lens,
|
||||
recover_block_list,
|
||||
recover_lens,
|
||||
need_block_list,
|
||||
need_block_len,
|
||||
used_list_len,
|
||||
free_list,
|
||||
free_list_len,
|
||||
input_ids,
|
||||
pre_ids,
|
||||
step_idx,
|
||||
next_tokens,
|
||||
first_token_ids,
|
||||
accept_num,
|
||||
block_size,
|
||||
encoder_decoder_block_num,
|
||||
max_draft_tokens,
|
||||
)
|
||||
|
||||
# 可选:打印执行后关键信息(如需调试可开启)
|
||||
if os.environ.get("STEP_TEST_DEBUG", "0") == "1":
|
||||
print("-" * 50 + "after step op" + "-" * 50)
|
||||
# ... (省略打印内容以保持简洁)
|
||||
|
||||
return test_data
|
||||
|
||||
|
||||
class TestSpeculateStepPaddle(unittest.TestCase):
|
||||
"""
|
||||
测试类,继承自 unittest.TestCase。
|
||||
所有以 'test_' 开头的方法都会被视为测试用例。
|
||||
"""
|
||||
|
||||
def assert_test_data_equal(self, test_data1, test_data2, rtol=1e-05, atol=1e-08):
|
||||
"""
|
||||
自定义的断言方法,用于比较两个 test_data 结构和数据。
|
||||
在 unittest 中,自定义断言通常以 'assert' 开头。
|
||||
"""
|
||||
# 1. 先校验两个 test_data 的字段名完全一致
|
||||
keys1 = set(test_data1.keys())
|
||||
keys2 = set(test_data2.keys())
|
||||
self.assertEqual(
|
||||
keys1,
|
||||
keys2,
|
||||
msg=f"两个 test_data 字段不一致!\n仅在第一个中存在:{keys1 - keys2}\n仅在第二个中存在:{keys2 - keys1}",
|
||||
)
|
||||
|
||||
# 2. 逐字段校验数据
|
||||
for key in keys1:
|
||||
data1 = test_data1[key]
|
||||
data2 = test_data2[key]
|
||||
|
||||
# 区分:paddle Tensor(需转 numpy)和 普通标量/数组(直接使用)
|
||||
if isinstance(data1, paddle.Tensor):
|
||||
np1 = data1.detach().cpu().numpy()
|
||||
else:
|
||||
np1 = np.asarray(data1)
|
||||
|
||||
if isinstance(data2, paddle.Tensor):
|
||||
np2 = data2.detach().cpu().numpy()
|
||||
else:
|
||||
np2 = np.asarray(data2)
|
||||
|
||||
# 3. 校验数据
|
||||
if np1.dtype in (np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8):
|
||||
# 布尔/整数型:必须完全相等
|
||||
np.testing.assert_array_equal(np1, np2, err_msg=f"字段 {key} 数据不一致!")
|
||||
else:
|
||||
# 浮点型:允许 rtol/atol 范围内的误差
|
||||
np.testing.assert_allclose(np1, np2, rtol=rtol, atol=atol, err_msg=f"字段 {key} 浮点数据不一致!")
|
||||
|
||||
print("✅ 两个 test_data 结构和数据完全一致!")
|
||||
|
||||
def test_speculate_step_paddle_execution(self):
|
||||
"""
|
||||
核心测试用例方法。
|
||||
该方法会调用 generate_test_data 获取数据,
|
||||
分别在 CPU 和 XPU 上执行测试函数,
|
||||
并使用自定义的断言方法比较结果。
|
||||
"""
|
||||
print("\nRunning test: test_speculate_step_paddle_execution")
|
||||
|
||||
# 1. 获取测试数据
|
||||
data_cpu, data_xpu = generate_test_data()
|
||||
|
||||
# 2. 执行测试函数
|
||||
result_xpu = speculate_step_paddle_execution(data_xpu)
|
||||
result_cpu = speculate_step_paddle_execution(data_cpu)
|
||||
|
||||
# 3. 断言结果一致
|
||||
self.assert_test_data_equal(result_xpu, result_cpu)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用 unittest 的主程序来运行所有测试用例
|
||||
unittest.main()
|
||||
@@ -12,101 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
# tests/test_speculate_update_v3.py
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
# 假设这是你的自定义算子
|
||||
from fastdeploy.model_executor.ops.xpu import speculate_update_v3
|
||||
|
||||
|
||||
# ---------------- NumPy 参考实现 ----------------
|
||||
def speculate_update_v3_np(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
not_need_stop,
|
||||
draft_tokens,
|
||||
actual_draft_token_nums,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
is_block_step,
|
||||
stop_nums,
|
||||
):
|
||||
"""
|
||||
完全复现 CPU / CUDA 逻辑的 NumPy 参考版本(就地修改)。
|
||||
"""
|
||||
stop_sum = 0
|
||||
real_bsz = seq_lens_this_time.shape[0]
|
||||
max_bsz = stop_flags.shape[0]
|
||||
max_draft_tokens = draft_tokens.shape[1]
|
||||
|
||||
for bid in range(max_bsz):
|
||||
stop_flag_now_int = 0
|
||||
inactive = bid >= real_bsz
|
||||
block_step = (not inactive) and is_block_step[bid]
|
||||
|
||||
if (not block_step) and (not inactive):
|
||||
|
||||
if stop_flags[bid]:
|
||||
stop_flag_now_int = 1
|
||||
|
||||
# encoder 长度为 0 时直接累加 decoder
|
||||
if seq_lens_encoder[bid] == 0:
|
||||
seq_lens_decoder[bid] += accept_num[bid]
|
||||
|
||||
# draft 长度自适应
|
||||
if (seq_lens_encoder[bid] == 0) and (seq_lens_this_time[bid] > 1):
|
||||
cur_len = actual_draft_token_nums[bid]
|
||||
if accept_num[bid] - 1 == cur_len: # 全部接受
|
||||
if cur_len + 2 <= max_draft_tokens - 1:
|
||||
cur_len += 2
|
||||
elif cur_len + 1 <= max_draft_tokens - 1:
|
||||
cur_len += 1
|
||||
else:
|
||||
cur_len = max_draft_tokens - 1
|
||||
else: # 有拒绝
|
||||
cur_len = max(1, cur_len - 1)
|
||||
actual_draft_token_nums[bid] = cur_len
|
||||
|
||||
# 偿还 encoder 欠账
|
||||
if seq_lens_encoder[bid] != 0:
|
||||
seq_lens_decoder[bid] += seq_lens_encoder[bid]
|
||||
seq_lens_encoder[bid] = 0
|
||||
|
||||
# 写回下一轮首 token
|
||||
draft_tokens[bid, 0] = accept_tokens[bid, accept_num[bid] - 1]
|
||||
|
||||
# 停止则清零 decoder
|
||||
if stop_flag_now_int:
|
||||
seq_lens_decoder[bid] = 0
|
||||
|
||||
elif inactive:
|
||||
stop_flag_now_int = 1 # padding slot 视为 stop
|
||||
|
||||
stop_sum += stop_flag_now_int
|
||||
|
||||
# print("stop_sum: ", stop_sum)
|
||||
not_need_stop[0] = stop_sum < stop_nums[0]
|
||||
|
||||
# 返回引用,仅供一致性
|
||||
return (
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
not_need_stop,
|
||||
draft_tokens,
|
||||
actual_draft_token_nums,
|
||||
)
|
||||
|
||||
|
||||
# ---------------- 生成随机输入 ----------------
|
||||
def gen_inputs(
|
||||
max_bsz=512, # 与 CUDA BlockSize 对齐
|
||||
max_draft_tokens=16,
|
||||
real_bsz=123, # 可自调;须 ≤ max_bsz
|
||||
seed=2022,
|
||||
):
|
||||
"""生成随机测试输入数据"""
|
||||
rng = np.random.default_rng(seed)
|
||||
|
||||
# 基本张量
|
||||
@@ -122,89 +43,91 @@ def gen_inputs(
|
||||
stop_nums = np.array([5], dtype=np.int64) # 阈值随意
|
||||
|
||||
# seq_lens_this_time 仅取 real_bsz 长度
|
||||
seq_lens_this_time = rng.integers(1, max_draft_tokens, size=real_bsz, dtype=np.int32)
|
||||
seq_lens_this_time = rng.integers(1, max_draft_tokens + 1, size=real_bsz, dtype=np.int32)
|
||||
|
||||
return {
|
||||
"seq_lens_encoder": seq_lens_encoder,
|
||||
"seq_lens_decoder": seq_lens_decoder,
|
||||
"not_need_stop": not_need_stop,
|
||||
"draft_tokens": draft_tokens,
|
||||
"actual_draft_token_nums": actual_draft_nums,
|
||||
"accept_tokens": accept_tokens,
|
||||
"accept_num": accept_num,
|
||||
"stop_flags": stop_flags,
|
||||
"seq_lens_this_time": seq_lens_this_time,
|
||||
"is_block_step": is_block_step,
|
||||
"stop_nums": stop_nums,
|
||||
# real_bsz = real_bsz,
|
||||
# max_bsz = max_bsz,
|
||||
# max_draft_tokens = max_draft_tokens
|
||||
paddle.set_device("xpu:0")
|
||||
data_xpu = {
|
||||
"seq_lens_encoder": paddle.to_tensor(seq_lens_encoder),
|
||||
"seq_lens_decoder": paddle.to_tensor(seq_lens_decoder),
|
||||
"not_need_stop": paddle.to_tensor(not_need_stop).cpu(),
|
||||
"draft_tokens": paddle.to_tensor(draft_tokens),
|
||||
"actual_draft_token_nums": paddle.to_tensor(actual_draft_nums),
|
||||
"accept_tokens": paddle.to_tensor(accept_tokens),
|
||||
"accept_num": paddle.to_tensor(accept_num),
|
||||
"stop_flags": paddle.to_tensor(stop_flags),
|
||||
"seq_lens_this_time": paddle.to_tensor(seq_lens_this_time),
|
||||
"is_block_step": paddle.to_tensor(is_block_step),
|
||||
"stop_nums": paddle.to_tensor(stop_nums),
|
||||
}
|
||||
|
||||
|
||||
# ------------------- 单测主体 -------------------
|
||||
inputs = gen_inputs(max_bsz=512, max_draft_tokens=32, real_bsz=201)
|
||||
|
||||
# ---- Paddle 端 ----
|
||||
paddle_inputs = {}
|
||||
for k, v in inputs.items():
|
||||
if k in ("real_bsz", "max_bsz", "max_draft_tokens"):
|
||||
paddle_inputs[k] = v # 纯 python int
|
||||
else:
|
||||
if k == "not_need_stop":
|
||||
paddle_inputs[k] = paddle.to_tensor(v, place=paddle.CPUPlace())
|
||||
else:
|
||||
# 其余张量保持默认 place(想测 GPU 就手动加 place=paddle.CUDAPlace(0))
|
||||
paddle_inputs[k] = paddle.to_tensor(v)
|
||||
|
||||
# ---- NumPy 端 ----
|
||||
# 为保证初值一致,这里必须复制 Paddle 入参的 numpy 值再传给参考实现
|
||||
np_inputs = {
|
||||
k: (paddle_inputs[k].numpy().copy() if isinstance(paddle_inputs[k], paddle.Tensor) else paddle_inputs[k])
|
||||
for k in paddle_inputs
|
||||
}
|
||||
|
||||
# 调用自定义算子
|
||||
# print("seq_lens_encoder_xpu_before: ", paddle_inputs["seq_lens_encoder"])
|
||||
out_pd = speculate_update_v3(**paddle_inputs)
|
||||
# print("seq_lens_encoder_xpu_after: ", out_pd[0])
|
||||
# print("not_need_stop: ", out_pd[2])
|
||||
|
||||
# speculate_update_v3 返回 5 个张量(与 Outputs 对应)
|
||||
(
|
||||
seq_lens_encoder_pd,
|
||||
seq_lens_decoder_pd,
|
||||
not_need_stop_pd,
|
||||
draft_tokens_pd,
|
||||
actual_draft_nums_pd,
|
||||
) = out_pd
|
||||
|
||||
# print("seq_lens_encoder_np_before: ", np_inputs["seq_lens_encoder"])
|
||||
out_np = speculate_update_v3_np(**np_inputs)
|
||||
# print("seq_lens_encoder_np_after: ", out_np[0])
|
||||
# print("not_need_stop: ", out_np[2])
|
||||
paddle.set_device("cpu")
|
||||
data_cpu = {
|
||||
"seq_lens_encoder": paddle.to_tensor(seq_lens_encoder),
|
||||
"seq_lens_decoder": paddle.to_tensor(seq_lens_decoder),
|
||||
"not_need_stop": paddle.to_tensor(not_need_stop),
|
||||
"draft_tokens": paddle.to_tensor(draft_tokens),
|
||||
"actual_draft_token_nums": paddle.to_tensor(actual_draft_nums),
|
||||
"accept_tokens": paddle.to_tensor(accept_tokens),
|
||||
"accept_num": paddle.to_tensor(accept_num),
|
||||
"stop_flags": paddle.to_tensor(stop_flags),
|
||||
"seq_lens_this_time": paddle.to_tensor(seq_lens_this_time),
|
||||
"is_block_step": paddle.to_tensor(is_block_step),
|
||||
"stop_nums": paddle.to_tensor(stop_nums),
|
||||
}
|
||||
return data_xpu, data_cpu
|
||||
|
||||
|
||||
# ---------------- 校对 ----------------
|
||||
names = [
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"not_need_stop",
|
||||
"draft_tokens",
|
||||
"actual_draft_token_nums",
|
||||
]
|
||||
pd_tensors = [
|
||||
seq_lens_encoder_pd,
|
||||
seq_lens_decoder_pd,
|
||||
not_need_stop_pd,
|
||||
draft_tokens_pd,
|
||||
actual_draft_nums_pd,
|
||||
]
|
||||
class TestSpeculateUpdateV3(unittest.TestCase):
|
||||
"""测试 speculate_update_v3 算子"""
|
||||
|
||||
for name, pd_val, np_val in zip(names, pd_tensors, out_np):
|
||||
pd_arr = pd_val.numpy()
|
||||
ok = np.array_equal(pd_arr, np_val)
|
||||
print(f"{name:25s} equal :", ok)
|
||||
def test_op_vs_golden(self, max_bsz=512, max_draft_tokens=16, real_bsz=123):
|
||||
"""
|
||||
核心测试:比较自定义算子的输出与纯 NumPy 参考实现的输出。
|
||||
"""
|
||||
# 1. gen inputs for cpu/xpu
|
||||
data_xpu, data_cpu = gen_inputs(max_bsz=max_bsz, max_draft_tokens=max_draft_tokens, real_bsz=real_bsz)
|
||||
|
||||
# 也可以加 assert,配合 pytest
|
||||
# assert all(np.array_equal(p.numpy(), n) for p,n in zip(pd_tensors, out_np))
|
||||
# 3. run xpu kernel
|
||||
speculate_update_v3(**data_xpu)
|
||||
|
||||
# 4. run cpu kernel
|
||||
speculate_update_v3(**data_cpu)
|
||||
|
||||
# 5. format outputs
|
||||
outputs_xpu = [
|
||||
data_xpu["seq_lens_encoder"].cpu().numpy(),
|
||||
data_xpu["seq_lens_decoder"].cpu().numpy(),
|
||||
data_xpu["not_need_stop"].cpu().numpy(),
|
||||
data_xpu["draft_tokens"].cpu().numpy(),
|
||||
data_xpu["actual_draft_token_nums"].cpu().numpy(),
|
||||
]
|
||||
|
||||
outputs_cpu = [
|
||||
data_cpu["seq_lens_encoder"].numpy(),
|
||||
data_cpu["seq_lens_decoder"].numpy(),
|
||||
data_cpu["not_need_stop"].numpy(),
|
||||
data_cpu["draft_tokens"].numpy(),
|
||||
data_cpu["actual_draft_token_nums"].numpy(),
|
||||
]
|
||||
output_names = [
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"not_need_stop",
|
||||
"draft_tokens",
|
||||
"actual_draft_token_nums",
|
||||
]
|
||||
|
||||
# 6. check outputs
|
||||
for name, pd_out, np_out in zip(output_names, outputs_xpu, outputs_cpu):
|
||||
with self.subTest(output_name=name):
|
||||
np.testing.assert_allclose(
|
||||
pd_out,
|
||||
np_out,
|
||||
atol=0,
|
||||
rtol=1e-6,
|
||||
err_msg=f"Output mismatch for tensor '{name}'.\nPaddle Output:\n{pd_out}\nGolden Output:\n{np_out}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
315
fastdeploy/model_executor/xpu_pre_and_post_process.py
Normal file
315
fastdeploy/model_executor/xpu_pre_and_post_process.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.forward_meta import XPUForwardMeta
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.output import ModelOutputData
|
||||
|
||||
if current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
adjust_batch,
|
||||
gather_next_token,
|
||||
get_infer_param,
|
||||
get_padding_offset,
|
||||
limit_thinking_content_length_v1,
|
||||
limit_thinking_content_length_v2,
|
||||
update_inputs_v1,
|
||||
)
|
||||
|
||||
|
||||
def xpu_pre_process(
|
||||
input_ids: paddle.Tensor,
|
||||
seq_lens_this_time: int,
|
||||
share_inputs: Dict,
|
||||
use_speculate_method: bool,
|
||||
block_size: int,
|
||||
draft_tokens: Optional[paddle.Tensor] = None,
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None,
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None,
|
||||
is_profiling: bool = False,
|
||||
) -> XPUForwardMeta:
|
||||
""" """
|
||||
max_len = input_ids.shape[1]
|
||||
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
|
||||
token_num = paddle.sum(seq_lens_this_time)
|
||||
|
||||
(
|
||||
ids_remove_padding,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
|
||||
|
||||
share_inputs["ids_remove_padding"] = None # set this after adjust batch
|
||||
share_inputs["cum_offsets"] = cum_offsets
|
||||
share_inputs["batch_id_per_token"] = batch_id_per_token
|
||||
share_inputs["cu_seqlens_q"] = cu_seqlens_q
|
||||
share_inputs["cu_seqlens_k"] = cu_seqlens_k
|
||||
|
||||
xpu_forward_meta = XPUForwardMeta(
|
||||
ids_remove_padding=share_inputs["ids_remove_padding"],
|
||||
rotary_embs=share_inputs["rope_emb"],
|
||||
attn_backend=None,
|
||||
seq_lens_encoder=share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=share_inputs["seq_lens_this_time"],
|
||||
cum_offsets=share_inputs["cum_offsets"],
|
||||
batch_id_per_token=share_inputs["batch_id_per_token"],
|
||||
cu_seqlens_q=share_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=share_inputs["cu_seqlens_k"],
|
||||
block_tables=share_inputs["block_tables"],
|
||||
caches=share_inputs["caches"],
|
||||
)
|
||||
|
||||
(
|
||||
xpu_forward_meta.encoder_batch_map,
|
||||
xpu_forward_meta.decoder_batch_map,
|
||||
xpu_forward_meta.encoder_batch_idx,
|
||||
xpu_forward_meta.decoder_batch_idx,
|
||||
xpu_forward_meta.encoder_seq_lod,
|
||||
xpu_forward_meta.decoder_seq_lod,
|
||||
xpu_forward_meta.encoder_kv_lod,
|
||||
xpu_forward_meta.prefix_len,
|
||||
xpu_forward_meta.decoder_context_len,
|
||||
xpu_forward_meta.decoder_context_len_cache,
|
||||
xpu_forward_meta.prefix_block_tables,
|
||||
xpu_forward_meta.encoder_batch_map_cpu,
|
||||
xpu_forward_meta.decoder_batch_map_cpu,
|
||||
xpu_forward_meta.encoder_batch_idx_cpu,
|
||||
xpu_forward_meta.decoder_batch_idx_cpu,
|
||||
xpu_forward_meta.encoder_seq_lod_cpu,
|
||||
xpu_forward_meta.decoder_seq_lod_cpu,
|
||||
xpu_forward_meta.encoder_kv_lod_cpu,
|
||||
xpu_forward_meta.prefix_len_cpu,
|
||||
xpu_forward_meta.decoder_context_len_cpu,
|
||||
xpu_forward_meta.decoder_context_len_cache_cpu,
|
||||
xpu_forward_meta.len_info_cpu,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, xpu_forward_meta.block_tables, block_size
|
||||
)
|
||||
xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0]
|
||||
xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1]
|
||||
xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2]
|
||||
|
||||
adjusted_input = adjust_batch(
|
||||
ids_remove_padding.reshape([-1, 1]),
|
||||
cum_offsets,
|
||||
xpu_forward_meta.encoder_seq_lod,
|
||||
xpu_forward_meta.decoder_seq_lod,
|
||||
xpu_forward_meta.encoder_batch_idx,
|
||||
xpu_forward_meta.decoder_batch_idx,
|
||||
xpu_forward_meta.encoder_seq_lod_cpu,
|
||||
xpu_forward_meta.decoder_seq_lod_cpu,
|
||||
xpu_forward_meta.encoder_batch_idx_cpu,
|
||||
xpu_forward_meta.decoder_batch_idx_cpu,
|
||||
xpu_forward_meta.len_info_cpu,
|
||||
None, # output_padding_offset
|
||||
-1, # max bs
|
||||
)
|
||||
|
||||
adjusted_input = adjusted_input.squeeze(1)
|
||||
|
||||
share_inputs["ids_remove_padding"] = adjusted_input
|
||||
xpu_forward_meta.ids_remove_padding = adjusted_input
|
||||
# Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
|
||||
xpu_forward_meta.is_profiling = is_profiling
|
||||
return xpu_forward_meta
|
||||
|
||||
|
||||
def xpu_process_output(
|
||||
forward_output,
|
||||
cum_offsets: paddle.Tensor,
|
||||
xpu_forward_meta: XPUForwardMeta,
|
||||
share_inputs,
|
||||
) -> paddle.Tensor:
|
||||
""" """
|
||||
|
||||
output_padding_offset = share_inputs.get("output_padding_offset", None)
|
||||
|
||||
hiddden_states = gather_next_token(
|
||||
forward_output,
|
||||
cum_offsets,
|
||||
xpu_forward_meta.encoder_seq_lod,
|
||||
xpu_forward_meta.decoder_seq_lod,
|
||||
xpu_forward_meta.encoder_batch_map,
|
||||
xpu_forward_meta.decoder_batch_map,
|
||||
xpu_forward_meta.encoder_seq_lod_cpu,
|
||||
xpu_forward_meta.decoder_seq_lod_cpu,
|
||||
xpu_forward_meta.encoder_batch_map_cpu,
|
||||
xpu_forward_meta.decoder_batch_map_cpu,
|
||||
xpu_forward_meta.len_info_cpu,
|
||||
output_padding_offset, # output_padding_offset
|
||||
-1, # max_input_length
|
||||
)
|
||||
return hiddden_states
|
||||
|
||||
|
||||
def xpu_post_process_normal(
|
||||
sampled_token_ids: paddle.Tensor,
|
||||
model_output: ModelOutputData,
|
||||
share_inputs: Dict[str, paddle.Tensor],
|
||||
block_size: int = 64,
|
||||
skip_save_output: bool = False,
|
||||
think_end_id: int = None,
|
||||
line_break_id: int = None,
|
||||
) -> None:
|
||||
""" """
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
save_output,
|
||||
set_stop_value_multi_ends,
|
||||
update_inputs,
|
||||
)
|
||||
|
||||
if think_end_id > 0:
|
||||
limit_strategy = envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR
|
||||
max_think_lens = share_inputs["max_think_lens"]
|
||||
step_idx = share_inputs["step_idx"]
|
||||
limit_think_status = share_inputs["limit_think_status"]
|
||||
stop_flags = share_inputs["stop_flags"]
|
||||
eos_token_ids = share_inputs["eos_token_id"]
|
||||
if limit_strategy == "</think>":
|
||||
# for ernie-45-vl
|
||||
limit_thinking_content_length_v1(
|
||||
sampled_token_ids,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题
|
||||
think_end_id,
|
||||
)
|
||||
elif limit_strategy == "\n</think>\n\n":
|
||||
# for ernie-x1
|
||||
assert line_break_id > 0
|
||||
limit_thinking_content_length_v2(
|
||||
sampled_token_ids,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.")
|
||||
|
||||
# 1. Set stop value
|
||||
paddle.assign(
|
||||
paddle.where(
|
||||
model_output.stop_flags,
|
||||
model_output.step_idx,
|
||||
model_output.step_idx + 1,
|
||||
),
|
||||
model_output.step_idx,
|
||||
)
|
||||
length_cond = paddle.greater_equal(model_output.step_idx, model_output.max_dec_len)
|
||||
paddle.assign(
|
||||
paddle.logical_or(model_output.stop_flags, length_cond),
|
||||
model_output.stop_flags,
|
||||
)
|
||||
set_stop_value_multi_ends(
|
||||
sampled_token_ids,
|
||||
model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.eos_token_id,
|
||||
model_output.next_tokens,
|
||||
False,
|
||||
) # multi ends
|
||||
|
||||
# 2. Update the input buffer of the model
|
||||
with paddle.framework._no_check_dy2st_diff():
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output:
|
||||
update_inputs_v1(
|
||||
model_output.stop_flags,
|
||||
model_output.not_need_stop,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
share_inputs["step_seq_lens_decoder"],
|
||||
share_inputs["prompt_lens"],
|
||||
sampled_token_ids,
|
||||
model_output.input_ids,
|
||||
share_inputs["block_tables"],
|
||||
model_output.stop_nums,
|
||||
model_output.next_tokens,
|
||||
model_output.is_block_step,
|
||||
block_size,
|
||||
)
|
||||
else:
|
||||
update_inputs(
|
||||
model_output.stop_flags,
|
||||
model_output.not_need_stop,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.input_ids,
|
||||
model_output.stop_nums,
|
||||
sampled_token_ids,
|
||||
model_output.is_block_step,
|
||||
)
|
||||
# 3. Transmit the model's output and stop generation signal via message queue.
|
||||
# In the future, we will abandon this approach.
|
||||
if not skip_save_output:
|
||||
save_output(
|
||||
sampled_token_ids,
|
||||
model_output.not_need_stop,
|
||||
model_output.mp_rank,
|
||||
False, # use_ep
|
||||
)
|
||||
|
||||
|
||||
def step_xpu(
|
||||
share_inputs: Dict[str, paddle.Tensor],
|
||||
block_size: int,
|
||||
enc_dec_block_num: int,
|
||||
) -> None:
|
||||
"""
|
||||
TODO(gongshaotian): normalization name
|
||||
"""
|
||||
from fastdeploy.model_executor.ops.xpu import step_paddle
|
||||
|
||||
step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
@@ -17,7 +17,7 @@
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -28,7 +28,7 @@ from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request, RequestType
|
||||
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.graph_optimization.utils import (
|
||||
profile_run_guard,
|
||||
sot_warmup_guard,
|
||||
@@ -43,17 +43,17 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler
|
||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
adjust_batch,
|
||||
create_kv_signal_sender,
|
||||
destroy_kv_signal_sender,
|
||||
get_infer_param,
|
||||
get_padding_offset,
|
||||
limit_thinking_content_length_v1,
|
||||
limit_thinking_content_length_v2,
|
||||
recover_decode_task,
|
||||
set_data_ipc,
|
||||
share_external_data,
|
||||
update_inputs_v1,
|
||||
)
|
||||
from fastdeploy.model_executor.xpu_pre_and_post_process import ( # xpu_post_process_specualate, # TODO(chenhuan09): add xpu_post_process_specualate
|
||||
step_xpu,
|
||||
xpu_post_process_normal,
|
||||
xpu_pre_process,
|
||||
xpu_process_output,
|
||||
)
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
||||
@@ -62,282 +62,6 @@ from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
|
||||
logger = get_logger("xpu_model_runner", "xpu_model_runner.log")
|
||||
|
||||
|
||||
def xpu_pre_process(
|
||||
input_ids: paddle.Tensor,
|
||||
seq_lens_this_time: int,
|
||||
share_inputs: Dict,
|
||||
use_speculate_method: bool,
|
||||
block_size: int,
|
||||
draft_tokens: Optional[paddle.Tensor] = None,
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None,
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None,
|
||||
is_profiling: bool = False,
|
||||
) -> XPUForwardMeta:
|
||||
""" """
|
||||
max_len = input_ids.shape[1]
|
||||
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
|
||||
token_num = paddle.sum(seq_lens_this_time)
|
||||
|
||||
(
|
||||
ids_remove_padding,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
|
||||
|
||||
share_inputs["ids_remove_padding"] = None # set this after adjust batch
|
||||
share_inputs["cum_offsets"] = cum_offsets
|
||||
share_inputs["batch_id_per_token"] = batch_id_per_token
|
||||
share_inputs["cu_seqlens_q"] = cu_seqlens_q
|
||||
share_inputs["cu_seqlens_k"] = cu_seqlens_k
|
||||
|
||||
xpu_forward_meta = XPUForwardMeta(
|
||||
ids_remove_padding=share_inputs["ids_remove_padding"],
|
||||
rotary_embs=share_inputs["rope_emb"],
|
||||
attn_backend=None,
|
||||
seq_lens_encoder=share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=share_inputs["seq_lens_this_time"],
|
||||
cum_offsets=share_inputs["cum_offsets"],
|
||||
batch_id_per_token=share_inputs["batch_id_per_token"],
|
||||
cu_seqlens_q=share_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=share_inputs["cu_seqlens_k"],
|
||||
block_tables=share_inputs["block_tables"],
|
||||
caches=share_inputs["caches"],
|
||||
)
|
||||
|
||||
(
|
||||
xpu_forward_meta.encoder_batch_map,
|
||||
xpu_forward_meta.decoder_batch_map,
|
||||
xpu_forward_meta.encoder_batch_idx,
|
||||
xpu_forward_meta.decoder_batch_idx,
|
||||
xpu_forward_meta.encoder_seq_lod,
|
||||
xpu_forward_meta.decoder_seq_lod,
|
||||
xpu_forward_meta.encoder_kv_lod,
|
||||
xpu_forward_meta.prefix_len,
|
||||
xpu_forward_meta.decoder_context_len,
|
||||
xpu_forward_meta.decoder_context_len_cache,
|
||||
xpu_forward_meta.prefix_block_tables,
|
||||
xpu_forward_meta.encoder_batch_map_cpu,
|
||||
xpu_forward_meta.decoder_batch_map_cpu,
|
||||
xpu_forward_meta.encoder_batch_idx_cpu,
|
||||
xpu_forward_meta.decoder_batch_idx_cpu,
|
||||
xpu_forward_meta.encoder_seq_lod_cpu,
|
||||
xpu_forward_meta.decoder_seq_lod_cpu,
|
||||
xpu_forward_meta.encoder_kv_lod_cpu,
|
||||
xpu_forward_meta.prefix_len_cpu,
|
||||
xpu_forward_meta.decoder_context_len_cpu,
|
||||
xpu_forward_meta.decoder_context_len_cache_cpu,
|
||||
xpu_forward_meta.len_info_cpu,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, xpu_forward_meta.block_tables, block_size
|
||||
)
|
||||
xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0]
|
||||
xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1]
|
||||
xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2]
|
||||
|
||||
adjusted_input = adjust_batch(
|
||||
ids_remove_padding.reshape([-1, 1]),
|
||||
cum_offsets,
|
||||
xpu_forward_meta.encoder_seq_lod,
|
||||
xpu_forward_meta.encoder_batch_idx,
|
||||
xpu_forward_meta.decoder_batch_idx,
|
||||
xpu_forward_meta.encoder_seq_lod_cpu,
|
||||
xpu_forward_meta.encoder_batch_idx_cpu,
|
||||
xpu_forward_meta.decoder_batch_idx_cpu,
|
||||
xpu_forward_meta.enc_batch,
|
||||
xpu_forward_meta.dec_batch,
|
||||
None, # output_padding_offset
|
||||
-1, # max_input_length
|
||||
)
|
||||
|
||||
adjusted_input = adjusted_input.squeeze(1)
|
||||
|
||||
share_inputs["ids_remove_padding"] = adjusted_input
|
||||
xpu_forward_meta.ids_remove_padding = adjusted_input
|
||||
# Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
|
||||
xpu_forward_meta.is_profiling = is_profiling
|
||||
return xpu_forward_meta
|
||||
|
||||
|
||||
def xpu_process_output(
|
||||
forward_output,
|
||||
cum_offsets: paddle.Tensor,
|
||||
xpu_forward_meta: XPUForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
""" """
|
||||
from fastdeploy.model_executor.ops.xpu import gather_next_token
|
||||
|
||||
hiddden_states = gather_next_token(
|
||||
forward_output,
|
||||
cum_offsets,
|
||||
xpu_forward_meta.encoder_seq_lod,
|
||||
xpu_forward_meta.encoder_batch_map,
|
||||
xpu_forward_meta.decoder_batch_map,
|
||||
xpu_forward_meta.encoder_seq_lod_cpu,
|
||||
xpu_forward_meta.encoder_batch_map_cpu,
|
||||
xpu_forward_meta.decoder_batch_map_cpu,
|
||||
xpu_forward_meta.enc_batch,
|
||||
xpu_forward_meta.dec_batch,
|
||||
None, # output_padding_offset
|
||||
-1, # max_input_length
|
||||
)
|
||||
return hiddden_states
|
||||
|
||||
|
||||
def xpu_post_process(
|
||||
sampled_token_ids: paddle.Tensor,
|
||||
model_output: ModelOutputData,
|
||||
share_inputs: Dict[str, paddle.Tensor],
|
||||
block_size: int = 64,
|
||||
skip_save_output: bool = False,
|
||||
think_end_id: int = None,
|
||||
line_break_id: int = None,
|
||||
) -> None:
|
||||
""" """
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
save_output,
|
||||
set_stop_value_multi_ends,
|
||||
update_inputs,
|
||||
)
|
||||
|
||||
if think_end_id > 0:
|
||||
limit_strategy = envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR
|
||||
max_think_lens = share_inputs["max_think_lens"]
|
||||
step_idx = share_inputs["step_idx"]
|
||||
limit_think_status = share_inputs["limit_think_status"]
|
||||
stop_flags = share_inputs["stop_flags"]
|
||||
eos_token_ids = share_inputs["eos_token_id"]
|
||||
if limit_strategy == "</think>":
|
||||
# for ernie-45-vl
|
||||
limit_thinking_content_length_v1(
|
||||
sampled_token_ids,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题
|
||||
think_end_id,
|
||||
)
|
||||
elif limit_strategy == "\n</think>\n\n":
|
||||
# for ernie-x1
|
||||
assert line_break_id > 0
|
||||
limit_thinking_content_length_v2(
|
||||
sampled_token_ids,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.")
|
||||
|
||||
# 1. Set stop value
|
||||
paddle.assign(
|
||||
paddle.where(
|
||||
model_output.stop_flags,
|
||||
model_output.step_idx,
|
||||
model_output.step_idx + 1,
|
||||
),
|
||||
model_output.step_idx,
|
||||
)
|
||||
length_cond = paddle.greater_equal(model_output.step_idx, model_output.max_dec_len)
|
||||
paddle.assign(
|
||||
paddle.logical_or(model_output.stop_flags, length_cond),
|
||||
model_output.stop_flags,
|
||||
)
|
||||
set_stop_value_multi_ends(
|
||||
sampled_token_ids,
|
||||
model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.eos_token_id,
|
||||
model_output.next_tokens,
|
||||
False,
|
||||
) # multi ends
|
||||
|
||||
# 2. Update the input buffer of the model
|
||||
with paddle.framework._no_check_dy2st_diff():
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output:
|
||||
update_inputs_v1(
|
||||
model_output.stop_flags,
|
||||
model_output.not_need_stop,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
share_inputs["step_seq_lens_decoder"],
|
||||
share_inputs["prompt_lens"],
|
||||
sampled_token_ids,
|
||||
model_output.input_ids,
|
||||
share_inputs["block_tables"],
|
||||
model_output.stop_nums,
|
||||
model_output.next_tokens,
|
||||
model_output.is_block_step,
|
||||
block_size,
|
||||
)
|
||||
else:
|
||||
update_inputs(
|
||||
model_output.stop_flags,
|
||||
model_output.not_need_stop,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.input_ids,
|
||||
model_output.stop_nums,
|
||||
sampled_token_ids,
|
||||
model_output.is_block_step,
|
||||
)
|
||||
# 3. Transmit the model's output and stop generation signal via message queue.
|
||||
# In the future, we will abandon this approach.
|
||||
if not skip_save_output:
|
||||
save_output(
|
||||
sampled_token_ids,
|
||||
model_output.not_need_stop,
|
||||
model_output.mp_rank,
|
||||
False, # use_ep
|
||||
)
|
||||
|
||||
|
||||
def step_paddle(
|
||||
share_inputs: Dict[str, paddle.Tensor],
|
||||
block_size: int,
|
||||
enc_dec_block_num: int,
|
||||
) -> None:
|
||||
"""
|
||||
TODO(gongshaotian): normalization name
|
||||
"""
|
||||
from fastdeploy.model_executor.ops.xpu import step_paddle
|
||||
|
||||
step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
|
||||
|
||||
class XPUModelRunner(ModelRunnerBase):
|
||||
""" """
|
||||
|
||||
@@ -1212,8 +936,9 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
forward_meta=self.forward_meta,
|
||||
)
|
||||
|
||||
hidden_states = xpu_process_output(model_output, self.share_inputs["cum_offsets"], self.forward_meta)
|
||||
|
||||
hidden_states = xpu_process_output(
|
||||
model_output, self.share_inputs["cum_offsets"], self.forward_meta, self.share_inputs
|
||||
)
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
sampler_output = self.sampler(logits, self.sampling_metadata)
|
||||
@@ -1247,7 +972,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||
)
|
||||
xpu_post_process(
|
||||
xpu_post_process_normal(
|
||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
model_output=model_output_data,
|
||||
share_inputs=self.share_inputs,
|
||||
@@ -1260,7 +985,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
# 7. Updata 'infer_seed' and step_paddle()
|
||||
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
|
||||
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
|
||||
step_paddle(
|
||||
step_xpu(
|
||||
self.share_inputs,
|
||||
self.cache_config.block_size,
|
||||
self.cache_config.enc_dec_block_num,
|
||||
|
||||
Reference in New Issue
Block a user