[Feature][XPU] add custom kernels for mtp (#3537)

This commit is contained in:
lengxia
2025-08-25 10:14:17 +08:00
committed by GitHub
parent bdbac0aa3d
commit 137e539456
93 changed files with 13954 additions and 2 deletions

View File

@@ -0,0 +1,52 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <xft/xdnn_plugin.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_stop_flags) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
int real_bsz = base_model_draft_tokens.shape()[0];
int base_model_draft_token_len = base_model_draft_tokens.shape()[1];
int r = baidu::xpu::api::plugin::draft_model_postprocess(
xpu_ctx->x_context(),
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
const_cast<int*>(base_model_seq_lens_this_time.data<int>()),
const_cast<int*>(base_model_seq_lens_encoder.data<int>()),
const_cast<bool*>(base_model_stop_flags.data<bool>()),
real_bsz,
base_model_draft_token_len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "");
}
PD_BUILD_OP(draft_model_postprocess)
.Inputs({"base_model_draft_tokens",
"base_model_seq_lens_this_time",
"base_model_seq_lens_encoder",
"base_model_stop_flags"})
.Outputs({"base_model_draft_tokens_out",
"base_model_seq_lens_this_time_out",
"base_model_stop_flags_out"})
.SetInplaceMap({{"base_model_draft_tokens", "base_model_draft_tokens_out"},
{"base_model_seq_lens_this_time",
"base_model_seq_lens_this_time_out"},
{"base_model_stop_flags", "base_model_stop_flags_out"}})
.SetKernelFn(PD_KERNEL(DraftModelPostprocess));

View File

@@ -0,0 +1,138 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
namespace api = baidu::xpu::api;
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& seq_lens_encoder_record,
const paddle::Tensor& seq_lens_decoder_record,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& batch_drop,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& base_model_is_block_step,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const bool truncate_first_token,
const bool splitwise_prefill) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
api::Context* ctx = static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
if (draft_tokens.is_cpu()) {
ctx = new api::Context(api::kCPU);
}
int real_bsz = seq_lens_this_time.shape()[0];
int accept_tokens_len = accept_tokens.shape()[1];
int input_ids_len = input_ids.shape()[1];
int draft_tokens_len = draft_tokens.shape()[1];
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
auto not_need_stop_gpu =
not_need_stop.copy_to(seq_lens_this_time.place(), false);
int r = baidu::xpu::api::plugin::draft_model_preprocess(
ctx,
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(seq_lens_encoder_record.data<int>()),
const_cast<int*>(seq_lens_decoder_record.data<int>()),
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<bool*>(batch_drop.data<bool>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
base_model_seq_lens_encoder.data<int>(),
base_model_seq_lens_decoder.data<int>(),
base_model_step_idx.data<int64_t>(),
base_model_stop_flags.data<bool>(),
base_model_is_block_step.data<bool>(),
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
real_bsz,
max_draft_token,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
truncate_first_token,
splitwise_prefill);
PD_CHECK(r == 0, "xpu::plugin::draft_model_preprocess failed.");
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_OP(draft_model_preprocess)
.Inputs({"draft_tokens",
"input_ids",
"stop_flags",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx",
"seq_lens_encoder_record",
"seq_lens_decoder_record",
"not_need_stop",
"batch_drop",
"accept_tokens",
"accept_num",
"base_model_seq_lens_encoder",
"base_model_seq_lens_decoder",
"base_model_step_idx",
"base_model_stop_flags",
"base_model_is_block_step",
"base_model_draft_tokens"})
.Outputs({"draft_tokens_out",
"input_ids_out",
"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"step_idx_out",
"not_need_stop_out",
"batch_drop_out",
"seq_lens_encoder_record_out",
"seq_lens_decoder_record_out"})
.Attrs({"max_draft_token: int",
"truncate_first_token: bool",
"splitwise_prefill: bool"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"input_ids", "input_ids_out"},
{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_idx", "step_idx_out"},
{"not_need_stop", "not_need_stop_out"},
{"batch_drop", "batch_drop_out"},
{"seq_lens_encoder_record", "seq_lens_encoder_record_out"},
{"seq_lens_decoder_record", "seq_lens_decoder_record_out"}})
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));

View File

@@ -0,0 +1,122 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& pre_ids,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& end_ids,
const paddle::Tensor& base_model_draft_tokens,
const int max_seq_len,
const int substep) {
// printf("enter clear \n");
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
baidu::xpu::api::Context* ctx =
static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
if (draft_tokens.is_cpu()) {
ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU);
}
auto seq_lens_this_time_shape = seq_lens_this_time.shape();
const int real_bsz = seq_lens_this_time_shape[0];
auto not_need_stop_device =
not_need_stop.copy_to(seq_lens_this_time.place(), false);
const int end_ids_len = end_ids.shape()[0];
const int max_draft_token = draft_tokens.shape()[1];
const int pre_id_length = pre_ids.shape()[1];
const int max_base_model_draft_token = base_model_draft_tokens.shape()[1];
constexpr int BlockSize = 512;
bool prefill_one_step_stop = false;
if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
int r = baidu::xpu::api::plugin::draft_model_update(
ctx,
inter_next_tokens.data<int64_t>(),
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
const_cast<int64_t*>(pre_ids.data<int64_t>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
output_cum_offsets.data<int>(),
const_cast<bool*>(stop_flags.data<bool>()),
const_cast<bool*>(not_need_stop_device.data<bool>()),
max_dec_len.data<int64_t>(),
end_ids.data<int64_t>(),
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
real_bsz,
max_draft_token,
pre_id_length,
max_base_model_draft_token,
end_ids_len,
max_seq_len,
substep,
prefill_one_step_stop);
PD_CHECK(r == 0, "draft_model_update failed.");
}
PD_BUILD_OP(draft_model_update)
.Inputs({"inter_next_tokens",
"draft_tokens",
"pre_ids",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx",
"output_cum_offsets",
"stop_flags",
"not_need_stop",
"max_dec_len",
"end_ids",
"base_model_draft_tokens"})
.Attrs({"max_seq_len: int", "substep: int"})
.Outputs({"draft_tokens_out",
"pre_ids_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"step_idx_out",
"stop_flags_out",
"not_need_stop_out",
"base_model_draft_tokens_out"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"pre_ids", "pre_ids_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_idx", "step_idx_out"},
{"stop_flags", "stop_flags_out"},
{"not_need_stop", "not_need_stop_out"},
{"base_model_draft_tokens", "base_model_draft_tokens_out"}})
.SetKernelFn(PD_KERNEL(DraftModelUpdate));

View File

@@ -0,0 +1,116 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
namespace api = baidu::xpu::api;
std::vector<paddle::Tensor> EagleGetHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& accept_nums,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const int actual_draft_token_num) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
api::Context* ctx = xpu_ctx->x_context();
if (input.is_cpu()) {
ctx = new api::Context(api::kCPU);
}
auto input_token_num = input.shape()[0];
auto dim_embed = input.shape()[1];
int bsz = seq_lens_this_time.shape()[0];
auto position_map = paddle::full(
{input_token_num}, -1, seq_lens_this_time.dtype(), input.place());
auto output_token_num = paddle::empty(
{1}, seq_lens_this_time.dtype(), seq_lens_this_time.place());
int r = api::plugin::compute_order(ctx,
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
base_model_seq_lens_this_time.data<int>(),
base_model_seq_lens_encoder.data<int>(),
accept_nums.data<int>(),
position_map.data<int>(),
output_token_num.data<int>(),
bsz,
actual_draft_token_num,
input_token_num);
PD_CHECK(r == 0, "xpu::plugin::compute_order failed.");
int output_token_num_cpu =
output_token_num.copy_to(paddle::CPUPlace(), false).data<int>()[0];
auto out = paddle::empty(
{output_token_num_cpu, dim_embed}, input.dtype(), input.place());
int elem_cnt = input_token_num * dim_embed;
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
using XPUTypeBF16 = typename XPUTypeTrait<bfloat16>::Type;
typedef paddle::bfloat16 bf16_data_t;
r = api::plugin::rebuild_hidden_states(
ctx,
reinterpret_cast<const XPUTypeBF16*>(input.data<bf16_data_t>()),
position_map.data<int>(),
reinterpret_cast<XPUTypeBF16*>(out.data<bf16_data_t>()),
dim_embed,
elem_cnt);
PD_CHECK(r == 0, "xpu::plugin::rebuild_hidden_states failed.");
return {out};
case paddle::DataType::FLOAT16:
using XPUTypeFP16 = typename XPUTypeTrait<float16>::Type;
typedef paddle::float16 fp16_data_t;
r = api::plugin::rebuild_hidden_states(
ctx,
reinterpret_cast<const XPUTypeFP16*>(input.data<fp16_data_t>()),
position_map.data<int>(),
reinterpret_cast<XPUTypeFP16*>(out.data<fp16_data_t>()),
dim_embed,
elem_cnt);
PD_CHECK(r == 0, "xpu::plugin::rebuild_hidden_states failed.");
return {out};
case paddle::DataType::FLOAT32:
r = api::plugin::rebuild_hidden_states(
ctx,
reinterpret_cast<const float*>(input.data<float>()),
position_map.data<int>(),
reinterpret_cast<float*>(out.data<float>()),
dim_embed,
elem_cnt);
PD_CHECK(r == 0, "xpu::plugin::rebuild_hidden_states failed.");
return {out};
default:
PD_THROW("Unsupported data type.");
}
}
PD_BUILD_OP(eagle_get_hidden_states)
.Inputs({"input",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"stop_flags",
"accept_nums",
"base_model_seq_lens_this_time",
"base_model_seq_lens_encoder"})
.Attrs({"actual_draft_token_num: int"})
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(EagleGetHiddenStates));

View File

@@ -0,0 +1,104 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
namespace api = baidu::xpu::api;
std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& last_seq_lens_this_time,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& step_idx) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
api::Context* ctx = xpu_ctx->x_context();
if (input.is_cpu()) {
ctx = new api::Context(api::kCPU);
}
int input_token_num = input.shape()[0];
int dim_embed = input.shape()[1];
int bsz = seq_lens_this_time.shape()[0];
auto src_map = paddle::empty({input_token_num},
seq_lens_this_time.dtype(),
seq_lens_this_time.place());
auto output_token_num = paddle::empty(
{1}, seq_lens_this_time.dtype(), seq_lens_this_time.place());
int r = api::plugin::compute_self_order(
ctx,
reinterpret_cast<const int*>(last_seq_lens_this_time.data<int>()),
reinterpret_cast<const int*>(seq_lens_this_time.data<int>()),
reinterpret_cast<const int64_t*>(step_idx.data<int64_t>()),
reinterpret_cast<int*>(src_map.data<int>()),
reinterpret_cast<int*>(output_token_num.data<int>()),
bsz);
PD_CHECK(r == 0, "xpu::plugin::compute_self_order failed.");
int output_token_num_cpu =
output_token_num.copy_to(paddle::CPUPlace(), false).data<int>()[0];
auto out = paddle::empty(
{output_token_num_cpu, dim_embed}, input.type(), input.place());
int elem_cnt = output_token_num_cpu * dim_embed;
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
using XPUTypeBF16 = typename XPUTypeTrait<bfloat16>::Type;
typedef paddle::bfloat16 bf16_data_t;
r = api::plugin::rebuild_self_hidden_states(
ctx,
reinterpret_cast<const XPUTypeBF16*>(input.data<bf16_data_t>()),
src_map.data<int>(),
reinterpret_cast<XPUTypeBF16*>(out.data<bf16_data_t>()),
dim_embed,
elem_cnt);
PD_CHECK(r == 0, "xpu::plugin::rebuild_self_hidden_states failed.");
return {out};
case paddle::DataType::FLOAT16:
using XPUTypeFP16 = typename XPUTypeTrait<float16>::Type;
typedef paddle::float16 fp16_data_t;
r = api::plugin::rebuild_self_hidden_states(
ctx,
reinterpret_cast<const XPUTypeFP16*>(input.data<fp16_data_t>()),
src_map.data<int>(),
reinterpret_cast<XPUTypeFP16*>(out.data<fp16_data_t>()),
dim_embed,
elem_cnt);
PD_CHECK(r == 0, "xpu::plugin::rebuild_self_hidden_states failed.");
return {out};
case paddle::DataType::FLOAT32:
r = api::plugin::rebuild_self_hidden_states(
ctx,
reinterpret_cast<const float*>(input.data<float>()),
src_map.data<int>(),
reinterpret_cast<float*>(out.data<float>()),
dim_embed,
elem_cnt);
PD_CHECK(r == 0, "xpu::plugin::rebuild_self_hidden_states failed.");
return {out};
default:
PD_THROW("Unsupported data type.");
}
}
PD_BUILD_OP(eagle_get_self_hidden_states)
.Inputs(
{"input", "last_seq_lens_this_time", "seq_lens_this_time", "step_idx"})
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(EagleGetSelfHiddenStates));

View File

@@ -0,0 +1,159 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#define MAX_BSZ 256
// #define SAVE_WITH_OUTPUT_DEBUG
#define MAX_DRAFT_TOKENS 6
struct msgdata {
long mtype; // NOLINT
int mtext[2 + MAX_BSZ +
MAX_BSZ * MAX_DRAFT_TOKENS]; // stop_flag, token_num, tokens
};
void MTPSaveFirstToken(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
if (!save_each_rank && rank_id > 0) {
return;
}
int x_dim = x.shape()[1];
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t* x_data = x_cpu.data<int64_t>();
static struct msgdata msg_sed;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
msg_sed.mtype = 1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
int inference_msg_id_from_env = 1;
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is preserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be negative, please use other "
"number.");
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
<< std::endl;
#endif
} else {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
<< std::endl;
#endif
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "save_output_key: " << key << std::endl;
std::cout << "save msgid: " << msgid << std::endl;
#endif
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
: -inference_msg_id_from_env;
int bsz = x.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 0; i < bsz; i++) {
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("bid: %d. 1: %d. 2: %d.\n",
i,
static_cast<int>(x_data[i * x_dim]),
static_cast<int>(x_data[i * x_dim + 1]));
#endif
msg_sed.mtext[i + 2] = 2;
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] =
static_cast<int>(x_data[i * x_dim]);
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] =
static_cast<int>(x_data[i * x_dim + 1]);
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("mtext[%d]:%d. mtext[%d]:%d. \n",
i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ,
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ],
i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ,
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ]);
#endif
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: ";
for (int i = 0; i < bsz; i++) {
std::cout << " " << static_cast<int>(x_data[2 * i]) << " ";
std::cout << " " << static_cast<int>(x_data[2 * i + 1]);
}
std::cout << std::endl;
#endif
if ((msgsnd(msgid,
&msg_sed,
(2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS) * 4,
0)) == -1) {
printf("full msg buffer\n");
}
return;
}
void MTPSaveFirstTokenStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
bool save_each_rank) {
MTPSaveFirstToken(x, not_need_stop, rank_id, 1, save_each_rank);
}
void MTPSaveFirstTokenDynamic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
MTPSaveFirstToken(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
}
PD_BUILD_OP(mtp_save_first_token)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t", "save_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenStatic));
PD_BUILD_OP(mtp_save_first_token_dynamic)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenDynamic));

View File

@@ -0,0 +1,90 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
namespace api = baidu::xpu::api;
void MTPStepPaddle(
const paddle::Tensor &base_model_stop_flags,
const paddle::Tensor &stop_flags,
const paddle::Tensor &batch_drop,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const int block_size,
const int max_draft_tokens) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
api::Context *ctx = xpu_ctx->x_context();
if (base_model_stop_flags.is_cpu()) {
ctx = new api::Context(api::kCPU);
}
const int bsz = seq_lens_this_time.shape()[0];
const int block_num_per_seq = block_tables.shape()[1];
int r = baidu::xpu::api::plugin::mtp_free_and_dispatch_block(
ctx,
const_cast<bool *>(base_model_stop_flags.data<bool>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<bool *>(batch_drop.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(encoder_block_lens.data<int>()),
const_cast<int *>(used_list_len.data<int>()),
const_cast<int *>(free_list.data<int>()),
const_cast<int *>(free_list_len.data<int>()),
bsz,
block_size,
block_num_per_seq,
max_draft_tokens);
PD_CHECK(r == 0, "free_and_dispatch_block failed.");
if (base_model_stop_flags.is_cpu() && ctx != nullptr) {
delete ctx;
}
}
PD_BUILD_OP(mtp_step_paddle)
.Inputs({"base_model_stop_flags",
"stop_flags",
"batch_drop",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"block_tables",
"encoder_block_lens",
"used_list_len",
"free_list",
"free_list_len"})
.Attrs({"block_size: int", "max_draft_tokens: int"})
.Outputs({"block_tables_out",
"stop_flags_out",
"used_list_len_out",
"free_list_out",
"free_list_len_out"})
.SetInplaceMap({{"block_tables", "block_tables_out"},
{"stop_flags", "stop_flags_out"},
{"used_list_len", "used_list_len_out"},
{"free_list", "free_list_out"},
{"free_list_len", "free_list_len_out"}})
.SetKernelFn(PD_KERNEL(MTPStepPaddle));

View File

@@ -0,0 +1,38 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder) {
// printf("enter clear \n");
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int max_bsz = seq_lens_decoder.shape()[0];
int r = baidu::xpu::api::plugin::speculate_clear_accept_nums(
xpu_ctx->x_context(),
const_cast<int*>(accept_num.data<int>()),
seq_lens_decoder.data<int>(),
max_bsz);
PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed.");
}
PD_BUILD_OP(speculate_clear_accept_nums)
.Inputs({"accept_num", "seq_lens_decoder"})
.Outputs({"seq_lens_decoder_out"})
.SetInplaceMap({{"seq_lens_decoder", "seq_lens_decoder_out"}})
.SetKernelFn(PD_KERNEL(SpeculateClearAcceptNums));

View File

@@ -0,0 +1,113 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6
struct msgdata {
int64_t mtype;
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
2]; // stop_flag, bsz, accept_num*bsz, tokens...
};
void SpeculateGetOutput(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id,
bool get_each_rank) {
if (!get_each_rank && rank_id > 0) {
return;
}
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef GET_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static struct msgdata msg_rcv;
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid,
&msg_rcv,
(MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4,
0,
IPC_NOWAIT);
} else {
ret = msgrcv(
msgid, &msg_rcv, (MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4, 0, 0);
}
if (ret == -1) {
out_data[0] = -2;
out_data[1] = 0;
return;
}
int bsz = msg_rcv.mtext[1];
for (int64_t i = 0; i < MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}
return;
}
void SpeculateGetOutputStatic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
bool get_each_rank) {
SpeculateGetOutput(x, rank_id, wait_flag, 1, get_each_rank);
}
void SpeculateGetOutputDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id,
bool get_each_rank) {
SpeculateGetOutput(x, rank_id, wait_flag, msg_queue_id, get_each_rank);
}
PD_BUILD_OP(speculate_get_output)
.Inputs({"x"})
.Attrs({"rank_id: int64_t", "wait_flag: bool", "get_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SpeculateGetOutputStatic));
PD_BUILD_OP(speculate_get_output_dynamic)
.Inputs({"x"})
.Attrs({"rank_id: int64_t",
"wait_flag: bool",
"msg_queue_id: int",
"get_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SpeculateGetOutputDynamic));

View File

@@ -0,0 +1,78 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
const paddle::Tensor& output_cum_offsets_tmp,
const paddle::Tensor& out_token_num,
const paddle::Tensor& seq_lens_output,
const int max_seq_len) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
baidu::xpu::api::Context* ctx =
static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
if (output_cum_offsets_tmp.is_cpu()) {
ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU);
}
std::vector<int64_t> output_cum_offsets_tmp_shape =
output_cum_offsets_tmp.shape();
const int bsz = output_cum_offsets_tmp_shape[0];
auto cpu_out_token_num = out_token_num.copy_to(paddle::CPUPlace(), false);
auto output_padding_offset = paddle::full({cpu_out_token_num},
0,
paddle::DataType::INT32,
output_cum_offsets_tmp.place());
auto output_cum_offsets =
output_cum_offsets_tmp.copy_to(output_cum_offsets_tmp.place(), false);
int r = baidu::xpu::api::plugin::speculate_get_output_padding_offset(
ctx,
output_padding_offset.mutable_data<int>(),
output_cum_offsets.mutable_data<int>(),
output_cum_offsets_tmp.data<int>(),
seq_lens_output.data<int>(),
bsz,
max_seq_len);
PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed.");
return {output_padding_offset, output_cum_offsets};
}
std::vector<std::vector<int64_t>> SpeculateGetOutputPaddingOffsetInferShape(
const std::vector<int64_t>& output_cum_offsets_tmp_shape,
const std::vector<int64_t>& out_token_num_shape,
const std::vector<int64_t>& seq_lens_output_shape) {
int64_t bsz = output_cum_offsets_tmp_shape[0];
return {{-1}, {bsz}};
}
std::vector<paddle::DataType> SpeculateGetOutputPaddingOffsetInferDtype(
const paddle::DataType& output_cum_offsets_tmp_dtype,
const paddle::DataType& out_token_num_dtype,
const paddle::DataType& seq_lens_output_dtype) {
return {output_cum_offsets_tmp_dtype, output_cum_offsets_tmp_dtype};
}
PD_BUILD_OP(speculate_get_output_padding_offset)
.Inputs({"output_cum_offsets_tmp", "out_token_num", "seq_lens_output"})
.Outputs({"output_padding_offset", "output_cum_offsets"})
.Attrs({"max_seq_len: int"})
.SetKernelFn(PD_KERNEL(SpeculateGetOutputPaddingOffset))
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetOutputPaddingOffsetInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetOutputPaddingOffsetInferDtype));

View File

@@ -0,0 +1,127 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
const paddle::Tensor& input_ids,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len,
const paddle::Tensor& seq_lens_encoder) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
const int max_draft_tokens = draft_tokens.shape()[1];
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::empty(
{token_num_data}, paddle::DataType::INT64, input_ids.place());
auto padding_offset = paddle::empty(
{token_num_data}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
PD_CHECK(input_ids.is_contiguous(), "Input ids tensor must be contiguous");
PD_CHECK(draft_tokens.is_contiguous(),
"Draft tokens tensor must be contiguous");
PD_CHECK(cum_offsets.is_contiguous(),
"Cum offsets tensor must be contiguous");
PD_CHECK(seq_len.is_contiguous(), "Seq lens tensor must be contiguous");
int r = baidu::xpu::api::plugin::speculate_get_padding_offset(
xpu_ctx->x_context(),
padding_offset.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length,
bsz);
PD_CHECK(r == 0, "XPU speculate_get_padding_offset failed");
r = baidu::xpu::api::plugin::speculate_remove_padding<int64_t>(
xpu_ctx->x_context(),
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
draft_tokens.data<int64_t>(),
seq_len.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets_out.data<int>(),
seq_length,
max_draft_tokens,
bsz,
token_num_data);
PD_CHECK(r == 0, "XPU speculate_remove_padding failed");
return {x_remove_padding,
cum_offsets_out,
padding_offset,
cu_seqlens_q,
cu_seqlens_k}; // , enc_token_num, dec_token_num};
}
std::vector<std::vector<int64_t>> SpeculateGetPaddingOffsetInferShape(
const std::vector<int64_t>& input_ids_shape,
const std::vector<int64_t>& draft_tokens_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& token_num_shape,
const std::vector<int64_t>& seq_len_shape,
const std::vector<int64_t>& seq_lens_encoder_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
}
std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
const paddle::DataType& input_ids_dtype,
const paddle::DataType& draft_tokens_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& token_num_dtype,
const paddle::DataType& seq_len_dtype,
const paddle::DataType& seq_lens_encoder_dtype) {
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype};
}
PD_BUILD_OP(speculate_get_padding_offset)
.Inputs({"input_ids",
"draft_tokens",
"cum_offsets",
"token_num",
"seq_len",
"seq_lens_encoder"})
.Outputs({"x_remove_padding",
"cum_offsets_out",
"padding_offset",
"cu_seqlens_q",
"cu_seqlens_k"})
.SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset))
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetPaddingOffsetInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetPaddingOffsetInferDtype));

View File

@@ -0,0 +1,69 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
baidu::xpu::api::Context* ctx =
static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
if (seq_lens_this_time.is_cpu()) {
ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU);
}
std::vector<int64_t> seq_lens_this_time_shape = seq_lens_this_time.shape();
const int bsz = seq_lens_this_time_shape[0];
auto seq_lens_output = paddle::full(
{bsz}, 0, paddle::DataType::INT32, seq_lens_this_time.place());
int r = baidu::xpu::api::plugin::speculate_get_seq_lens_output(
ctx,
seq_lens_output.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
bsz);
PD_CHECK(r == 0, "speculate_get_seq_lens_output failed.");
return {seq_lens_output};
}
std::vector<std::vector<int64_t>> SpeculateGetSeqLensOutputInferShape(
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape) {
int64_t bsz = seq_lens_this_time_shape[0];
return {{bsz}};
}
std::vector<paddle::DataType> SpeculateGetSeqLensOutputInferDtype(
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype) {
return {seq_lens_this_time_dtype};
}
PD_BUILD_OP(speculate_get_seq_lens_output)
.Inputs({"seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder"})
.Outputs({"seq_lens_output"})
.SetKernelFn(PD_KERNEL(SpeculateGetSeqLensOutput))
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetSeqLensOutputInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetSeqLensOutputInferDtype));

View File

@@ -0,0 +1,31 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6
struct speculate_msgdata {
long mtype; // NOLINT
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
2]; // stop_flag, bsz, tokens
};

View File

@@ -0,0 +1,130 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
namespace api = baidu::xpu::api;
std::vector<paddle::Tensor> RebuildAppendPadding(
const paddle::Tensor& full_hidden_states,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& output_padding_offset,
int max_seq_len) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
api::Context* ctx = xpu_ctx->x_context();
if (full_hidden_states.is_cpu()) {
ctx = new api::Context(api::kCPU);
}
int dim_embed = full_hidden_states.shape()[1];
int output_token_num = output_padding_offset.shape()[0];
int elem_nums = output_token_num * dim_embed;
auto out = paddle::full({output_token_num, dim_embed},
0,
full_hidden_states.dtype(),
full_hidden_states.place());
int r;
switch (full_hidden_states.dtype()) {
case paddle::DataType::BFLOAT16:
using XPUTypeBF16 = typename XPUTypeTrait<bfloat16>::Type;
typedef paddle::bfloat16 bf16_data_t;
r = api::plugin::speculate_rebuild_append_padding<XPUTypeBF16>(
ctx,
const_cast<XPUTypeBF16*>(reinterpret_cast<const XPUTypeBF16*>(
full_hidden_states.data<bf16_data_t>())),
const_cast<int*>(cum_offsets.data<int>()),
const_cast<int*>(seq_len_encoder.data<int>()),
const_cast<int*>(seq_len_decoder.data<int>()),
const_cast<int*>(output_padding_offset.data<int>()),
max_seq_len,
dim_embed,
elem_nums,
reinterpret_cast<XPUTypeBF16*>(out.data<bf16_data_t>()));
PD_CHECK(r == 0, "xpu::plugin::speculate_rebuild_append_padding failed.");
return {out};
case paddle::DataType::FLOAT16:
using XPUTypeFP16 = typename XPUTypeTrait<float16>::Type;
typedef paddle::float16 fp16_data_t;
r = api::plugin::speculate_rebuild_append_padding<XPUTypeFP16>(
ctx,
const_cast<XPUTypeFP16*>(reinterpret_cast<const XPUTypeFP16*>(
full_hidden_states.data<fp16_data_t>())),
const_cast<int*>(cum_offsets.data<int>()),
const_cast<int*>(seq_len_encoder.data<int>()),
const_cast<int*>(seq_len_decoder.data<int>()),
const_cast<int*>(output_padding_offset.data<int>()),
max_seq_len,
dim_embed,
elem_nums,
reinterpret_cast<XPUTypeFP16*>(out.data<fp16_data_t>()));
PD_CHECK(r == 0, "xpu::plugin::speculate_rebuild_append_padding failed.");
return {out};
case paddle::DataType::FLOAT32:
r = api::plugin::speculate_rebuild_append_padding<float>(
ctx,
const_cast<float*>(full_hidden_states.data<float>()),
const_cast<int*>(cum_offsets.data<int>()),
const_cast<int*>(seq_len_encoder.data<int>()),
const_cast<int*>(seq_len_decoder.data<int>()),
const_cast<int*>(output_padding_offset.data<int>()),
max_seq_len,
dim_embed,
elem_nums,
out.data<float>());
PD_CHECK(r == 0, "xpu::plugin::speculate_rebuild_append_padding failed.");
return {out};
default:
PD_THROW("Unsupported data type.");
}
}
std::vector<std::vector<int64_t>> RebuildAppendPaddingInferShape(
const std::vector<int64_t>& full_hidden_states_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& seq_len_encoder_shape,
const std::vector<int64_t>& seq_len_decoder_shape,
const std::vector<int64_t>& output_padding_offset_shape) {
const int64_t output_token_num = output_padding_offset_shape[0];
const int64_t dim_embed = full_hidden_states_shape[1];
std::vector<int64_t> out_shape = {output_token_num, dim_embed};
return {out_shape};
}
std::vector<paddle::DataType> RebuildAppendPaddingInferDtype(
const paddle::DataType& full_hidden_states_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& seq_len_encoder_dtype,
const paddle::DataType& seq_len_decoder_dtype,
const paddle::DataType& output_padding_offset_dtype) {
return {full_hidden_states_dtype};
}
PD_BUILD_OP(speculate_rebuild_append_padding)
.Inputs({"full_hidden_states",
"cum_offsets",
"seq_len_encoder",
"seq_len_decoder",
"output_padding_offset"})
.Attrs({"max_seq_len: int"})
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(RebuildAppendPadding))
.SetInferShapeFn(PD_INFER_SHAPE(RebuildAppendPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildAppendPaddingInferDtype));

View File

@@ -0,0 +1,162 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6
struct msgdata {
long mtype; // NOLINT
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
2]; // stop_flag, bsz, tokens
};
void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
int save_each_rank) {
// printf("enter save output");
if (!save_each_rank && rank_id > 0) {
return;
}
int max_draft_tokens = accept_tokens.shape()[1];
auto accept_tokens_cpu = accept_tokens.copy_to(paddle::CPUPlace(), true);
auto accept_num_cpu = accept_num.copy_to(paddle::CPUPlace(), true);
int64_t* accept_tokens_data = accept_tokens_cpu.data<int64_t>();
int* accept_num_data = accept_num_cpu.data<int>();
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef GET_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static struct msgdata msg_sed;
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
msg_sed.mtype = 1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
int inference_msg_id_from_env = 1;
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is preserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be negative, please use other "
"number.");
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
<< std::endl;
#endif
} else {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
<< std::endl;
#endif
}
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
: -inference_msg_id_from_env;
int bsz = accept_tokens.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 2; i < MAX_BSZ + 2; i++) {
if (i - 2 >= bsz) {
msg_sed.mtext[i] = 0;
} else {
msg_sed.mtext[i] = static_cast<int>(accept_num_data[i - 2]);
}
}
for (int i = MAX_BSZ + 2; i < MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2; i++) {
int token_id = i - MAX_BSZ - 2;
int bid = token_id / MAX_DRAFT_TOKENS;
int local_token_id = token_id % MAX_DRAFT_TOKENS;
if (token_id / MAX_DRAFT_TOKENS >= bsz) {
msg_sed.mtext[i] = 0;
} else {
msg_sed.mtext[i] =
accept_tokens_data[bid * max_draft_tokens + local_token_id];
}
}
if ((msgsnd(msgid,
&msg_sed,
(MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4,
0)) == -1) {
printf("full msg buffer\n");
}
return;
}
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
bool save_each_rank) {
SpeculateSaveWithOutputMsg(
accept_tokens, accept_num, not_need_stop, rank_id, 1, save_each_rank);
}
void SpeculateSaveWithOutputMsgDynamic(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
SpeculateSaveWithOutputMsg(accept_tokens,
accept_num,
not_need_stop,
rank_id,
msg_queue_id,
save_each_rank);
}
PD_BUILD_OP(speculate_save_output)
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
.Attrs({"rank_id: int64_t", "save_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"accept_tokens", "x_out"}})
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgStatic));
PD_BUILD_OP(speculate_save_output_dynamic)
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"accept_tokens", "x_out"}})
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgDynamic));

View File

@@ -0,0 +1,80 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
namespace api = baidu::xpu::api;
void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &end_ids) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
api::Context *ctx =
static_cast<const phi::XPUContext *>(dev_ctx)->x_context();
if (accept_tokens.is_cpu()) {
ctx = new api::Context(api::kCPU);
}
PD_CHECK(accept_tokens.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
std::vector<int64_t> shape = accept_tokens.shape();
std::vector<int64_t> stop_seqs_shape = stop_seqs.shape();
int bs_now = shape[0];
int stop_seqs_bs = stop_seqs_shape[0];
int stop_seqs_max_len = stop_seqs_shape[1];
int pre_ids_len = pre_ids.shape()[1];
int accept_tokens_len = accept_tokens.shape()[1];
int r = baidu::xpu::api::plugin::speculate_set_stop_value_multi_seqs(
ctx,
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
stop_seqs.data<int64_t>(),
stop_seqs_len.data<int>(),
seq_lens.data<int>(),
end_ids.data<int64_t>(),
bs_now,
accept_tokens_len,
stop_seqs_bs,
stop_seqs_max_len,
pre_ids_len);
PD_CHECK(r == 0, "xpu::plugin::speculate_set_stop_value_multi_seqs failed.");
}
PD_BUILD_OP(speculate_set_stop_value_multi_seqs)
.Inputs({"accept_tokens",
"accept_num",
"pre_ids",
"step_idx",
"stop_flags",
"seq_lens",
"stop_seqs",
"stop_seqs_len",
"end_ids"})
.Outputs({"accept_tokens_out", "stop_flags_out"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"stop_flags", "stop_flags_out"}})
.SetKernelFn(PD_KERNEL(SpecGetStopFlagsMultiSeqs));

View File

@@ -0,0 +1,67 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_idx) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
baidu::xpu::api::Context *ctx =
static_cast<const phi::XPUContext *>(dev_ctx)->x_context();
// auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
if (pre_ids_all.is_cpu()) {
ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU);
}
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
int bs = seq_lens_this_time.shape()[0];
int length = pre_ids_all_shape[1];
int max_draft_tokens = accept_tokens.shape()[1];
int r = baidu::xpu::api::plugin::speculate_set_value_by_flag_and_id(
ctx,
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
stop_flags.data<bool>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
step_idx.data<int64_t>(),
bs,
length,
max_draft_tokens);
PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed.");
}
PD_BUILD_OP(speculate_set_value_by_flags_and_idx)
.Inputs({"pre_ids_all",
"accept_tokens",
"accept_num",
"stop_flags",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx"})
.Outputs({"pre_ids_all_out"})
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
.SetKernelFn(PD_KERNEL(SpeculateSetValueByFlagsAndIdx));

View File

@@ -0,0 +1,216 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/phi/core/enforce.h"
#include "speculate_msg.h" // NOLINT
#include "xpu/plugin.h"
// 为不修改接口调用方式,入参暂不改变
void SpeculateStepSchedule(
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &ori_seq_lens_encoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &step_block_list,
const paddle::Tensor &step_lens,
const paddle::Tensor &recover_block_list,
const paddle::Tensor &recover_lens,
const paddle::Tensor &need_block_list,
const paddle::Tensor &need_block_len,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const paddle::Tensor &input_ids,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &next_tokens,
const paddle::Tensor &first_token_ids,
const paddle::Tensor &accept_num,
const int block_size,
const int encoder_decoder_block_num,
const int max_draft_tokens) {
namespace api = baidu::xpu::api;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
api::Context *ctx = xpu_ctx->x_context();
if (stop_flags.is_cpu()) {
ctx = new api::Context(api::kCPU);
}
const int bsz = seq_lens_this_time.shape()[0];
const int block_num_per_seq = block_tables.shape()[1];
const int length = input_ids.shape()[1];
const int pre_id_length = pre_ids.shape()[1];
constexpr int BlockSize = 256; // bsz <= 256
const int max_decoder_block_num =
length / block_size -
encoder_decoder_block_num; // 最大输出长度对应的block -
// 服务为解码分配的block数量
auto step_lens_inkernel =
paddle::full({1}, 0, paddle::DataType::INT32, stop_flags.place());
auto step_bs_list =
paddle::full({bsz}, 0, paddle::DataType::INT32, stop_flags.place());
int r = baidu::xpu::api::plugin::speculate_free_and_reschedule(
ctx,
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<int *>(encoder_block_lens.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<int *>(step_bs_list.data<int>()),
const_cast<int *>(step_lens_inkernel.data<int>()),
const_cast<int *>(recover_block_list.data<int>()),
const_cast<int *>(recover_lens.data<int>()),
const_cast<int *>(need_block_list.data<int>()),
const_cast<int *>(need_block_len.data<int>()),
const_cast<int *>(used_list_len.data<int>()),
const_cast<int *>(free_list.data<int>()),
const_cast<int *>(free_list_len.data<int>()),
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
PD_CHECK(r == 0, "speculate_free_and_reschedule failed.");
// save output
auto step_lens_cpu = step_lens_inkernel.copy_to(paddle::CPUPlace(), false);
if (step_lens_cpu.data<int>()[0] > 0) {
auto step_bs_list_cpu = step_bs_list.copy_to(paddle::CPUPlace(), false);
auto next_tokens =
paddle::full({bsz}, -1, paddle::DataType::INT64, paddle::CPUPlace());
for (int i = 0; i < step_lens_cpu.data<int>()[0]; i++) {
const int step_bid = step_bs_list_cpu.data<int>()[i];
next_tokens.data<int64_t>()[step_bid] = -3; // need reschedule
}
const int rank_id = static_cast<int>(stop_flags.place().GetDeviceId());
printf("reschedule rank_id: %d, step_lens: %d",
rank_id,
step_lens_cpu.data<int>()[0]);
const int64_t *x_data = next_tokens.data<int64_t>();
static struct speculate_msgdata msg_sed;
int msg_queue_id = rank_id;
if (const char *inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
msg_queue_id = inference_msg_queue_id_from_env;
} else {
std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default."
<< std::endl;
}
int inference_msg_id_from_env = 1;
if (const char *inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is perserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be negative, please use other "
"number.");
}
} else {
}
// static key_t key = ftok("/dev/shm", msg_queue_id);
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
msg_sed.mtype = 1;
msg_sed.mtext[0] = inference_msg_id_from_env;
msg_sed.mtext[1] = bsz;
for (int i = 2; i < bsz + 2; i++) {
msg_sed.mtext[i] = static_cast<int>(x_data[i - 2]);
}
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) {
printf("full msg buffer\n");
}
}
}
PD_BUILD_OP(speculate_step_reschedule)
.Inputs({"stop_flags",
"seq_lens_this_time",
"ori_seq_lens_encoder",
"seq_lens_encoder",
"seq_lens_decoder",
"block_tables",
"encoder_block_lens",
"is_block_step",
"step_block_list",
"step_lens",
"recover_block_list",
"recover_lens",
"need_block_list",
"need_block_len",
"used_list_len",
"free_list",
"free_list_len",
"input_ids",
"pre_ids",
"step_idx",
"next_tokens",
"first_token_ids",
"accept_num"})
.Attrs({"block_size: int",
"encoder_decoder_block_num: int",
"max_draft_tokens: int"})
.Outputs({"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"block_tables_out",
"encoder_block_lens_out",
"is_block_step_out",
"step_block_list_out",
"step_lens_out",
"recover_block_list_out",
"recover_lens_out",
"need_block_list_out",
"need_block_len_out",
"used_list_len_out",
"free_list_out",
"free_list_len_out",
"input_ids_out",
"first_token_ids_out"})
.SetInplaceMap({{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"block_tables", "block_tables_out"},
{"encoder_block_lens", "encoder_block_lens_out"},
{"is_block_step", "is_block_step_out"},
{"step_block_list", "step_block_list_out"},
{"step_lens", "step_lens_out"},
{"recover_block_list", "recover_block_list_out"},
{"recover_lens", "recover_lens_out"},
{"need_block_list", "need_block_list_out"},
{"need_block_len", "need_block_len_out"},
{"used_list_len", "used_list_len_out"},
{"free_list", "free_list_out"},
{"free_list_len", "free_list_len_out"},
{"input_ids", "input_ids_out"},
{"first_token_ids", "first_token_ids_out"}})
.SetKernelFn(PD_KERNEL(SpeculateStepSchedule));

View File

@@ -0,0 +1,157 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
void TokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_scores,
const paddle::Tensor& presence_scores,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
const int max_seq_len) {
namespace api = baidu::xpu::api;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
api::Context* ctx = xpu_ctx->x_context();
if (pre_ids.is_cpu()) {
ctx = new api::Context(api::kCPU);
}
int64_t bs = seq_lens_this_time.shape()[0];
int64_t token_num = logits.shape()[0];
PADDLE_ENFORCE_LE(bs,
640,
phi::errors::InvalidArgument(
"Only support bs <= 640, but received bsz is %d", bs));
int64_t length = logits.shape()[1];
int64_t length_id = pre_ids.shape()[1];
int64_t length_bad_words = bad_tokens.shape()[0];
int64_t end_length = eos_token_id.shape()[0];
switch (logits.type()) {
case paddle::DataType::BFLOAT16: {
using XPUType = typename XPUTypeTrait<paddle::bfloat16>::Type;
typedef paddle::bfloat16 data_t;
int r = baidu::xpu::api::plugin::speculate_token_penalty_multi_scores(
ctx,
pre_ids.data<int64_t>(),
reinterpret_cast<XPUType*>(
const_cast<data_t*>(logits.data<data_t>())),
reinterpret_cast<const XPUType*>(penalty_scores.data<data_t>()),
reinterpret_cast<const XPUType*>(frequency_scores.data<data_t>()),
reinterpret_cast<const XPUType*>(presence_scores.data<data_t>()),
temperatures.data<float>(),
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bad_tokens.data<int64_t>(),
output_padding_offset.data<int>(),
output_cum_offsets.data<int>(),
bs,
length,
length_id,
end_length,
length_bad_words,
token_num,
max_seq_len);
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
} break;
case paddle::DataType::FLOAT16: {
using XPUType = typename XPUTypeTrait<float16>::Type;
typedef paddle::float16 data_t;
int r = baidu::xpu::api::plugin::speculate_token_penalty_multi_scores(
ctx,
pre_ids.data<int64_t>(),
reinterpret_cast<XPUType*>(
const_cast<data_t*>(logits.data<data_t>())),
reinterpret_cast<const XPUType*>(penalty_scores.data<data_t>()),
reinterpret_cast<const XPUType*>(frequency_scores.data<data_t>()),
reinterpret_cast<const XPUType*>(presence_scores.data<data_t>()),
temperatures.data<float>(),
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bad_tokens.data<int64_t>(),
output_padding_offset.data<int>(),
output_cum_offsets.data<int>(),
bs,
length,
length_id,
end_length,
length_bad_words,
token_num,
max_seq_len);
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
} break;
case paddle::DataType::FLOAT32: {
int r = baidu::xpu::api::plugin::speculate_token_penalty_multi_scores(
ctx,
pre_ids.data<int64_t>(),
const_cast<float*>(logits.data<float>()),
penalty_scores.data<float>(),
frequency_scores.data<float>(),
presence_scores.data<float>(),
temperatures.data<float>(),
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bad_tokens.data<int64_t>(),
output_padding_offset.data<int>(),
output_cum_offsets.data<int>(),
bs,
length,
length_id,
end_length,
length_bad_words,
token_num,
max_seq_len);
PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed.");
} break;
default:
PD_THROW(
"NOT supported data type. "
"Only float16 and float32 are supported. ");
break;
}
}
PD_BUILD_OP(speculate_get_token_penalty_multi_scores)
.Inputs({"pre_ids",
"logits",
"penalty_scores",
"frequency_scores",
"presence_scores",
"temperatures",
"bad_tokens",
"cur_len",
"min_len",
"eos_token_id",
"seq_lens_this_time",
"output_padding_offset",
"output_cum_offsets"})
.Outputs({"logits_out"})
.Attrs({"max_seq_len: int"})
.SetInplaceMap({{"logits", "logits_out"}})
.SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));

View File

@@ -0,0 +1,38 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
void UpdateInputIdsCPU(const paddle::Tensor& input_ids_cpu,
const std::vector<int64_t>& task_input_ids,
const int bid,
const int max_seq_len) {
int64_t* input_ids_cpu_data =
const_cast<int64_t*>(input_ids_cpu.data<int64_t>());
// printf("Input len is %d\n", task_input_ids.size());
for (int i = 0; i < task_input_ids.size(); i++) {
// printf("%lld\n", task_input_ids[i]);
input_ids_cpu_data[bid * max_seq_len + i] = task_input_ids[i];
}
}
PD_BUILD_OP(speculate_update_input_ids_cpu)
.Inputs({"input_ids_cpu"})
.Outputs({"input_ids_cpu_out"})
.Attrs({"task_input_ids: std::vector<int64_t>",
"bid: int",
"max_seq_len: int"})
.SetInplaceMap({{"input_ids_cpu", "input_ids_cpu_out"}})
.SetKernelFn(PD_KERNEL(UpdateInputIdsCPU));

View File

@@ -0,0 +1,91 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
namespace api = baidu::xpu::api;
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &is_block_step,
const paddle::Tensor &stop_nums) {
const int real_bsz = seq_lens_this_time.shape()[0];
const int max_bsz = stop_flags.shape()[0];
auto max_draft_tokens = draft_tokens.shape()[1];
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
api::Context *ctx =
static_cast<const phi::XPUContext *>(dev_ctx)->x_context();
if (draft_tokens.is_cpu()) {
ctx = new api::Context(api::kCPU);
}
auto not_need_stop_xpu = not_need_stop.copy_to(stop_flags.place(), false);
int r = baidu::xpu::api::plugin::speculate_update_v3(
ctx,
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<bool *>(not_need_stop_xpu.data<bool>()),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int *>(actual_draft_token_nums.data<int>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
stop_flags.data<bool>(),
seq_lens_this_time.data<int>(),
is_block_step.data<bool>(),
stop_nums.data<int64_t>(),
real_bsz,
max_bsz,
max_draft_tokens);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "speculate_update_v3");
auto not_need_stop_cpu =
not_need_stop_xpu.copy_to(not_need_stop.place(), true);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_OP(speculate_update_v3)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"not_need_stop",
"draft_tokens",
"actual_draft_token_nums",
"accept_tokens",
"accept_num",
"stop_flags",
"seq_lens_this_time",
"is_block_step",
"stop_nums"})
.Outputs({"seq_lens_encoder_out",
"seq_lens_decoder_out",
"not_need_stop_out",
"draft_tokens_out",
"actual_draft_token_nums_out"})
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"not_need_stop", "not_need_stop_out"},
{"draft_tokens", "draft_tokens_out"},
{"actual_draft_token_nums", "actual_draft_token_nums_out"}})
.SetKernelFn(PD_KERNEL(SpeculateUpdateV3));

View File

@@ -0,0 +1,251 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <stdio.h>
#include "paddle/common/flags.h"
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "ops/utility/debug.h"
#include "xpu/internal/infra_op.h"
#include "xpu/plugin.h"
namespace api = baidu::xpu::api;
void SpeculateVerify(const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &verify_tokens,
const paddle::Tensor &verify_scores,
const paddle::Tensor &max_dec_len,
const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &topp,
int max_seq_len,
int verify_window,
bool enable_topp) {
auto bsz = accept_tokens.shape()[0];
int real_bsz = seq_lens_this_time.shape()[0];
auto max_draft_tokens = draft_tokens.shape()[1];
auto end_length = end_tokens.shape()[0];
auto max_candidate_len = verify_tokens.shape()[1];
constexpr int BlockSize = 512;
// set topp_seed if needed
const paddle::optional<paddle::Tensor> &topp_seed = nullptr;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
api::Context *ctx =
static_cast<const phi::XPUContext *>(dev_ctx)->x_context();
bool xpu_ctx_flag = true;
if (draft_tokens.is_cpu()) {
ctx = new api::Context(api::kCPU);
xpu_ctx_flag = false;
}
// phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
// auto dev_ctx =
// paddle::experimental::DeviceContextPool::Instance().Get(place); auto
// xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
bool use_topk = false;
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
if (env_var) {
use_topk = static_cast<bool>(std::stoi(env_var));
}
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
// random
int random_seed = 0;
std::vector<int64_t> infer_seed(bsz, random_seed);
std::uniform_real_distribution<float> dist(0.0, 1.0);
std::vector<float> dev_curand_states_cpu;
for (int i = 0; i < bsz; i++) {
std::mt19937_64 engine(infer_seed[i]);
dev_curand_states_cpu.push_back(dist(engine));
}
float *dev_curand_states_xpu;
if (xpu_ctx_flag) {
xpu::ctx_guard RAII_GUARD(ctx);
dev_curand_states_xpu =
RAII_GUARD.alloc<float>(dev_curand_states_cpu.size());
xpu_memcpy(dev_curand_states_xpu,
dev_curand_states_cpu.data(),
dev_curand_states_cpu.size() * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
}
auto dev_curand_states =
!xpu_ctx_flag ? dev_curand_states_cpu.data() : dev_curand_states_xpu;
if (use_topk) {
if (enable_topp) {
baidu::xpu::api::plugin::speculate_verify<true, true>(
ctx,
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
} else {
baidu::xpu::api::plugin::speculate_verify<false, true>(
ctx,
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
}
} else {
if (enable_topp) {
baidu::xpu::api::plugin::speculate_verify<true, false>(
ctx,
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
} else {
baidu::xpu::api::plugin::speculate_verify<false, false>(
ctx,
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
}
}
}
PD_BUILD_OP(speculate_verify)
.Inputs({"accept_tokens",
"accept_num",
"step_idx",
"stop_flags",
"seq_lens_encoder",
"seq_lens_decoder",
"draft_tokens",
"seq_lens_this_time",
"verify_tokens",
"verify_scores",
"max_dec_len",
"end_tokens",
"is_block_step",
"output_cum_offsets",
"actual_candidate_len",
"actual_draft_token_nums",
"topp"})
.Outputs({"accept_tokens_out",
"accept_num_out",
"step_idx_out",
"stop_flags_out"})
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"accept_num", "accept_num_out"},
{"step_idx", "step_idx_out"},
{"stop_flags", "stop_flags_out"}})
.SetKernelFn(PD_KERNEL(SpeculateVerify));

View File

@@ -0,0 +1,158 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
#define FIXED_TOPK_BASE(topk, ...) \
case (topk): { \
constexpr auto kTopK = topk; \
__VA_ARGS__; \
} break
#define FIXED_TOPK(...) \
FIXED_TOPK_BASE(2, ##__VA_ARGS__); \
FIXED_TOPK_BASE(3, ##__VA_ARGS__); \
FIXED_TOPK_BASE(4, ##__VA_ARGS__); \
FIXED_TOPK_BASE(5, ##__VA_ARGS__); \
FIXED_TOPK_BASE(8, ##__VA_ARGS__); \
FIXED_TOPK_BASE(10, ##__VA_ARGS__); \
default: { \
PD_THROW("Unsupported candidates_len."); \
}
namespace api = baidu::xpu::api;
std::vector<paddle::Tensor> TopPCandidates(
const paddle::Tensor& probs,
const paddle::Tensor& top_p,
const paddle::Tensor& output_padding_offset,
int candidates_len,
int max_seq_len) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
api::Context* ctx = xpu_ctx->x_context();
if (probs.is_cpu()) {
ctx = new api::Context(api::kCPU);
}
std::vector<int64_t> input_shape = probs.shape();
const int token_num = input_shape[0];
const int vocab_size = input_shape[1];
auto verify_scores =
paddle::empty({token_num, candidates_len}, probs.dtype(), probs.place());
auto verify_tokens = paddle::empty(
{token_num, candidates_len}, paddle::DataType::INT64, probs.place());
auto actual_candidate_lens =
paddle::empty({token_num}, paddle::DataType::INT32, probs.place());
constexpr int TopKMaxLength = 2;
int r;
switch (probs.dtype()) {
case paddle::DataType::BFLOAT16:
using XPUTypeBF16 = typename XPUTypeTrait<bfloat16>::Type;
typedef paddle::bfloat16 bf16_data_t;
switch (candidates_len) {
FIXED_TOPK(
r = api::plugin::top_p_candidates<XPUTypeBF16,
TopKMaxLength,
kTopK>(
ctx,
reinterpret_cast<const XPUTypeBF16*>(probs.data<bf16_data_t>()),
reinterpret_cast<const XPUTypeBF16*>(top_p.data<bf16_data_t>()),
output_padding_offset.data<int>(),
verify_tokens.data<int64_t>(),
reinterpret_cast<XPUTypeBF16*>(
verify_scores.data<bf16_data_t>()),
actual_candidate_lens.data<int>(),
vocab_size,
token_num,
candidates_len,
max_seq_len);
PD_CHECK(r == 0, "xpu::plugin::top_p_candidates failed.");
return {verify_scores, verify_tokens, actual_candidate_lens});
}
case paddle::DataType::FLOAT16:
using XPUTypeFP16 = typename XPUTypeTrait<float16>::Type;
typedef paddle::float16 fp16_data_t;
switch (candidates_len) {
FIXED_TOPK(
r = api::plugin::top_p_candidates<XPUTypeFP16,
TopKMaxLength,
kTopK>(
ctx,
reinterpret_cast<const XPUTypeFP16*>(probs.data<fp16_data_t>()),
reinterpret_cast<const XPUTypeFP16*>(top_p.data<fp16_data_t>()),
output_padding_offset.data<int>(),
verify_tokens.data<int64_t>(),
reinterpret_cast<XPUTypeFP16*>(
verify_scores.data<fp16_data_t>()),
actual_candidate_lens.data<int>(),
vocab_size,
token_num,
candidates_len,
max_seq_len);
PD_CHECK(r == 0, "xpu::plugin::top_p_candidates failed.");
return {verify_scores, verify_tokens, actual_candidate_lens});
}
case paddle::DataType::FLOAT32:
switch (candidates_len) {
FIXED_TOPK(
r = api::plugin::top_p_candidates<float, TopKMaxLength, kTopK>(
ctx,
probs.data<float>(),
top_p.data<float>(),
output_padding_offset.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
actual_candidate_lens.data<int>(),
vocab_size,
token_num,
candidates_len,
max_seq_len);
PD_CHECK(r == 0, "xpu::plugin::top_p_candidates failed.");
return {verify_scores, verify_tokens, actual_candidate_lens});
}
default:
PD_THROW("Unsupported data type.");
}
}
std::vector<std::vector<int64_t>> TopPCandidatesInferShape(
const std::vector<int64_t>& probs_shape,
const std::vector<int64_t>& top_p_shape,
const std::vector<int64_t>& output_padding_offset_shape,
int max_candidates_len) {
int token_num = probs_shape[0];
return {{token_num, max_candidates_len},
{token_num, max_candidates_len},
{token_num}};
}
std::vector<paddle::DataType> TopPCandidatesInferDtype(
const paddle::DataType& probs_dtype,
const paddle::DataType& top_p_dtype,
const paddle::DataType& output_padding_offset_dtype) {
return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32};
}
PD_BUILD_OP(top_p_candidates)
.Inputs({"probs", "top_p", "output_padding_offset"})
.Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"})
.Attrs({"candidates_len: int", "max_seq_len: int"})
.SetKernelFn(PD_KERNEL(TopPCandidates))
.SetInferShapeFn(PD_INFER_SHAPE(TopPCandidatesInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(TopPCandidatesInferDtype));

View File

@@ -0,0 +1,95 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <string>
#include <vector>
#include "paddle/extension.h"
namespace paddle {
std::string string_format(const std::string fmt_str, ...);
template <typename T = std::string>
static T string_parse(const std::string& v) {
return v;
}
template <>
int32_t string_parse<int32_t>(const std::string& v) {
return std::stoi(v);
}
template <>
int64_t string_parse<int64_t>(const std::string& v) {
return std::stoll(v);
}
template <>
float string_parse<float>(const std::string& v) {
return std::stof(v);
}
template <>
double string_parse<double>(const std::string& v) {
return std::stod(v);
}
template <>
bool string_parse<bool>(const std::string& v) {
std::string upper;
for (size_t i = 0; i < v.length(); i++) {
char ch = v[i];
if (ch >= 'a' && ch <= 'z') {
ch = ch - 'a' + 'A';
}
upper.push_back(ch);
}
return upper == "TRUE" || upper == "1";
}
template <class T = std::string>
static std::vector<T> string_split(const std::string& original,
const std::string& separator) {
std::vector<T> results;
std::string::size_type pos1, pos2;
pos2 = original.find(separator);
pos1 = 0;
while (std::string::npos != pos2) {
results.push_back(string_parse<T>(original.substr(pos1, pos2 - pos1)));
pos1 = pos2 + separator.size();
pos2 = original.find(separator, pos1);
}
if (pos1 != original.length()) {
results.push_back(string_parse<T>(original.substr(pos1)));
}
return results;
}
std::string shape_to_string(const std::vector<int64_t>& shape);
template <typename T>
void DebugPrintXPUTensor(const phi::XPUContext* xpu_ctx,
const paddle::Tensor& input,
std::string tag = "",
int len = 1);
template <typename T>
void DebugPrintXPUTensorv2(const paddle::Tensor& input,
std::string tag = "",
int len = 1);
} // namespace paddle

View File

@@ -174,6 +174,10 @@ macro(
separate_arguments(arg_device_o_extra_flags) separate_arguments(arg_device_o_extra_flags)
set(arg_host_o_extra_flags ${host_o_extra_flags}) set(arg_host_o_extra_flags ${host_o_extra_flags})
separate_arguments(arg_host_o_extra_flags) separate_arguments(arg_host_o_extra_flags)
set(MTP_KERNEL_COMPILE_FLAGS "")
if(${kernel_path} MATCHES "mtp_kernel")
list(APPEND MTP_KERNEL_COMPILE_FLAGS -mllvm -fix-mfence=all)
endif()
add_custom_command( add_custom_command(
OUTPUT ${kernel_name}.device.bin.o ${kernel_name}.o OUTPUT ${kernel_name}.device.bin.o ${kernel_name}.o
@@ -181,8 +185,8 @@ macro(
${XPU_CLANG} -std=c++11 ${OPT_LEVEL} ${arg_device_o_extra_flags} -c ${XPU_CLANG} -std=c++11 ${OPT_LEVEL} ${arg_device_o_extra_flags} -c
${kernel_path} -D ${xpu_n_macro} --target=${TARGET_ARCH} ${HOST_XPU_FLAGS} ${kernel_path} -D ${xpu_n_macro} --target=${TARGET_ARCH} ${HOST_XPU_FLAGS}
--basename ${kernel_name} -fno-builtin --xpu-arch=${xpu_n} -fPIC --basename ${kernel_name} -fno-builtin --xpu-arch=${xpu_n} -fPIC
-Wno-int-to-void-pointer-cast -Wno-int-to-pointer-cast -Werror -mllvm -Wno-int-to-void-pointer-cast -Wno-int-to-pointer-cast -Werror ${MTP_KERNEL_COMPILE_FLAGS}
--xpu-inline-cost -mllvm --xpu-inline-hot-call -I${XDNN_INC_DIR} -I${XRE_INC_DIR} -mllvm --xpu-inline-cost -mllvm --xpu-inline-hot-call -I${XDNN_INC_DIR} -I${XRE_INC_DIR}
-fxpu-launch-return -fxpu-launch-return
-I${CMAKE_CURRENT_SOURCE_DIR}/include -I${CMAKE_CURRENT_SOURCE_DIR}/src -I${CMAKE_CURRENT_SOURCE_DIR}/include -I${CMAKE_CURRENT_SOURCE_DIR}/src
-I${CMAKE_CURRENT_SOURCE_DIR}/src/kernel -I${CMAKE_CURRENT_SOURCE_DIR}/src/kernel

View File

@@ -139,6 +139,307 @@ template <typename TX, typename TSCALE = float, typename TY = int8_t>
DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x, DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x,
const TSCALE *scale_in, TY *y, const TSCALE *scale_in, TY *y,
TSCALE *scale_out, int64_t m, int64_t n); TSCALE *scale_out, int64_t m, int64_t n);
/*--------------------------------------- MTP being --------------------------------------------*/
template <typename T>
DLL_EXPORT int speculate_token_penalty_multi_scores(
Context* ctx,
const int64_t* pre_ids,
T* logits,
const T* penalty_scores,
const T* frequency_scores,
const T* presence_scores,
const float* temperatures,
const int64_t* cur_len,
const int64_t* min_len,
const int64_t* eos_token_id,
const int64_t* bad_words,
const int* output_padding_offset,
const int* output_cum_offsets,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t length_bad_words,
const int64_t token_num,
const int64_t max_seq_len);
DLL_EXPORT int mtp_free_and_dispatch_block(Context* ctx,
bool* base_model_stop_flags,
bool* stop_flags,
bool* batch_drop,
int* seq_lens_this_time,
int* seq_lens_decoder,
int* block_tables,
int* encoder_block_lens,
int* used_list_len,
int* free_list,
int* free_list_len,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_draft_tokens);
template <bool ENABLE_TOPP, bool USE_TOPK>
DLL_EXPORT int speculate_verify(Context* ctx,
int64_t* accept_tokens,
int* accept_num,
int64_t* step_idx,
bool* stop_flags,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int64_t* draft_tokens,
const int* actual_draft_token_nums,
const float* dev_curand_states,
const float* topp,
const int* seq_lens_this_time,
const int64_t* verify_tokens,
const float* verify_scores,
const int64_t* max_dec_len,
const int64_t* end_tokens,
const bool* is_block_step,
const int* output_cum_offsets,
const int* actual_candidate_len,
const int real_bsz,
const int max_draft_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const bool prefill_one_step_stop);
DLL_EXPORT int speculate_clear_accept_nums(Context* ctx,
int* accept_num,
const int* seq_lens_decoder,
const int max_bsz);
DLL_EXPORT int speculate_get_seq_lens_output(Context* ctx,
int* seq_lens_output,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int real_bsz);
DLL_EXPORT int draft_model_update(Context* ctx,
const int64_t* inter_next_tokens,
int64_t* draft_tokens,
int64_t* pre_ids,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
const int* output_cum_offsets,
bool* stop_flags,
bool* not_need_stop,
const int64_t* max_dec_len,
const int64_t* end_ids,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int pre_id_length,
const int max_base_model_draft_token,
const int end_ids_len,
const int max_seq_len,
const int substep,
const bool prefill_one_step_stop);
DLL_EXPORT int draft_model_preprocess(api::Context* ctx,
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
int real_bsz,
int max_draft_token,
int accept_tokens_len,
int draft_tokens_len,
int input_ids_len,
int base_model_draft_tokens_len,
bool truncate_first_token,
bool splitwise_prefill);
DLL_EXPORT int speculate_set_stop_value_multi_seqs(Context* ctx,
bool* stop_flags,
int64_t* accept_tokens,
int* accept_nums,
const int64_t* pre_ids,
const int64_t* step_idx,
const int64_t* stop_seqs,
const int* stop_seqs_len,
const int* seq_lens,
const int64_t* end_ids,
const int bs_now,
const int accept_tokens_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int pre_ids_len);
template <typename T>
DLL_EXPORT int speculate_rebuild_append_padding(api::Context* ctx,
T* full_hidden_states,
int* cum_offsets,
int* seq_len_encoder,
int* seq_len_decoder,
int* output_padding_offset,
int max_seq_len,
int dim_embed,
int elem_nums,
T* out);
template <typename T>
DLL_EXPORT int speculate_remove_padding(Context* ctx,
T* x_remove_padding,
const T* input_ids,
const T* draft_tokens,
const int* seq_lens,
const int* seq_lens_encoder,
const int* cum_offsets_out,
int seq_length,
int max_draft_tokens,
int bsz,
int token_num_data);
DLL_EXPORT int speculate_get_padding_offset(Context* ctx,
int* padding_offset,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
const int* cum_offsets,
const int* seq_lens,
const int max_seq_len,
int bsz);
DLL_EXPORT int compute_self_order(api::Context* ctx,
const int* last_seq_lens_this_time,
const int* seq_lens_this_time,
const int64_t* step_idx,
int* src_map,
int* output_token_num,
int bsz);
DLL_EXPORT int compute_order(api::Context* ctx,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* accept_nums,
int* position_map,
int* output_token_num,
const int bsz,
const int actual_draft_token_num,
const int input_token_num);
DLL_EXPORT int draft_model_postprocess(Context* ctx,
const int64_t* base_model_draft_tokens,
int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const bool* base_model_stop_flags,
int bsz,
int base_model_draft_token_len);
DLL_EXPORT int speculate_set_value_by_flag_and_id(Context* ctx,
int64_t* pre_ids_all,
const int64_t* accept_tokens,
const int* accept_num,
const bool* stop_flags,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int64_t* step_idx,
int bs,
int length,
int max_draft_tokens);
DLL_EXPORT int speculate_get_output_padding_offset(
Context* ctx,
int* output_padding_offset,
int* output_cum_offsets,
const int* output_cum_offsets_tmp,
const int* seq_lens_output,
const int bsz,
const int max_seq_len);
template <typename T, int MaxLength, int TopPBeamTopK>
DLL_EXPORT int top_p_candidates(api::Context* ctx,
const T* src,
const T* top_ps,
const int* output_padding_offset,
int64_t* out_id,
T* out_val,
int* actual_candidates_lens,
int vocab_size,
int token_num,
int max_cadidate_len,
int max_seq_len);
DLL_EXPORT int speculate_free_and_reschedule(Context* ctx,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_decoder,
int* block_tables,
int* encoder_block_lens,
bool* is_block_step,
int* step_block_list, // [bsz]
int* step_len,
int* recover_block_list,
int* recover_len,
int* need_block_list,
int* need_block_len,
int* used_list_len,
int* free_list,
int* free_list_len,
int64_t* first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens);
DLL_EXPORT int speculate_update_v3(Context* ctx,
int* seq_lens_encoder,
int* seq_lens_decoder,
bool* not_need_stop,
int64_t* draft_tokens,
int* actual_draft_token_nums,
const int64_t* accept_tokens,
const int* accept_num,
const bool* stop_flags,
const int* seq_lens_this_time,
const bool* is_block_step,
const int64_t* stop_nums,
const int real_bsz,
const int max_bsz,
const int max_draft_tokens);
template <typename T>
DLL_EXPORT int rebuild_hidden_states(api::Context* ctx,
const T* input,
const int* position_map,
T* out,
int dim_embed,
int elem_cnt);
template <typename T>
DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx,
const T* input,
int* src_map,
T* output,
int dim_embed,
int elem_cnt);
/*--------------------------------------- MTP end --------------------------------------------*/
} // namespace plugin } // namespace plugin
} // namespace api } // namespace api
} // namespace xpu } // namespace xpu

View File

@@ -0,0 +1,112 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
__global__ void ComputeOrderKernel(const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* accept_nums,
int* position_map,
int* output_token_num,
const int bsz,
const int actual_draft_token_num,
const int input_token_num) {
int tid = core_id() * cluster_num() + cluster_id();
if (tid != 0) {
return;
}
char lm[6 * 1024];
int buf_size = 6 * 1024 / (6 * sizeof(int));
int* lm_base_model_seq_lens_this_time = (int*)lm;
int* lm_base_model_seq_lens_encoder =
lm_base_model_seq_lens_this_time + buf_size;
int* lm_seq_lens_this_time = lm_base_model_seq_lens_encoder + buf_size;
int* lm_accept_nums = lm_seq_lens_this_time + buf_size;
int* lm_seq_lens_encoder = lm_accept_nums + buf_size;
int* lm_position_map = lm_seq_lens_encoder + buf_size;
int in_offset = 0;
int out_offset = 0;
for (int i = 0; i < bsz; i += buf_size) {
int64_t read_size = min(static_cast<int64_t>(bsz - i), buf_size);
GM2LM_ASYNC(base_model_seq_lens_this_time + i,
lm_base_model_seq_lens_this_time,
read_size * sizeof(int));
GM2LM_ASYNC(base_model_seq_lens_encoder + i,
lm_base_model_seq_lens_encoder,
read_size * sizeof(int));
GM2LM_ASYNC(
seq_lens_this_time + i, lm_seq_lens_this_time, read_size * sizeof(int));
GM2LM_ASYNC(accept_nums + i, lm_accept_nums, read_size * sizeof(int));
GM2LM(seq_lens_encoder + i, lm_seq_lens_encoder, read_size * sizeof(int));
for (int j = 0; j < read_size; j++) {
int cur_base_model_seq_lens_this_time =
lm_base_model_seq_lens_this_time[j];
int cur_base_model_seq_lens_encoder = lm_base_model_seq_lens_encoder[j];
int cur_seq_lens_this_time = lm_seq_lens_this_time[j];
int accept_num = lm_accept_nums[j];
int cur_seq_lens_encoder = lm_seq_lens_encoder[j];
// 1. eagle encoder. Base step=1
if (cur_seq_lens_encoder > 0) {
for (int k = 0; k < cur_seq_lens_encoder; k += buf_size) {
int64_t write_size =
min(static_cast<int64_t>(cur_seq_lens_encoder - k),
static_cast<int64_t>(buf_size));
for (int l = 0; l < write_size; l++) {
lm_position_map[l] = out_offset;
out_offset++;
}
mfence_lm();
LM2GM(lm_position_map,
position_map + in_offset,
write_size * sizeof(int));
in_offset += write_size;
}
mfence_lm();
// 2. base model encoder. Base step=0
} else if (cur_base_model_seq_lens_encoder != 0) {
// nothing happens
// 3. New end
} else if (cur_base_model_seq_lens_this_time != 0 &&
cur_seq_lens_this_time == 0) {
in_offset += cur_base_model_seq_lens_this_time;
// 4. stopped
} else if (cur_base_model_seq_lens_this_time == 0 &&
cur_seq_lens_this_time == 0) {
// nothing happens
} else {
if (accept_num <= actual_draft_token_num) {
int position_map_val = out_offset;
LM2GM(&position_map_val,
position_map + in_offset + accept_num - 1,
sizeof(int));
out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
} else {
int position_map_val_1 = out_offset;
LM2GM(&position_map_val_1,
position_map + in_offset + accept_num - 2,
sizeof(int));
out_offset++;
int position_map_val_2 = out_offset;
LM2GM(&position_map_val_2,
position_map + in_offset + accept_num - 1,
sizeof(int));
out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
}
mfence_lm();
}
}
}
mfence_lm();
LM2GM(&out_offset, output_token_num, sizeof(int));
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,75 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
__global__ void ComputeSelfOrderKernel(const int* last_seq_lens_this_time,
const int* seq_lens_this_time,
const int64_t* step_idx,
int* src_map,
int* output_token_num,
int bsz) {
int tid = core_id() * cluster_num() + cluster_id();
if (tid != 0) {
return;
}
char lm[6 * 1024];
int buf_size = 256;
int* lm_last_seq_lens_this_time = (int*)lm;
int* lm_seq_lens_this_time = lm_last_seq_lens_this_time + buf_size;
int64_t* lm_step_idx = (int64_t*)(lm_seq_lens_this_time + buf_size);
int* lm_src_map = (int*)(lm_step_idx + buf_size);
int in_offset = 0;
int out_offset = 0;
int previous_out_offset = out_offset;
for (int i = 0; i < bsz; i += buf_size) {
int64_t read_size = min(static_cast<int64_t>(bsz - i), buf_size);
GM2LM_ASYNC(last_seq_lens_this_time + i,
lm_last_seq_lens_this_time,
read_size * sizeof(int));
GM2LM_ASYNC(
seq_lens_this_time + i, lm_seq_lens_this_time, read_size * sizeof(int));
GM2LM(step_idx + i, lm_step_idx, read_size * sizeof(int64_t));
for (int j = 0; j < read_size; j++) {
int cur_seq_lens_this_time = lm_seq_lens_this_time[j];
int cur_last_seq_lens_this_time = lm_last_seq_lens_this_time[j];
int64_t cur_step_idx = lm_step_idx[j];
// 1. encoder
if (cur_step_idx == 1 && cur_seq_lens_this_time > 0) {
in_offset += 1;
lm_src_map[j] = in_offset - 1;
out_offset++;
// 2. decoder
} else if (cur_seq_lens_this_time > 0) /* =1 */ {
in_offset += cur_last_seq_lens_this_time;
lm_src_map[j] = in_offset - 1;
out_offset++;
// 3. stop
} else {
// first token end
if (cur_step_idx == 1) {
in_offset += cur_last_seq_lens_this_time > 0 ? 1 : 0;
// normal end
} else {
in_offset += cur_last_seq_lens_this_time;
}
}
}
mfence_lm();
if (out_offset > previous_out_offset) {
LM2GM_ASYNC(lm_src_map,
src_map + previous_out_offset,
(out_offset - previous_out_offset) * sizeof(int));
}
previous_out_offset = out_offset;
}
mfence_lm();
LM2GM(&out_offset, output_token_num, sizeof(int));
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,189 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/cluster_primitive_template.h"
#include "xpu/kernel/cluster_simd.h"
#include "xpu/kernel/xtdk_io.h"
namespace xpu3 {
namespace plugin {
static inline __device__ int v_reduce(int32x16_t v) {
auto v0 = vsrlp_int32x16(256, v);
v = vvadd_int32x16(v0, v);
v0 = vsrlp_int32x16(128, v);
v = vvadd_int32x16(v0, v);
v0 = vsrlp_int32x16(64, v);
v = vvadd_int32x16(v0, v);
v0 = vsrlp_int32x16(32, v);
v = vvadd_int32x16(v0, v);
int res;
res = vextract_int32x16(v);
return res;
}
__device__ int do_calc(int64_t* lmptr, int read_len) {
int res = 0;
int32x16_t v0;
int32x16_t v1;
int32x16_t v2 = {0};
int* lmptr_i16 = (int*)lmptr;
int rounddown_size = rounddown16(read_len * 2);
int comp = -1;
int i = 0;
for (; i < rounddown_size; i += 16) {
v0 = vload_lm_int32x16(lmptr_i16 + i);
v1 = vload_lm_int32x16(lmptr_i16 + i);
unsigned int mask0 =
static_cast<unsigned int>(sveq_int32x16_mz(comp, v0, 0xAAAA));
unsigned int mask1 =
static_cast<unsigned int>(sveq_int32x16_mz(comp, v1, 0x5555));
mask1 = mask1 << 1;
unsigned int mask2 = (mask0 & 0xFFFFFFFF) & (mask1 & 0xFFFFFFFF);
v2 = svadd_int32x16_mh(1, v2, v2, mask2);
}
res = i / 2 - v_reduce(v2);
mfence_lm();
for (int j = i / 2; j < read_len; j++) {
if (lmptr[j] != -1) {
res += 1;
}
}
return res;
}
__global__ void draft_model_postprocess(const int64_t* base_model_draft_tokens,
int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const bool* base_model_stop_flags,
int bsz,
int base_model_draft_token_len) {
int cid = core_id();
int ncores = core_num();
int nclusters = cluster_num();
int nthreads = ncores * nclusters;
const int max_sm_len = 256 * 1024 / sizeof(int);
const int core_limit_row = max_sm_len / ncores;
const int clusetr_limit_row = max_sm_len * nclusters;
int bsz_start_cluster;
int bsz_end_cluster;
int bsz_start_core;
int bsz_end_core;
int row_to_partition = min(bsz, clusetr_limit_row);
// cluster partition
partition(cluster_id(),
nclusters,
row_to_partition,
1,
&bsz_start_cluster,
&bsz_end_cluster);
if (bsz_start_cluster >= bsz_end_cluster) {
return;
}
int rows_cluster =
bsz_end_cluster - bsz_start_cluster; // total rows for a cluster
// core partition
partition(
core_id(), core_num(), rows_cluster, 1, &bsz_start_core, &bsz_end_core);
__shared__ int base_model_sm[max_sm_len];
const int LM_SIZE = 3072;
const int BUFSIZE = LM_SIZE / sizeof(int64_t);
__simd__ int64_t output_lm[BUFSIZE * 2];
DoublePtr<BUFSIZE, LmPtr<int64_t>> local_base_model(
(LmPtr<int64_t>((int64_t*)output_lm)));
const int BSZ_BUF = 16;
__simd__ bool base_model_stop_lm[BSZ_BUF];
__simd__ int base_model_seq_lm[BSZ_BUF];
bsz_start_core += bsz_start_cluster;
bsz_end_core += bsz_start_cluster;
int read_len_sm = 0;
int offset_loop = 0;
int offset_cluster = 0;
int cur_row_to_all_clusetr = 0;
for (int limit_loop = 0; limit_loop < roundup_div(bsz, clusetr_limit_row);
limit_loop += 1) {
offset_loop = limit_loop * clusetr_limit_row;
if (bsz_start_core + offset_loop >= bsz) {
break;
}
cur_row_to_all_clusetr = min(bsz - offset_loop, clusetr_limit_row);
offset_cluster = 0;
//计算offset_cluster
for (int start_cluster = 0; start_cluster < cluster_id();
start_cluster += 1) {
offset_cluster += (rounddown_div(cur_row_to_all_clusetr, nclusters) +
(start_cluster < cur_row_to_all_clusetr % nclusters));
}
if (core_id() == 0) {
if (cur_row_to_all_clusetr < nclusters) {
// bsz很小 每个cluster平均分不到一个读一个长度就好
read_len_sm = 1;
} else {
// bsz足够大 每个cluster读一部分 最大个数max_sm_len
read_len_sm =
min(max_sm_len,
rounddown_div(cur_row_to_all_clusetr, nclusters) +
(cluster_id() < (cur_row_to_all_clusetr % nclusters)));
}
GM2SM(base_model_seq_lens_this_time + offset_loop + offset_cluster,
base_model_sm + offset_cluster,
sizeof(int) * read_len_sm);
}
cur_row_to_all_clusetr -= clusetr_limit_row;
sync_cluster();
for (int bsz_index = bsz_start_core + offset_loop;
(bsz_index < bsz_end_core + offset_loop) && (bsz_index < bsz);
bsz_index += 1) {
int bsz_offset = bsz_index - bsz_start_core;
if (bsz_offset % BSZ_BUF == 0) {
int64_t readm = min(bsz - bsz_index, BSZ_BUF);
GM2LM_ASYNC(base_model_stop_flags + bsz_index,
base_model_stop_lm,
sizeof(bool) * readm);
GM2LM(base_model_seq_lens_encoder + bsz_index,
base_model_seq_lm,
sizeof(int) * readm);
}
if (!base_model_stop_lm[bsz_offset % BSZ_BUF] &&
(base_model_seq_lm[bsz_offset % BSZ_BUF] == 0)) {
// 计算有效token数量非-1的token
int token_num = 0;
int j = 0;
int read_len = min(base_model_draft_token_len - j, BUFSIZE);
local_base_model.gm_load(base_model_draft_tokens +
bsz_index * base_model_draft_token_len + j,
read_len);
for (; j < base_model_draft_token_len; j += BUFSIZE) {
int next_idx = j + BUFSIZE;
int read_len_next =
min(base_model_draft_token_len - next_idx, BUFSIZE);
if (read_len_next > 0) {
local_base_model.next().gm_load_async(
base_model_draft_tokens +
bsz_index * base_model_draft_token_len + next_idx,
read_len_next);
}
token_num += do_calc(local_base_model.ptr, read_len);
read_len = read_len_next;
local_base_model.toggle();
mfence_lm();
}
base_model_sm[bsz_index % max_sm_len] = token_num;
} else if (base_model_stop_lm[bsz_offset % BSZ_BUF]) {
int token_num = 0;
base_model_sm[bsz_index % max_sm_len] = token_num;
}
}
sync_cluster();
if (core_id() == 0) {
SM2GM(base_model_sm + offset_cluster,
base_model_seq_lens_this_time + offset_loop + offset_cluster,
sizeof(int) * read_len_sm);
}
sync_cluster();
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,243 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_debug.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/cluster_simd.h"
namespace xpu3 {
namespace plugin {
__global__ void draft_model_preprocess(int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
int real_bsz,
int max_draft_token,
int accept_tokens_len,
int draft_tokens_len,
int input_ids_len,
int base_model_draft_tokens_len,
bool truncate_first_token,
bool splitwise_prefill) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
int tid = clusterid * ncores + cid;
__shared__ int not_stop_flag_sm[64];
not_stop_flag_sm[cid] = 0;
int64_t accept_tokens_now[128];
int value_zero = 0;
int64_t value_fu = -1;
if (splitwise_prefill) {
for (; tid < real_bsz; tid += ncores * nclusters) {
int64_t base_model_step_idx_now = 0;
int seq_lens_encoder_now = 0;
int seq_lens_this_time_now = 0;
bool stop_flags_now = false;
int64_t base_model_first_token;
int seq_lens_encoder_record_now = 0;
int64_t input_ids_now = 0;
GM2LM_ASYNC(
base_model_step_idx + tid, &base_model_step_idx_now, sizeof(int64_t));
GM2LM_ASYNC(seq_lens_encoder_record + tid,
&seq_lens_encoder_record_now,
sizeof(int));
GM2LM(accept_tokens + tid * accept_tokens_len,
&base_model_first_token,
sizeof(int64_t));
if (base_model_step_idx_now == 1 && seq_lens_encoder_record_now > 0) {
not_stop_flag_sm[cid] += 1;
int seq_len_encoder_record = seq_lens_encoder_record_now;
seq_lens_encoder_now = seq_len_encoder_record;
seq_lens_encoder_record_now = -1;
stop_flags_now = false;
int position = seq_len_encoder_record;
if (truncate_first_token) {
position = position - 1;
input_ids_now = base_model_first_token;
seq_lens_this_time_now = seq_len_encoder_record;
} else {
input_ids_now = base_model_first_token;
seq_lens_this_time_now = seq_len_encoder_record + 1;
}
LM2GM_ASYNC(&input_ids_now,
input_ids + tid * input_ids_len + position,
sizeof(int64_t));
LM2GM_ASYNC(&seq_lens_encoder_record_now,
seq_lens_encoder_record + tid,
sizeof(int));
} else {
stop_flags_now = true;
seq_lens_this_time_now = 0;
seq_lens_encoder_now = 0;
not_stop_flag_sm[cid] += 0;
LM2GM_ASYNC(&value_zero, seq_lens_decoder + tid, sizeof(int));
}
LM2GM_ASYNC(&seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int));
LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool));
LM2GM(&seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int));
}
} else {
for (; tid < real_bsz; tid += ncores * nclusters) {
bool base_model_stop_flags_now = false;
bool base_model_is_block_step_now = false;
bool batch_drop_now = false;
bool stop_flags_now = false;
int seq_lens_this_time_now = 0;
int seq_lens_encoder_record_now = 0;
int seq_lens_encoder_now = 0;
int seq_lens_decoder_new = 0;
int seq_lens_decoder_record_now = 0;
int accept_num_now = 0;
int base_model_seq_lens_decoder_now = 0;
int64_t step_id_now = 0;
int64_t base_model_step_idx_now;
mfence();
GM2LM_ASYNC(base_model_stop_flags + tid,
&base_model_stop_flags_now,
sizeof(bool));
GM2LM_ASYNC(base_model_is_block_step + tid,
&base_model_is_block_step_now,
sizeof(bool));
GM2LM_ASYNC(batch_drop + tid, &batch_drop_now, sizeof(bool));
GM2LM_ASYNC(stop_flags + tid, &stop_flags_now, sizeof(bool));
GM2LM_ASYNC(seq_lens_encoder_record + tid,
&seq_lens_encoder_record_now,
sizeof(int));
GM2LM_ASYNC(seq_lens_decoder_record + tid,
&seq_lens_decoder_record_now,
sizeof(int));
GM2LM_ASYNC(seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int));
GM2LM_ASYNC(seq_lens_decoder + tid, &seq_lens_decoder_new, sizeof(int));
GM2LM_ASYNC(accept_tokens + tid * accept_tokens_len,
accept_tokens_now,
accept_tokens_len * sizeof(int64_t));
GM2LM_ASYNC(accept_num + tid, &accept_num_now, sizeof(int));
GM2LM_ASYNC(base_model_seq_lens_decoder + tid,
&base_model_seq_lens_decoder_now,
sizeof(int));
GM2LM_ASYNC(step_idx + tid, &step_id_now, sizeof(int64_t));
GM2LM(
base_model_step_idx + tid, &base_model_step_idx_now, sizeof(int64_t));
for (int i = 1; i < base_model_draft_tokens_len; i++) {
LM2GM_ASYNC(
&value_fu,
base_model_draft_tokens + tid * base_model_draft_tokens_len + i,
sizeof(int));
}
if (base_model_stop_flags_now && base_model_is_block_step_now) {
batch_drop_now = true;
stop_flags_now = true;
}
if (!(base_model_stop_flags_now || batch_drop_now)) {
not_stop_flag_sm[cid] += 1;
if (base_model_step_idx_now == 0) {
seq_lens_this_time_now = 0;
not_stop_flag_sm[cid] -= 1; // 因为上面加过,这次减去,符合=0逻辑
} else if (base_model_step_idx_now == 1 &&
seq_lens_encoder_record_now > 0) {
int seq_len_encoder_record = seq_lens_encoder_record_now;
seq_lens_encoder_now = seq_len_encoder_record;
seq_lens_encoder_record_now = -1;
seq_lens_decoder_new = seq_lens_decoder_record_now;
seq_lens_decoder_record_now = 0;
stop_flags_now = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder_record;
if (truncate_first_token) {
LM2GM(&base_model_first_token,
input_ids + tid * input_ids_len + position - 1,
sizeof(int64_t));
seq_lens_this_time_now = seq_len_encoder_record;
} else {
LM2GM(&base_model_first_token,
input_ids + tid * input_ids_len + position,
sizeof(int64_t));
seq_lens_this_time_now = seq_len_encoder_record + 1;
}
} else if (accept_num_now <= max_draft_token) {
if (stop_flags_now) {
stop_flags_now = false;
seq_lens_decoder_new = base_model_seq_lens_decoder_now;
step_id_now = base_model_step_idx_now;
} else {
seq_lens_decoder_new -= max_draft_token - accept_num_now;
step_id_now -= max_draft_token - accept_num_now;
}
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
LM2GM(&modified_token,
draft_tokens + tid * draft_tokens_len,
sizeof(int64_t));
seq_lens_this_time_now = 1;
} else /*Accept all draft tokens*/ {
LM2GM(accept_tokens_now + max_draft_token,
draft_tokens + tid * draft_tokens_len + 1,
sizeof(int64_t));
seq_lens_this_time_now = 2;
}
} else {
stop_flags_now = true;
seq_lens_this_time_now = 0;
seq_lens_encoder_now = 0;
seq_lens_decoder_new = 0;
}
LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool));
LM2GM_ASYNC(&batch_drop_now, batch_drop + tid, sizeof(bool));
LM2GM_ASYNC(&seq_lens_decoder_new, seq_lens_decoder + tid, sizeof(int));
LM2GM_ASYNC(
&seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int));
LM2GM_ASYNC(&seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int));
LM2GM_ASYNC(&seq_lens_encoder_record_now,
seq_lens_encoder_record + tid,
sizeof(int));
LM2GM_ASYNC(&seq_lens_decoder_record_now,
seq_lens_decoder_record + tid,
sizeof(int));
LM2GM_ASYNC(&step_id_now, step_idx + tid, sizeof(int64_t));
}
}
mfence();
sync_cluster();
bool value_true = true;
bool value_false = false;
if (cid == 0) {
for (int i = 0; i < ncores; i++) {
not_stop_flag_sm[0] += not_stop_flag_sm[i];
}
if (not_stop_flag_sm[0] > 0) {
LM2GM(&value_true, not_need_stop, sizeof(bool));
} else {
LM2GM(&value_false, not_need_stop, sizeof(bool));
}
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,114 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_debug.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/cluster_primitive_template.h"
namespace xpu3 {
namespace plugin {
inline __device__ bool is_in_end(const int64_t id,
const __global_ptr__ int64_t* end_ids,
int length) {
bool flag = false;
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
}
return flag;
}
__global__ void draft_model_update(const int64_t* inter_next_tokens,
int64_t* draft_tokens,
int64_t* pre_ids,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
const int* output_cum_offsets,
bool* stop_flags,
bool* not_need_stop,
const int64_t* max_dec_len,
const int64_t* end_ids,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int pre_id_length,
const int max_base_model_draft_token,
const int end_ids_len,
const int max_seq_len,
const int substep,
const bool prefill_one_step_stop) {
int cid = core_id();
int ncores = core_num();
__shared__ float stop_flag_now_int_sm[64];
stop_flag_now_int_sm[cid] = 0;
for (int tid = cid; tid < bsz; tid += ncores) {
auto* draft_token_now = draft_tokens + tid * max_draft_token;
auto* pre_ids_now = pre_ids + tid * pre_id_length;
auto* base_model_draft_tokens_now =
base_model_draft_tokens + tid * max_base_model_draft_token;
const int next_tokens_start_id =
tid * max_seq_len - output_cum_offsets[tid];
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
auto seq_len_this_time = seq_lens_this_time[tid];
auto seq_len_encoder = seq_lens_encoder[tid];
auto seq_len_decoder = seq_lens_decoder[tid];
if (!stop_flags[tid] /* seq_lens_decoder > 0 or seq_lens_encoder > 0 */) {
int64_t token_this_time = -1;
// decoder step
if (seq_len_decoder > 0 && seq_len_encoder <= 0) {
seq_lens_decoder[tid] += seq_len_this_time;
token_this_time = next_tokens_start[seq_len_this_time - 1];
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
base_model_draft_tokens_now[substep + 1] = token_this_time;
for (int i = 0; i < seq_len_this_time; ++i) {
pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i];
}
step_idx[tid] += seq_len_this_time;
} else {
token_this_time = next_tokens_start[0];
seq_lens_decoder[tid] = seq_len_encoder + seq_len_decoder;
// mfence();
seq_lens_encoder[tid] = 0;
pre_ids_now[1] = token_this_time;
step_idx[tid] += 1;
draft_token_now[0] = token_this_time;
base_model_draft_tokens_now[substep + 1] = token_this_time;
}
// multi_end
if (is_in_end(token_this_time, end_ids, end_ids_len) ||
prefill_one_step_stop) {
stop_flags[tid] = true;
stop_flag_now_int_sm[cid] += 1;
// max_dec_len
} else if (step_idx[tid] >= max_dec_len[tid]) {
stop_flags[tid] = true;
draft_token_now[seq_len_this_time - 1] = end_ids[0];
base_model_draft_tokens_now[substep + 1] = end_ids[0];
stop_flag_now_int_sm[cid] += 1;
}
} else {
draft_token_now[0] = -1;
base_model_draft_tokens_now[substep + 1] = -1;
stop_flag_now_int_sm[cid] += 1;
}
// 2. set end
if (!stop_flags[tid]) {
seq_lens_this_time[tid] = 1;
} else {
seq_lens_this_time[tid] = 0;
seq_lens_encoder[tid] = 0;
}
}
mfence();
sync_all();
if (cid == 0) {
int sum_stop = 0;
for (int i = 0; i < 64; i++) {
sum_stop += stop_flag_now_int_sm[i];
}
not_need_stop[0] = sum_stop < bsz;
}
mfence();
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,209 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
static __device__ inline int loada_float(_shared_ptr_ const int *ptr) {
int ret;
__asm__ __volatile__("loada.w %0,%1" : "=&r"(ret) : "r"(ptr));
return ret;
}
static __device__ inline bool storea_float(_shared_ptr_ int *ptr, int value) {
bool ret;
__asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(ret) : "r"(value), "r"(ptr));
return ret;
}
static __device__ int atomic_add(_shared_ptr_ int *ptr, int value) {
bool fail = true;
int old_value;
while (fail) {
old_value = loada_float(ptr);
int new_value = old_value + value;
fail = storea_float(ptr, new_value);
}
return old_value;
}
__global__ void mtp_free_and_dispatch_block(bool *base_model_stop_flags,
bool *stop_flags,
bool *batch_drop,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
int *used_list_len,
int *free_list,
int *free_list_len,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_draft_tokens) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
if (clusterid != 0 || cid >= bsz) return;
// assert bsz <= 640
const int max_bs = 640;
int value_zero = 0;
bool flag_true = true;
__shared__ int free_list_len_sm;
// 每次最多处理free_list数量为block_table_now_len
const int block_table_now_len = 128;
int block_table_now[block_table_now_len];
for (int i = 0; i < block_table_now_len; i++) {
block_table_now[i] = -1;
}
__shared__ bool base_model_stop_flags_sm[max_bs];
__shared__ bool batch_drop_sm[max_bs];
__shared__ int encoder_block_lens_sm[max_bs];
__shared__ int seq_lens_decoder_sm[max_bs];
int free_list_now[block_table_now_len];
__shared__ int need_block_len_sm;
__shared__ int need_block_list_sm[max_bs];
__shared__ int used_list_len_sm[max_bs];
__shared__ bool step_max_block_flag;
if (cid == 0) {
// len = 1
need_block_len_sm = 0;
GM2SM_ASYNC(free_list_len, &free_list_len_sm, sizeof(int));
// len = bsz
GM2SM_ASYNC(
base_model_stop_flags, &base_model_stop_flags_sm, bsz * sizeof(bool));
GM2SM_ASYNC(batch_drop, &batch_drop_sm, bsz * sizeof(bool));
GM2SM_ASYNC(encoder_block_lens, &encoder_block_lens_sm, bsz * sizeof(int));
GM2SM_ASYNC(used_list_len, used_list_len_sm, bsz * sizeof(int));
GM2SM_ASYNC(seq_lens_decoder, seq_lens_decoder_sm, bsz * sizeof(int));
}
for (int tid = cid; tid < bsz; tid += ncores) {
need_block_list_sm[tid] = 0;
}
mfence();
sync_all();
for (int tid = cid; tid < bsz; tid += ncores) {
int64_t first_token_id_lm = -1;
if (base_model_stop_flags_sm[tid] || batch_drop_sm[tid]) {
// 回收block块
const int encoder_block_len_lm = encoder_block_lens_sm[tid];
const int decoder_used_len_lm = used_list_len_sm[tid];
if (decoder_used_len_lm > 0) {
const int ori_free_list_len =
atomic_add(&free_list_len_sm, decoder_used_len_lm);
for (int i = 0; i < decoder_used_len_lm; i += block_table_now_len) {
int process_len = min(block_table_now_len, decoder_used_len_lm - i);
GM2LM(
block_tables + tid * block_num_per_seq + encoder_block_len_lm + i,
free_list_now,
process_len * sizeof(int));
LM2GM(free_list_now,
free_list + ori_free_list_len + i,
process_len * sizeof(int));
LM2GM(
block_table_now,
block_tables + tid * block_num_per_seq + encoder_block_len_lm + i,
process_len * sizeof(int));
}
encoder_block_lens_sm[tid] = 0;
used_list_len_sm[tid] = 0;
}
mfence();
}
int max_possible_block_idx =
(seq_lens_decoder_sm[tid] + max_draft_tokens + 1) / block_size;
int next_block_id;
GM2LM(block_tables + tid * block_num_per_seq + max_possible_block_idx,
&next_block_id,
sizeof(int));
if (!base_model_stop_flags[tid] && !batch_drop[tid] &&
max_possible_block_idx < block_num_per_seq && next_block_id == -1) {
// 统计需要分配block的位置和总数
const int ori_need_block_len = atomic_add(&need_block_len_sm, 1);
need_block_list_sm[ori_need_block_len] = tid;
mfence();
}
} // for
sync_cluster();
if (cid == 0) {
while (need_block_len_sm > free_list_len_sm) {
// 调度block根据used_list_len从大到小回收block直到满足need_block_len
int max_used_list_len_id = 0;
int max_used_list_len = 0;
for (int i = 0; i < bsz; i++) {
if ((!base_model_stop_flags_sm[i]) &&
(used_list_len_sm[i] > max_used_list_len)) {
max_used_list_len_id = i;
max_used_list_len = used_list_len_sm[i];
}
}
const int encoder_block_len_lm =
encoder_block_lens_sm[max_used_list_len_id];
for (int i = 0; i < max_used_list_len; i += block_table_now_len) {
int process_len = min(block_table_now_len, max_used_list_len - i);
GM2LM(block_tables + max_used_list_len_id * block_num_per_seq +
encoder_block_len_lm + i,
free_list_now,
process_len * sizeof(int));
LM2GM(free_list_now,
free_list + free_list_len_sm + i,
process_len * sizeof(int));
LM2GM(block_table_now,
block_tables + max_used_list_len_id * block_num_per_seq +
encoder_block_len_lm + i,
process_len * sizeof(int));
}
free_list_len_sm += max_used_list_len;
LM2GM_ASYNC(&flag_true, stop_flags + max_used_list_len_id, sizeof(bool));
LM2GM_ASYNC(
&value_zero, seq_lens_this_time + max_used_list_len_id, sizeof(int));
// 后面还要用所以先放到sm中用完在写回GM
batch_drop_sm[max_used_list_len_id] = true;
seq_lens_decoder_sm[max_used_list_len_id] = 0;
used_list_len_sm[max_used_list_len_id] = 0;
mfence();
}
}
sync_cluster();
int need_block_len_all = need_block_len_sm;
for (int tid = cid; tid < need_block_len_all; tid += ncores) {
// 为需要block的位置分配block每个位置分配一个block
const int need_block_id = need_block_list_sm[tid];
if (!batch_drop_sm[need_block_id]) {
used_list_len_sm[need_block_id]++;
const int ori_free_list_len = atomic_add(&free_list_len_sm, -1);
int free_block_id;
GM2LM(free_list + ori_free_list_len - 1, &free_block_id, sizeof(int));
LM2GM(&free_block_id,
block_tables + need_block_id * block_num_per_seq +
(seq_lens_decoder_sm[need_block_id] + max_draft_tokens + 1) /
block_size,
sizeof(int));
}
}
sync_cluster();
if (cid == 0) {
mfence();
SM2GM_ASYNC(&free_list_len_sm, free_list_len, sizeof(int));
SM2GM_ASYNC(used_list_len_sm, used_list_len, sizeof(int) * bsz);
SM2GM_ASYNC(seq_lens_decoder_sm, seq_lens_decoder, sizeof(int) * bsz);
SM2GM_ASYNC(batch_drop_sm, batch_drop, sizeof(bool) * bsz);
SM2GM_ASYNC(encoder_block_lens_sm, encoder_block_lens, sizeof(int) * bsz);
mfence();
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,90 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
template <typename T>
__global__ void RebuildAppendPaddingKernel(const T *full_hidden_states,
const int *cum_offset,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int *output_padding_offset,
int max_seq_len,
int dim_embed,
int elem_nums,
T *out) {
int ncores = core_num();
int cid = core_id();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
int64_t mstart = -1;
int64_t mend = -1;
int64_t nstart = -1;
int64_t nend = -1;
partition2d<int64_t>(tid,
nthreads,
elem_nums / dim_embed,
dim_embed,
&mstart,
&mend,
&nstart,
&nend);
const int64_t BUFFER_LEN = rounddown(6144 / sizeof(T), 64);
__simd__ T lm_full_hidden_states[BUFFER_LEN];
int output_padding_offset_val, cum_offset_val, seq_len_encoder_val,
seq_len_decoder_val;
for (int64_t _m = mstart; _m < mend; _m++) {
int out_token_id = _m;
GM2LM(output_padding_offset + out_token_id,
&output_padding_offset_val,
sizeof(int));
int ori_token_id = out_token_id + output_padding_offset_val;
int bi = ori_token_id / max_seq_len;
GM2LM_ASYNC(seq_len_encoder + bi, &seq_len_encoder_val, sizeof(int));
GM2LM(seq_len_decoder + bi, &seq_len_decoder_val, sizeof(int));
int seq_id = 0;
if (seq_len_encoder_val == 0 and seq_len_decoder_val == 0) {
continue;
} else if (seq_len_encoder_val != 0) {
seq_id = seq_len_encoder_val - 1;
}
GM2LM(cum_offset + bi, &cum_offset_val, sizeof(int));
int input_token_id = ori_token_id - cum_offset_val + seq_id;
for (int64_t _n = nstart; _n < nend; _n += BUFFER_LEN) {
int64_t read_size = min(BUFFER_LEN, nend - _n);
// out[i] = full_hidden_states[(i / dim_embed +
// output_padding_offset[i / dim_embed] - cum_offset[(i / dim_embed
// + output_padding_offset[i / dim_embed]) / max_seq_len] + seq_id)
// * dim_embed + i % dim_embed]
GM2LM(full_hidden_states + input_token_id * dim_embed + _n,
lm_full_hidden_states,
read_size * sizeof(T));
LM2GM(lm_full_hidden_states,
out + _m * dim_embed + _n,
read_size * sizeof(T));
}
}
}
#define _XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(T) \
template __global__ void RebuildAppendPaddingKernel<T>( \
const T *full_hidden_states, \
const int *cum_offset, \
const int *seq_len_encoder, \
const int *seq_len_decoder, \
const int *output_padding_offset, \
int max_seq_len, \
int dim_embed, \
int elem_nums, \
T *out);
_XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(bfloat16);
_XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(float16);
_XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(float);
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,65 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
template <typename T>
__global__ void rebuildHiddenStatesKernel(const T* input,
const int* position_map,
T* output,
int dim_embed,
int elem_cnt) {
int ncores = core_num();
int cid = core_id();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
int64_t mstart = -1;
int64_t mend = -1;
int64_t nstart = -1;
int64_t nend = -1;
partition2d<int64_t>(tid,
nthreads,
elem_cnt / dim_embed,
dim_embed,
&mstart,
&mend,
&nstart,
&nend);
const int64_t BUFFER_LEN = 6144 / sizeof(T);
T lm_input[BUFFER_LEN];
for (int64_t _m = mstart; _m < mend; _m++) {
int ori_token_idx = _m;
int token_idx;
GM2LM(position_map + _m, &token_idx, sizeof(int));
if (token_idx >= 0) {
for (int64_t _n = nstart; _n < nend; _n += BUFFER_LEN) {
int64_t read_size = min(BUFFER_LEN, nend - _n);
GM2LM(input + ori_token_idx * dim_embed + _n,
lm_input,
read_size * sizeof(T));
LM2GM(lm_input,
output + token_idx * dim_embed + _n,
read_size * sizeof(T));
}
}
}
}
#define _XPU_DEF_REBUILD_HIDDEN_STATES_KERNEL(T) \
template __global__ void rebuildHiddenStatesKernel<T>( \
const T* input, \
const int* position_map, \
T* output, \
int dim_embed, \
int elem_cnt);
_XPU_DEF_REBUILD_HIDDEN_STATES_KERNEL(bfloat16);
_XPU_DEF_REBUILD_HIDDEN_STATES_KERNEL(float);
_XPU_DEF_REBUILD_HIDDEN_STATES_KERNEL(float16);
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,56 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
template <typename T>
__global__ void rebuildSelfHiddenStatesKernel(
const T* input, int* src_map, T* output, int dim_embed, int elem_cnt) {
int ncores = core_num();
int cid = core_id();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
int64_t mstart = -1;
int64_t mend = -1;
int64_t nstart = -1;
int64_t nend = -1;
partition2d<int64_t>(tid,
nthreads,
elem_cnt / dim_embed,
dim_embed,
&mstart,
&mend,
&nstart,
&nend);
const int64_t BUFFER_LEN = 6144 / sizeof(T);
T lm_input[BUFFER_LEN];
for (int64_t _m = mstart; _m < mend; _m++) {
int output_token_idx = _m;
int input_token_idx;
GM2LM(src_map + _m, &input_token_idx, sizeof(int));
if (input_token_idx >= 0) {
for (int64_t _n = nstart; _n < nend; _n += BUFFER_LEN) {
int64_t read_size = min(BUFFER_LEN, nend - _n);
GM2LM(input + input_token_idx * dim_embed + _n,
lm_input,
read_size * sizeof(T));
LM2GM(lm_input, output + _m * dim_embed + _n, read_size * sizeof(T));
}
}
}
}
#define _XPU_DEF_REBUILD_SELF_HIDDEN_STATES_KERNEL(T) \
template __global__ void rebuildSelfHiddenStatesKernel<T>( \
const T* input, int* src_map, T* output, int dim_embed, int elem_cnt);
_XPU_DEF_REBUILD_SELF_HIDDEN_STATES_KERNEL(bfloat16);
_XPU_DEF_REBUILD_SELF_HIDDEN_STATES_KERNEL(float);
_XPU_DEF_REBUILD_SELF_HIDDEN_STATES_KERNEL(float16);
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,78 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
template <typename T>
inline __device__ void update_bad_words_logit(_global_ptr_ T* logits) {
__local__ T min_value = -1e10f;
mfence_lm();
LM2GM((void*)&(min_value), logits, sizeof(T));
}
template <>
inline __device__ void update_bad_words_logit<float16>(
_global_ptr_ float16* logits) {
__local__ short min_value = 0xFBFF;
mfence_lm();
LM2GM((void*)&(min_value), logits, sizeof(float16));
}
template <typename T>
__global__ void speculate_ban_bad_words(T* logits,
const int64_t* bad_words_list,
const int* output_padding_offset,
const int64_t bs,
const int64_t length,
const int64_t bad_words_length,
const int64_t token_num,
const int64_t max_seq_len) {
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = cluster_num() * core_num();
int start = -1;
int end = -1;
int output_padding_offset_lm;
partition(tid,
nthreads,
static_cast<int>(token_num * bad_words_length),
1,
&start,
&end);
for (int i = start; i < end; i++) {
int token_idx = i / bad_words_length;
GM2LM(output_padding_offset + token_idx,
&output_padding_offset_lm,
sizeof(int));
int bs_idx = (token_idx + output_padding_offset_lm) / max_seq_len;
if (bs_idx >= bs) {
continue;
}
int bad_words_idx = i - token_idx * bad_words_length;
int64_t bad_words_token_id = -1;
mfence_lm();
GM2LM(bad_words_list + bad_words_idx,
(void*)&(bad_words_token_id),
sizeof(int64_t));
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
update_bad_words_logit<T>(logits + token_idx * length + bad_words_token_id);
}
}
#define _XPU_DEF__BAN_BAD_WORDS_(DATA_TYPE) \
template __global__ void speculate_ban_bad_words( \
DATA_TYPE* logits, \
const int64_t* bad_words_list, \
const int* output_padding_offset, \
const int64_t bs, \
const int64_t length, \
const int64_t bad_words_length, \
const int64_t token_num, \
const int64_t max_seq_len);
_XPU_DEF__BAN_BAD_WORDS_(float);
_XPU_DEF__BAN_BAD_WORDS_(float16);
_XPU_DEF__BAN_BAD_WORDS_(bfloat16);
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,44 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/xtdk_io.h"
namespace xpu3 {
namespace plugin {
__global__ void speculate_clear_accept_nums(int* accept_num,
const int* seq_lens_decoder,
const int max_bsz) {
int cid = core_id();
int ncores = core_num();
int accept_num_lm = 0;
int seq_lens_decoder_lm;
for (int i = cid; i < max_bsz; i += ncores) {
GM2LM(seq_lens_decoder + i, &seq_lens_decoder_lm, sizeof(int));
if (seq_lens_decoder_lm == 0) {
LM2GM_ASYNC(&accept_num_lm, accept_num + i, sizeof(int));
}
mfence_lm();
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,288 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
static __device__ inline int loada_float(_shared_ptr_ const int *ptr) {
int ret;
__asm__ __volatile__("loada.w %0,%1" : "=&r"(ret) : "r"(ptr));
return ret;
}
static __device__ inline bool storea_float(_shared_ptr_ int *ptr, int value) {
bool ret;
__asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(ret) : "r"(value), "r"(ptr));
return ret;
}
static __device__ int atomic_add(_shared_ptr_ int *ptr, int value) {
bool fail = true;
int old_value;
while (fail) {
old_value = loada_float(ptr);
int new_value = old_value + value;
fail = storea_float(ptr, new_value);
}
return old_value;
}
static __device__ bool in_need_block_list(const int qid,
_shared_ptr_ int *need_block_list,
const int need_block_len) {
bool res = false;
for (int i = 0; i < need_block_len; i++) {
if (qid == need_block_list[i]) {
need_block_list[i] = -1;
res = true;
break;
}
}
return res;
}
__global__ void speculate_free_and_reschedule(bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
if (clusterid != 0 || cid >= bsz) return;
// assert bsz <= 640
const int max_bs = 640;
int value_zero = 0;
bool flag_true = true;
// 128 = seq_len(8192) / block_size(64)
// 每次最多处理block_table数量为128
const int block_table_now_len = 128;
int block_table_now[block_table_now_len];
for (int i = 0; i < block_table_now_len; i++) {
block_table_now[i] = -1;
}
bool stop_flag_lm;
int seq_lens_decoder_lm;
__shared__ int free_list_len_sm;
// 每次最多处理free_list数量为block_table_now_len
int free_list_now[block_table_now_len];
__shared__ int need_block_len_sm;
__shared__ int need_block_list_sm[max_bs];
__shared__ int used_list_len_sm[max_bs];
__shared__ bool step_max_block_flag;
__shared__ int in_need_block_list_len;
if (cid == 0) {
step_max_block_flag = false;
in_need_block_list_len = 0;
GM2SM_ASYNC(free_list_len, &free_list_len_sm, sizeof(int));
GM2SM_ASYNC(need_block_len, &need_block_len_sm, sizeof(int));
mfence();
if (need_block_len_sm > 0) {
GM2SM_ASYNC(
need_block_list, need_block_list_sm, sizeof(int) * need_block_len_sm);
}
GM2SM_ASYNC(used_list_len, used_list_len_sm, sizeof(int) * bsz);
mfence();
}
sync_cluster();
for (int tid = cid; tid < bsz; tid += ncores) {
int seq_lens_this_time_lm;
mfence();
GM2LM_ASYNC(stop_flags + tid, &stop_flag_lm, sizeof(bool));
GM2LM_ASYNC(seq_lens_decoder + tid, &seq_lens_decoder_lm, sizeof(int));
GM2LM_ASYNC(seq_lens_this_time + tid, &seq_lens_this_time_lm, sizeof(int));
mfence();
int max_possible_block_idx =
(seq_lens_decoder_lm + max_draft_tokens + 1) / block_size;
if (stop_flag_lm) {
// 回收block块
int64_t first_token_id_lm = -1;
mfence_lm();
LM2GM(&first_token_id_lm, first_token_ids + tid, sizeof(int64_t));
int encoder_block_len_lm;
int decoder_used_len_lm = used_list_len_sm[tid];
GM2LM(encoder_block_lens + tid, &encoder_block_len_lm, sizeof(int));
if (decoder_used_len_lm > 0) {
const int ori_free_list_len =
atomic_add(&free_list_len_sm, decoder_used_len_lm);
for (int i = 0; i < decoder_used_len_lm; i += block_table_now_len) {
int process_len = min(block_table_now_len, decoder_used_len_lm - i);
GM2LM(
block_tables + tid * block_num_per_seq + encoder_block_len_lm + i,
free_list_now,
process_len * sizeof(int));
LM2GM(free_list_now,
free_list + ori_free_list_len + i,
process_len * sizeof(int));
LM2GM(
block_table_now,
block_tables + tid * block_num_per_seq + encoder_block_len_lm + i,
process_len * sizeof(int));
}
used_list_len_sm[tid] = 0;
mfence();
LM2GM(&value_zero, encoder_block_lens + tid, sizeof(int));
}
} else if (seq_lens_this_time_lm != 0 &&
max_possible_block_idx < block_num_per_seq) {
int next_block_id;
GM2LM(block_tables + tid * block_num_per_seq +
(seq_lens_decoder_lm + max_draft_tokens + 1) / block_size,
&next_block_id,
sizeof(int));
if (next_block_id == -1) {
// 统计需要分配block的位置和总数
const int ori_need_block_len = atomic_add(&need_block_len_sm, 1);
need_block_list_sm[ori_need_block_len] = tid;
}
}
}
sync_cluster();
bool is_block_step_lm[max_bs];
int step_len_lm;
int step_block_list_lm[max_bs];
int recover_len_lm;
int recover_block_list_lm[max_bs];
if (cid == 0) {
GM2LM_ASYNC(is_block_step, is_block_step_lm, sizeof(bool) * bsz);
GM2LM_ASYNC(step_len, &step_len_lm, sizeof(int));
GM2LM_ASYNC(step_block_list, step_block_list_lm, sizeof(int) * bsz);
GM2LM_ASYNC(recover_len, &recover_len_lm, sizeof(int));
GM2LM_ASYNC(recover_block_list, recover_block_list_lm, sizeof(int) * bsz);
mfence();
}
if (cid == 0) {
while (need_block_len_sm > free_list_len_sm) {
// 调度block根据used_list_len从大到小回收block直到满足need_block_len已解码到最后一个block的query不参与调度马上就结束
int max_used_list_len_id = 0;
int max_used_list_len = 0;
for (int i = 0; i < bsz; i++) {
if (used_list_len_sm[i] > max_used_list_len) {
max_used_list_len_id = i;
max_used_list_len = used_list_len_sm[i];
}
}
if (max_used_list_len == 0) {
step_max_block_flag = true;
} else {
int encoder_block_len;
GM2LM(encoder_block_lens + max_used_list_len_id,
&encoder_block_len,
sizeof(int));
for (int i = 0; i < max_used_list_len; i += block_table_now_len) {
int process_len = min(block_table_now_len, max_used_list_len - i);
GM2LM(block_tables + max_used_list_len_id * block_num_per_seq +
encoder_block_len + i,
free_list_now,
process_len * sizeof(int));
LM2GM(free_list_now,
free_list + free_list_len_sm + i,
process_len * sizeof(int));
LM2GM(block_table_now,
block_tables + max_used_list_len_id * block_num_per_seq +
encoder_block_len + i,
process_len * sizeof(int));
}
step_block_list_lm[step_len_lm] = max_used_list_len_id;
int need_block_len_all = need_block_len_sm + in_need_block_list_len;
if (in_need_block_list(
max_used_list_len_id, need_block_list_sm, need_block_len_all)) {
need_block_len_sm--;
in_need_block_list_len++;
}
step_len_lm++;
free_list_len_sm += max_used_list_len;
LM2GM_ASYNC(
&flag_true, stop_flags + max_used_list_len_id, sizeof(bool));
LM2GM_ASYNC(&value_zero,
seq_lens_this_time + max_used_list_len_id,
sizeof(int));
LM2GM_ASYNC(
&value_zero, seq_lens_decoder + max_used_list_len_id, sizeof(int));
LM2GM_ASYNC(&value_zero,
encoder_block_lens + max_used_list_len_id,
sizeof(int));
used_list_len_sm[max_used_list_len_id] = 0;
mfence();
}
}
}
sync_cluster();
int need_block_len_all = need_block_len_sm + in_need_block_list_len;
for (int tid = cid; tid < need_block_len_all; tid += ncores) {
// 为需要block的位置分配block每个位置分配一个block
const int need_block_id = need_block_list_sm[tid];
if (need_block_id != -1) {
GM2LM(stop_flags + need_block_id, &stop_flag_lm, sizeof(bool));
if (!stop_flag_lm) {
// 如果需要的位置正好是上一步中被释放的位置,不做处理
used_list_len_sm[need_block_id]++;
const int ori_free_list_len = atomic_add(&free_list_len_sm, -1);
int tmp_seq_lens_decoder;
GM2LM(seq_lens_decoder + need_block_id,
&tmp_seq_lens_decoder,
sizeof(int));
int free_block_id;
GM2LM(free_list + ori_free_list_len - 1, &free_block_id, sizeof(int));
LM2GM(&free_block_id,
block_tables + need_block_id * block_num_per_seq +
(tmp_seq_lens_decoder + max_draft_tokens + 1) / block_size,
sizeof(int));
}
need_block_list_sm[tid] = -1;
}
}
sync_cluster();
int ori_need_block_len;
if (cid == 0) {
ori_need_block_len = need_block_len_sm;
need_block_len_sm = 0;
}
if (cid == 0) {
mfence();
LM2GM_ASYNC(step_block_list_lm, step_block_list, sizeof(int) * bsz);
LM2GM_ASYNC(is_block_step_lm, is_block_step, sizeof(bool) * bsz);
LM2GM_ASYNC(&step_len_lm, step_len, sizeof(int));
LM2GM_ASYNC(&recover_len_lm, recover_len, sizeof(int));
LM2GM_ASYNC(recover_block_list_lm, recover_block_list, sizeof(int) * bsz);
SM2GM_ASYNC(&free_list_len_sm, free_list_len, sizeof(int));
SM2GM_ASYNC(&need_block_len_sm, need_block_len, sizeof(int));
if (ori_need_block_len > 0) {
SM2GM_ASYNC(need_block_list_sm,
need_block_list,
sizeof(int) * ori_need_block_len);
}
SM2GM_ASYNC(used_list_len_sm, used_list_len, sizeof(int) * bsz);
mfence();
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,63 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/xtdk_io.h"
namespace xpu3 {
namespace plugin {
__global__ void speculate_get_output_padding_offset(
int* output_padding_offset,
int* output_cum_offsets,
const int* output_cum_offsets_tmp,
const int* seq_lens_output,
const int bsz,
const int max_seq_len) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
int seq_lens_output_lm;
int cum_offset_lm;
for (int bi = clusterid; bi < bsz; bi += nclusters) {
if (bi == 0) {
cum_offset_lm = 0;
} else {
GM2LM_ASYNC(output_cum_offsets_tmp + bi - 1, &cum_offset_lm, sizeof(int));
}
GM2LM_ASYNC(seq_lens_output + bi, &seq_lens_output_lm, sizeof(int));
mfence_lm();
for (int i = cid; i < seq_lens_output_lm; i += ncores) {
LM2GM_ASYNC(&cum_offset_lm,
output_padding_offset + bi * max_seq_len - cum_offset_lm + i,
sizeof(int));
}
if (cid == 0) {
LM2GM_ASYNC(&cum_offset_lm, output_cum_offsets + bi, sizeof(int));
}
mfence_lm();
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,122 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/cluster_simd.h"
#include "xpu/kernel/xtdk.h"
namespace xpu3 {
namespace plugin {
template <typename T>
__global__ void speculate_remove_padding(T* output_data,
const T* input_data,
const T* draft_tokens,
const int* seq_lens,
const int* seq_lens_encoder,
const int* cum_offsets,
int sequence_length,
int max_draft_tokens,
int bsz,
int token_num_data) {
int bid = cluster_id();
int tid = core_id();
int ncores = core_num();
int nclusters = cluster_num();
int seq_lens_now = 0;
int seq_lens_encoder_now = 0;
int cum_offsets_now = 0;
T input_date_now;
for (int bi = bid; bi < bsz; bi += nclusters) {
GM2LM(seq_lens + bi, &seq_lens_now, sizeof(int));
GM2LM(seq_lens_encoder + bi, &seq_lens_encoder_now, sizeof(int));
GM2LM(cum_offsets + bi, &cum_offsets_now, sizeof(int));
for (int i = tid; i < seq_lens_now; i += ncores) {
const int tgt_seq_id = bi * sequence_length - cum_offsets_now + i;
if (seq_lens_encoder_now > 0) {
const int src_seq_id = bi * sequence_length + i;
GM2LM(input_data + src_seq_id, &input_date_now, sizeof(T));
LM2GM(&input_date_now, output_data + tgt_seq_id, sizeof(T));
} else {
const int src_seq_id = bi * max_draft_tokens + i;
GM2LM(draft_tokens + src_seq_id, &input_date_now, sizeof(T));
LM2GM(&input_date_now, output_data + tgt_seq_id, sizeof(T));
}
}
}
}
__global__ void speculate_get_padding_offset(int* padding_offset,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
const int* cum_offsets,
const int* seq_lens,
const int max_seq_len,
int bsz) {
int bid = cluster_id();
int tid = core_id();
int ncores = core_num();
int nclusters = cluster_num();
int seq_lens_now = 0;
int cum_offsets_now = 0;
int cum_offsets_now_ind = 0;
for (int bi = bid; bi < bsz; bi += nclusters) {
GM2LM(seq_lens + bi, &seq_lens_now, sizeof(int));
if (bi == 0) {
cum_offsets_now = 0;
} else {
GM2LM(cum_offsets + bi - 1, &cum_offsets_now, sizeof(int));
}
GM2LM(cum_offsets + bi, &cum_offsets_now_ind, sizeof(int));
for (int i = tid; i < seq_lens_now; i += ncores) {
LM2GM(&cum_offsets_now,
padding_offset + bi * max_seq_len - cum_offsets_now + i,
sizeof(int));
}
LM2GM(&cum_offsets_now, cum_offsets_out + bi, sizeof(int));
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets_now_ind;
LM2GM(&cum_seq_len, cu_seqlens_q + bi + 1, sizeof(int));
LM2GM(&cum_seq_len, cu_seqlens_k + bi + 1, sizeof(int));
}
}
#define _XPU_DEF_SPECULATE_KERNELS_(T) \
template __global__ void speculate_remove_padding<T>(T*, \
const T*, \
const T*, \
const int*, \
const int*, \
const int*, \
int, \
int, \
int, \
int);
_XPU_DEF_SPECULATE_KERNELS_(float);
_XPU_DEF_SPECULATE_KERNELS_(float16);
_XPU_DEF_SPECULATE_KERNELS_(bfloat16);
_XPU_DEF_SPECULATE_KERNELS_(int64_t);
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,58 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/xtdk_io.h"
namespace xpu3 {
namespace plugin {
__global__ void speculate_get_seq_lens_output(int* seq_lens_output,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int real_bsz) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
int thread_num = ncores * nclusters;
int bid = clusterid * ncores + cid;
int one = 1;
int lm_seq_lens_this_time;
int lm_seq_lens_encoder;
for (; bid < real_bsz; bid += thread_num) {
GM2LM_ASYNC(seq_lens_this_time + bid, &lm_seq_lens_this_time, sizeof(int));
GM2LM(seq_lens_encoder + bid, &lm_seq_lens_encoder, sizeof(int));
if (lm_seq_lens_this_time == 0) {
continue;
} else if (lm_seq_lens_this_time == 1) {
LM2GM_ASYNC(&one, seq_lens_output + bid, sizeof(int));
} else if (lm_seq_lens_encoder != 0) {
LM2GM_ASYNC(&one, seq_lens_output + bid, sizeof(int));
} else {
LM2GM_ASYNC(&lm_seq_lens_this_time, seq_lens_output + bid, sizeof(int));
}
mfence_lm();
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,91 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_debug.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
template <typename T>
__global__ void speculate_min_length_logits_process(
T* logits,
const int64_t* cur_len,
const int64_t* min_len,
const int64_t* eos_token_id,
const int* output_padding_offset,
const int* output_cum_offsets,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t token_num,
const int64_t max_seq_len) {
int ncores = core_num();
int cid = core_id();
int tid = cluster_num() * cid + cluster_id();
int nthreads = cluster_num() * ncores;
int64_t cur_len_now;
int64_t min_len_now;
int64_t eos_token_id_now;
int64_t bi;
int64_t end_num;
int output_padding_offset_now;
int output_cum_offsets_now;
__simd__ float float32logits_now[32];
for (int64_t i = tid; i < token_num * end_length; i += nthreads) {
int64_t token_idx = i / end_length;
GM2LM(output_padding_offset + token_idx,
&output_padding_offset_now,
sizeof(int));
bi = (token_idx + output_padding_offset_now) / max_seq_len;
if (bi >= bs) {
continue;
}
end_num = i % end_length;
GM2LM_ASYNC(
output_cum_offsets + bi, (void*)&output_cum_offsets_now, sizeof(int));
GM2LM_ASYNC(cur_len + bi, (void*)&(cur_len_now), sizeof(int64_t));
GM2LM_ASYNC(min_len + bi, (void*)&(min_len_now), sizeof(int64_t));
mfence();
int query_start_token_idx = bi * max_seq_len - output_cum_offsets_now;
if (cur_len_now >= 0 &&
(cur_len_now + (token_idx - query_start_token_idx) < min_len_now)) {
GM2LM(
eos_token_id + end_num, (void*)&(eos_token_id_now), sizeof(int64_t));
GM2LM(logits + token_idx * length + eos_token_id_now,
(void*)float32logits_now,
sizeof(T));
primitive_cast<T, float>(
(const T*)(float32logits_now), float32logits_now, 1);
float32logits_now[0] = std::is_same<T, float16>::value ? -1e4 : -1e10;
mfence_lm();
primitive_cast<float, T>(float32logits_now, (T*)float32logits_now, 1);
LM2GM((void*)float32logits_now,
logits + token_idx * length + eos_token_id_now,
sizeof(T));
}
}
}
#define _XPU_DEF__UPDATE_LOGITS_REPEAT_TIMES_(DATA_TYPE) \
template __global__ void speculate_min_length_logits_process<DATA_TYPE>( \
DATA_TYPE * logits, \
const int64_t* cur_len, \
const int64_t* min_len, \
const int64_t* eos_token_id, \
const int* output_padding_offset, \
const int* output_cum_offsets, \
const int64_t bs, \
const int64_t length, \
const int64_t length_id, \
const int64_t end_length, \
const int64_t token_num, \
const int64_t max_seq_len);
_XPU_DEF__UPDATE_LOGITS_REPEAT_TIMES_(float);
_XPU_DEF__UPDATE_LOGITS_REPEAT_TIMES_(float16);
_XPU_DEF__UPDATE_LOGITS_REPEAT_TIMES_(bfloat16);
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,98 @@
#include "xpu/kernel/cluster_debug.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
namespace xpu3 {
namespace plugin {
__global__ void speculate_set_stop_value_multi_seqs(bool *stop_flags,
int64_t *accept_tokens,
int *accept_nums,
const int64_t *pre_ids,
const int64_t *step_idx,
const int64_t *stop_seqs,
const int *stop_seqs_len,
const int *seq_lens,
const int64_t *end_ids,
const int bs,
const int accept_tokens_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int pre_ids_len) {
int cls_id = cluster_id();
int cid = core_id();
int ncores = core_num();
int nclusters = cluster_num();
int accept_num = 0;
int64_t step_idx_now = 0;
bool stop_flags_now = false;
int stop_seq_len = 0;
for (int bid = cls_id; bid < bs; bid += nclusters) {
GM2LM_ASYNC(accept_nums + bid, &accept_num, sizeof(int));
GM2LM_ASYNC(step_idx + bid, &step_idx_now, sizeof(int64_t));
GM2LM(stop_flags + bid, &stop_flags_now, sizeof(bool));
if (stop_flags_now) {
continue;
}
for (int tid = cid; tid < stop_seqs_bs; tid += ncores) {
GM2LM_ASYNC(stop_seqs_len + tid, &stop_seq_len, sizeof(int));
if (stop_seq_len <= 0) {
continue;
}
int accept_idx = 0;
bool is_end = false;
int64_t stop_seq_now_lm = 0;
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) {
continue;
}
// 遍历一个 stop_seqs
for (int i = stop_seq_len - 1; i >= 0; --i) {
int64_t cur_token_idx = -1;
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
if (stop_seq_len - 1 - i < accept_idx) {
GM2LM(accept_tokens + bid * accept_tokens_len + accept_idx -
(stop_seq_len - 1 - i) - 1,
&cur_token_idx,
sizeof(int64_t));
} else {
int pre_ids_idx =
step_idx_now - accept_num + accept_idx - (stop_seq_len - 1 - i);
// EC3
// 特殊拼接会导致input_ids最后一位无特殊token即pre_ids[0]可能为23,
// 导致异常结束
if (pre_ids_idx <= 0) {
break;
}
GM2LM(pre_ids + bid * pre_ids_len + pre_ids_idx,
&cur_token_idx,
sizeof(int64_t));
}
GM2LM(stop_seqs + tid * stop_seqs_max_len + i,
&stop_seq_now_lm,
sizeof(int64_t));
if (cur_token_idx != stop_seq_now_lm) {
break;
}
if (i == 0) {
is_end = true;
}
}
}
if (is_end) {
int64_t end_id_lm;
bool value_true = true;
GM2LM(end_ids, &end_id_lm, sizeof(int64_t));
LM2GM_ASYNC(&end_id_lm,
accept_tokens + bid * accept_tokens_len + accept_idx - 1,
sizeof(int64_t));
LM2GM(&value_true, stop_flags + bid, sizeof(bool));
}
}
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,83 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/xtdk_io.h"
namespace xpu3 {
namespace plugin {
__global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all,
const int64_t *accept_tokens,
const int *accept_num,
const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx,
int bs,
int length,
int max_draft_tokens) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
if (clusterid != 0) return;
int64_t pre_ids_all_lm[max_draft_tokens];
int64_t accept_tokens_lm[max_draft_tokens];
int accept_num_lm;
bool stop_flags_lm;
int seq_lens_encoder_lm;
int seq_lens_decoder_lm;
int64_t step_idx_lm;
for (int i = cid; i < bs; i += ncores) {
GM2LM_ASYNC(stop_flags + i, &stop_flags_lm, sizeof(bool));
GM2LM_ASYNC(seq_lens_encoder + i, &seq_lens_encoder_lm, sizeof(int));
GM2LM_ASYNC(seq_lens_decoder + i, &seq_lens_decoder_lm, sizeof(int));
GM2LM_ASYNC(step_idx + i, &step_idx_lm, sizeof(int64_t));
GM2LM_ASYNC(accept_num + i, &accept_num_lm, sizeof(int));
mfence_lm();
if (stop_flags_lm ||
(seq_lens_encoder_lm == 0 && seq_lens_decoder_lm == 0) ||
step_idx_lm < 0)
continue;
// Avoid loading large amounts of data
int pre_ids_start_idx = i * length + step_idx_lm - max_draft_tokens + 1;
GM2LM_ASYNC(pre_ids_all + pre_ids_start_idx,
pre_ids_all_lm,
max_draft_tokens * sizeof(int64_t));
GM2LM_ASYNC(accept_tokens + i * max_draft_tokens,
accept_tokens_lm,
max_draft_tokens * sizeof(int64_t));
mfence_lm();
for (int j = 0; j < accept_num_lm; j++) {
pre_ids_all_lm[max_draft_tokens - 1 - j] =
accept_tokens_lm[accept_num_lm - 1 - j];
}
LM2GM(&pre_ids_all_lm,
pre_ids_all + pre_ids_start_idx,
max_draft_tokens * sizeof(int64_t));
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,268 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/cluster_primitive_template.h"
namespace xpu3 {
namespace plugin {
static __device__ void atomic_add(_shared_ptr_ int *ptr, int v) {
bool fail = true;
while (fail) {
int a;
__asm__ __volatile__("loada.w %0,%1" : "=&r"(a) : "r"(ptr));
a += v;
__asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(fail) : "r"(a), "r"(ptr));
}
}
// original version
__device__ void speculate_update_repeat_times_normal(
char *lm,
__shared_ptr__ char *sm,
__global_ptr__ const int64_t *pre_ids,
__global_ptr__ const int64_t *cur_len,
__global_ptr__ int *repeat_times,
__global_ptr__ const int *output_padding_offset,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t token_num,
const int64_t max_seq_len) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
int tid = clusterid * ncores + cid;
const int max_sm_len = 256 * 1024 / sizeof(int);
__shared_ptr__ int *repeated_times_sm = (__shared_ptr__ int *)sm;
int64_t pre_id_lm;
int n_length = (length + max_sm_len - 1) / max_sm_len;
int64_t *cur_len_lm = (int64_t *)lm;
int output_padding_offset_now;
GM2LM(cur_len, cur_len_lm, bs * sizeof(int64_t));
for (int nli = 0; nli < n_length; nli++) {
int step = nli * max_sm_len;
int cur_length = min(max_sm_len, length - step);
for (int64_t i = clusterid; i < token_num; i += nclusters) {
GM2LM(output_padding_offset + i, &output_padding_offset_now, sizeof(int));
int64_t bi = (i + output_padding_offset_now) / max_seq_len;
if (bi >= bs || cur_len_lm[bi] < 0) {
continue;
}
if (cid == 0) {
GM2SM_ASYNC(repeat_times + i * length + step,
repeated_times_sm,
sizeof(int) * cur_length);
}
mfence();
sync_cluster();
for (int j = cid; j < length_id; j += ncores) {
GM2LM(pre_ids + bi * length_id + j, &pre_id_lm, sizeof(int64_t));
if (pre_id_lm < 0) {
break;
}
if (pre_id_lm >= step && pre_id_lm < step + cur_length) {
atomic_add(repeated_times_sm + pre_id_lm - step, 1);
}
mfence();
}
sync_cluster();
if (cid == 0) {
SM2GM_ASYNC(repeated_times_sm,
repeat_times + i * length + step,
sizeof(int) * cur_length);
}
mfence();
sync_cluster();
}
}
}
// best optimized version
// about 49000+ ns
__device__ void speculate_update_repeat_times_optimized(
char *lm,
__shared_ptr__ char *sm,
__global_ptr__ const int64_t *pre_ids, // {bs, length_id}
__global_ptr__ const int64_t *cur_len, // {bs}
__global_ptr__ int *repeat_times, // {token_num, length}
__global_ptr__ const int *output_padding_offset, // {token_num}
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t token_num,
const int64_t max_seq_len) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
int tid = clusterid * ncores + cid;
const int repeat_times_sm_len = 250 * 1024 / sizeof(int);
__shared_ptr__ int *repeat_times_sm = (__shared_ptr__ int *)sm;
// assert bs <= 640
int cur_len_sm_len = 640;
__shared_ptr__ int64_t *cur_len_sm =
(__shared_ptr__ int64_t *)(repeat_times_sm + repeat_times_sm_len);
__shared_ptr__ int *output_padding_offset_sm =
(__shared_ptr__ int *)(cur_len_sm + cur_len_sm_len);
DoublePtr<1, SmPtr<int>> buffer_ptr_output_padding_offset(
(SmPtr<int>(output_padding_offset_sm)));
int pre_ids_lm_len = 4;
int64_t *pre_ids_lm = (int64_t *)lm;
DoublePtr<4, LmPtr<int64_t>> buffer_ptr_pre_ids((LmPtr<int64_t>(pre_ids_lm)));
int64_t i = clusterid;
if (i < token_num && cid == 0) {
GM2SM_ASYNC(cur_len, cur_len_sm, bs * sizeof(int64_t));
buffer_ptr_output_padding_offset.gm_load_async(output_padding_offset + i,
1);
mfence_sm();
}
sync_all();
for (; i < token_num; i += nclusters) {
if (cid == 0 && i + nclusters < token_num) {
buffer_ptr_output_padding_offset.next().gm_load_async(
output_padding_offset + i + nclusters, 1);
}
int64_t bi = (i + (buffer_ptr_output_padding_offset.ptr[0])) / max_seq_len;
buffer_ptr_output_padding_offset.toggle();
if (bi >= bs || cur_len_sm[bi] < 0) {
mfence_sm();
sync_all();
continue;
}
int64_t boundary = -1;
for (int64_t repeat_times_start = 0; repeat_times_start < length;
repeat_times_start += repeat_times_sm_len) {
int64_t repeat_times_read_size =
min(length - repeat_times_start, repeat_times_sm_len);
int64_t start, end;
partition(cid, ncores, repeat_times_read_size, 1, &start, &end);
int64_t load_start = repeat_times_start + start;
int64_t repeat_times_read_size_per_core = end - start;
if (repeat_times_read_size_per_core > 0) {
GM2SM(repeat_times + i * length + load_start,
repeat_times_sm + start,
repeat_times_read_size_per_core * sizeof(int));
}
sync_all();
// each core loads pre_ids step by step and record the index of pre_ids
// which is less than zero, and store the index to boundary
if (repeat_times_start == 0) {
bool do_prone = false;
int64_t j = cid * pre_ids_lm_len;
int64_t pre_ids_read_size =
min(static_cast<int64_t>(pre_ids_lm_len), length_id - j);
buffer_ptr_pre_ids.gm_load(pre_ids + bi * length_id + j,
pre_ids_read_size);
for (; j < length_id && !do_prone; j += ncores * pre_ids_lm_len) {
int64_t pre_ids_read_size_next =
min(static_cast<int64_t>(pre_ids_lm_len),
length_id - (j + ncores * pre_ids_lm_len));
if (buffer_ptr_pre_ids.ptr[pre_ids_read_size - 1] >= 0 &&
pre_ids_read_size_next > 0) {
buffer_ptr_pre_ids.next().gm_load_async(
pre_ids + bi * length_id + j + ncores * pre_ids_lm_len,
pre_ids_read_size_next);
}
for (int k = 0; k < pre_ids_read_size; k++) {
if (buffer_ptr_pre_ids.ptr[k] < 0) {
do_prone = true;
boundary = j + k;
break;
}
if (buffer_ptr_pre_ids.ptr[k] >= repeat_times_start &&
buffer_ptr_pre_ids.ptr[k] <
repeat_times_start + repeat_times_read_size) {
atomic_add(repeat_times_sm + buffer_ptr_pre_ids.ptr[k] -
repeat_times_start,
1);
}
}
mfence_lm();
pre_ids_read_size = pre_ids_read_size_next;
buffer_ptr_pre_ids.toggle();
}
}
// each core loads all the needed pre_ids into lm without mfence inbetween
// according to the index recorded by previous iteration
else {
int cnt = -1;
int64_t pre_ids_read_size = 0;
for (int64_t j = cid * pre_ids_lm_len; j < boundary;
j += ncores * pre_ids_lm_len) {
cnt++;
pre_ids_read_size =
min(static_cast<int64_t>(pre_ids_lm_len), boundary - j);
GM2LM_ASYNC(pre_ids + bi * length_id + j,
pre_ids_lm + cnt * pre_ids_lm_len,
pre_ids_read_size * sizeof(int64_t));
}
mfence_lm();
cnt = max(0, cnt);
for (int k = 0; k < cnt * pre_ids_lm_len + pre_ids_read_size; k++) {
if (pre_ids_lm[k] >= repeat_times_start &&
pre_ids_lm[k] < repeat_times_start + repeat_times_read_size) {
atomic_add(repeat_times_sm + pre_ids_lm[k] - repeat_times_start, 1);
}
}
}
mfence_sm();
sync_cluster();
if (repeat_times_read_size_per_core > 0) {
SM2GM(repeat_times_sm + start,
repeat_times + i * length + load_start,
repeat_times_read_size_per_core * sizeof(int));
}
sync_all();
}
}
}
__global__ void speculate_update_repeat_times(const int64_t *pre_ids,
const int64_t *cur_len,
int *repeat_times,
const int *output_padding_offset,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t token_num,
const int64_t max_seq_len) {
char lm[6 * 1024];
__shared__ char sm[256 * 1024];
if (length_id <= 6 * 1024 * 64 / sizeof(int64_t)) {
speculate_update_repeat_times_optimized(lm,
sm,
pre_ids,
cur_len,
repeat_times,
output_padding_offset,
bs,
length,
length_id,
token_num,
max_seq_len);
} else {
speculate_update_repeat_times_normal(lm,
sm,
pre_ids,
cur_len,
repeat_times,
output_padding_offset,
bs,
length,
length_id,
token_num,
max_seq_len);
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,202 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2025 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/cluster_primitive_template.h"
namespace xpu3 {
namespace plugin {
static inline __device__ int v_reduce(int32x16_t &v0, int32x16_t &v1) {
int res;
v1 = vvadd_int32x16(v0, v1);
auto v = vsrlp_int32x16(256, v1);
v1 = vvadd_int32x16(v, v1);
v = vsrlp_int32x16(128, v1);
v1 = vvadd_int32x16(v, v1);
v = vsrlp_int32x16(64, v1);
v1 = vvadd_int32x16(v, v1);
v = vsrlp_int32x16(32, v1);
v1 = vvadd_int32x16(v, v1);
res = vextract_int32x16(v1, 1);
return res;
}
static inline __device__ int ClusterReduce(
const _shared_ptr_ int *stop_flag_now_int_sm, int len) {
int sum = 0;
if (core_id() == 0) {
int32x16_t vec_x_0;
int32x16_t vec_x_1;
int32x16_t vec_y_0 = vzero<int>();
int32x16_t vec_y_1 = vzero<int>();
for (int i = 0; i < len; i += 32) {
vload2_sm(stop_flag_now_int_sm + i, vec_x_0, vec_x_1);
vec_y_0 = vvadd_int32x16(vec_y_0, vec_x_0);
vec_y_1 = vvadd_int32x16(vec_y_1, vec_x_1);
}
sum = v_reduce(vec_y_0, vec_y_1);
}
return sum;
}
template <int THREADBLOCK_SIZE>
__global__ void speculate_update_v3(
int *seq_lens_encoder, // 输入&输出 [B_max, ]
int *seq_lens_decoder, // 输入&输出 [B_max, ]
bool *not_need_stop, // 输出 [1,]
int64_t *draft_tokens, // 输出 [B_max, T_max]
int *actual_draft_token_nums, // 输入&输出 [B_max, ]
const int64_t *accept_tokens, // 输入 [B_max, T_max]
const int *accept_num, // 输入 [B_max, ]
const bool *stop_flags, // 输入 [B_max, ]
const int *seq_lens_this_time, // 输入 [B_real,]
const bool *is_block_step, // 输入 [B_max, ]
const int64_t *stop_nums, // 输入 [1,]
const int real_bsz,
const int max_bsz,
const int max_draft_tokens) {
// real_bsz <= max_bsz <= THREADBLOCK_SIZE;
const int cid = core_id();
const int tid = core_id() * cluster_num() + cluster_id();
const int nthreads = core_num() * cluster_num();
__shared__ int seq_lens_encoder_sm[THREADBLOCK_SIZE]; // 输入&输出 [B_max] 2K
__shared__ int seq_lens_decoder_sm[THREADBLOCK_SIZE]; // 输入&输出 [B_max] 2K
__shared__ int
actual_draft_token_nums_sm[THREADBLOCK_SIZE]; // 输出 [B_max] 2K
__shared__ int accept_num_sm[THREADBLOCK_SIZE]; // 输入&输出 [B_max] 2K
__shared__ bool stop_flags_sm[THREADBLOCK_SIZE]; // 输入 [B_max] 512B
__shared__ int seq_lens_this_time_sm[THREADBLOCK_SIZE]; // 输入 [B_real] 2K
__shared__ bool is_block_step_sm[THREADBLOCK_SIZE]; // 输入 [B_max] 512B
__shared__ int stop_flag_now_int_sm[64];
bool not_need_stop_lm; // 输出[1]
int64_t stop_nums_lm; // 输入[1]
int bid_start_core, bid_end_core;
partition(tid, nthreads, max_bsz, 1, &bid_start_core, &bid_end_core);
if (cid == 0) {
GM2SM_ASYNC(seq_lens_encoder, seq_lens_encoder_sm, max_bsz * sizeof(int));
GM2SM_ASYNC(seq_lens_decoder, seq_lens_decoder_sm, max_bsz * sizeof(int));
GM2SM_ASYNC(actual_draft_token_nums,
actual_draft_token_nums_sm,
max_bsz * sizeof(int));
GM2SM_ASYNC(accept_num, accept_num_sm, max_bsz * sizeof(int));
GM2SM_ASYNC(stop_flags, stop_flags_sm, max_bsz * sizeof(bool));
GM2SM_ASYNC(
seq_lens_this_time, seq_lens_this_time_sm, max_bsz * sizeof(int));
GM2SM_ASYNC(is_block_step, is_block_step_sm, max_bsz * sizeof(bool));
GM2LM_ASYNC(stop_nums, &stop_nums_lm, sizeof(int64_t));
mfence_lm_sm();
}
sync_all();
stop_flag_now_int_sm[cid] = 0;
for (int bid = bid_start_core; bid < bid_end_core; bid++) {
const int accept_num_now = accept_num_sm[bid];
int stop_flag_now_int = 0;
if (!is_block_step_sm[bid] && bid < real_bsz) {
if (stop_flags_sm[bid]) {
stop_flag_now_int = 1;
}
if (seq_lens_encoder_sm[bid] == 0) {
seq_lens_decoder_sm[bid] += accept_num_now;
}
// 对于append模式需要根据接收与否确定是否要降低下次draft
// token的数量
if (seq_lens_this_time_sm[bid] > 1 && seq_lens_encoder_sm[bid] == 0) {
auto current_actual_draft_token_num = actual_draft_token_nums_sm[bid];
if (accept_num_now - 1 == current_actual_draft_token_num) {
if (current_actual_draft_token_num + 2 <= max_draft_tokens - 1) {
actual_draft_token_nums_sm[bid] =
current_actual_draft_token_num + 2;
} else if (current_actual_draft_token_num + 1 <=
max_draft_tokens - 1) {
actual_draft_token_nums_sm[bid] =
current_actual_draft_token_num + 1;
} else {
actual_draft_token_nums_sm[bid] = max_draft_tokens - 1;
}
} else {
actual_draft_token_nums_sm[bid] =
actual_draft_token_nums_sm[bid] - 1 >= 1
? actual_draft_token_nums_sm[bid] - 1
: 1;
}
}
if (seq_lens_encoder_sm[bid] != 0) {
seq_lens_decoder_sm[bid] += seq_lens_encoder_sm[bid];
seq_lens_encoder_sm[bid] = 0;
}
if (stop_flag_now_int) {
seq_lens_decoder_sm[bid] = 0;
} else {
// 这里试下编译器的新特性
draft_tokens[bid * max_draft_tokens] =
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
}
} else if (bid >= real_bsz && bid < max_bsz) {
stop_flag_now_int = 1;
}
stop_flag_now_int_sm[cid] += stop_flag_now_int;
mfence_lm();
}
mfence_sm();
sync_all();
// printf("cid = %d, stop_sum = %d \n", cid, stop_flag_now_int_sm[cid]);
int64_t stop_sum = ClusterReduce(stop_flag_now_int_sm, 64);
sync_all();
if (cid == 0) {
// printf("stop_sum = %d \n", static_cast<int>(stop_sum));
not_need_stop_lm = stop_sum < stop_nums_lm;
mfence_lm();
SM2GM_ASYNC(seq_lens_encoder_sm, seq_lens_encoder, max_bsz * sizeof(int));
SM2GM_ASYNC(seq_lens_decoder_sm, seq_lens_decoder, max_bsz * sizeof(int));
LM2GM_ASYNC(&not_need_stop_lm, not_need_stop, 1 * sizeof(bool));
SM2GM_ASYNC(actual_draft_token_nums_sm,
actual_draft_token_nums,
max_bsz * sizeof(int));
mfence();
}
}
template __global__ void speculate_update_v3<512>(int *seq_lens_encoder,
int *seq_lens_decoder,
bool *not_need_stop,
int64_t *draft_tokens,
int *actual_draft_token_nums,
const int64_t *accept_tokens,
const int *accept_num,
const bool *stop_flags,
const int *seq_lens_this_time,
const bool *is_block_step,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_draft_tokens);
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,281 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_debug.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
__device__ void do_cast(const int *xlm, float *ylm, int64_t len) {
for (int64_t i = 0; i < len; i += 32) {
int32x16_t xl = vload_lm_int32x16(xlm + i);
int32x16_t xh = vload_lm_int32x16(xlm + i + 16);
float32x16_t yl = vfix2float(xl);
float32x16_t yh = vfix2float(xh);
vstore_lm_float32x16(ylm + i, yl);
vstore_lm_float32x16(ylm + i + 16, yh);
}
mfence_lm();
}
template <typename T>
__global__ void speculate_update_value_by_repeat_times(
const int *repeat_times,
const T *penalty_scores,
const T *frequency_score,
const T *presence_score,
const float *temperatures,
T *logits,
const int *output_padding_offset,
const int64_t bs,
const int64_t length,
const int64_t token_num,
const int64_t max_seq_len) {
int ncores = core_num();
int cid = core_id();
int thread_id = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
int64_t start = -1;
int64_t end = -1;
partition(thread_id, nthreads, token_num * length, 1, &start, &end);
if (start >= end) {
return;
}
int64_t token_start = start / length;
int64_t token_end = end / length;
if (token_end >= token_num) {
token_end = token_num - 1;
}
int output_padding_offset_start_lm;
int output_padding_offset_end_lm;
GM2LM_ASYNC(output_padding_offset + token_start,
(void *)&output_padding_offset_start_lm,
sizeof(int));
GM2LM(output_padding_offset + token_end,
(void *)&output_padding_offset_end_lm,
sizeof(int));
int64_t bs_start =
(token_start + output_padding_offset_start_lm) / max_seq_len;
int64_t bs_end = (token_end + output_padding_offset_end_lm) / max_seq_len;
const int param_len = 256;
// ncores = 64 for xpu2
__shared__ __simd__ float alpha_buf[param_len * 64];
__shared__ __simd__ float beta_buf[param_len * 64];
__shared__ __simd__ float gamma_buf[param_len * 64];
__shared__ __simd__ float temperatures_buf[param_len * 64];
_shared_ptr_ float *alpha_sm = alpha_buf + cid * param_len;
_shared_ptr_ float *beta_sm = beta_buf + cid * param_len;
_shared_ptr_ float *gamma_sm = gamma_buf + cid * param_len;
_shared_ptr_ float *temperatures_sm = temperatures_buf + cid * param_len;
int read_param_len = bs_end - bs_start + 1;
GM2SM_ASYNC(penalty_scores + bs_start, alpha_sm, read_param_len * sizeof(T));
GM2SM_ASYNC(frequency_score + bs_start, beta_sm, read_param_len * sizeof(T));
GM2SM_ASYNC(presence_score + bs_start, gamma_sm, read_param_len * sizeof(T));
GM2SM(
temperatures + bs_start, temperatures_sm, read_param_len * sizeof(float));
primitive_cast_sm<T, float>(
(const _shared_ptr_ T *)(alpha_sm), alpha_sm, read_param_len);
primitive_cast_sm<T, float>(
(const _shared_ptr_ T *)(beta_sm), beta_sm, read_param_len);
primitive_cast_sm<T, float>(
(const _shared_ptr_ T *)(gamma_sm), gamma_sm, read_param_len);
float logit_now;
float alpha;
float beta;
float gamma;
float temperature;
int time;
const int buffer_len = 512;
__simd__ float logits_lm[buffer_len];
int times_lm[buffer_len];
int output_padding_offset_lm[buffer_len];
for (int64_t i = start; i < end; i += buffer_len) {
int read_len = min(end - i, buffer_len);
GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T));
GM2LM_ASYNC(output_padding_offset + i / length,
output_padding_offset_lm,
((read_len + length - 1) / length + 1) * sizeof(int));
GM2LM(repeat_times + i, times_lm, read_len * sizeof(int));
primitive_cast<T, float>((const T *)(logits_lm), logits_lm, read_len);
for (int j = 0; j < read_len; j++) {
time = times_lm[j];
logit_now = logits_lm[j];
int token_idx = (i + j) / length;
int bs_idx =
(token_idx + output_padding_offset_lm[token_idx - i / length]) /
max_seq_len;
if (bs_idx >= bs) {
continue;
}
int param_idx = bs_idx - bs_start;
temperature = temperatures_sm[param_idx];
if (time != 0) {
alpha = alpha_sm[param_idx];
beta = beta_sm[param_idx];
gamma = gamma_sm[param_idx];
logit_now = logit_now < 0.0f ? logit_now * alpha : logit_now / alpha;
logit_now = logit_now - time * beta - gamma;
}
logits_lm[j] = logit_now / temperature;
}
mfence_lm();
primitive_cast<float, T>(logits_lm, (T *)logits_lm, read_len);
LM2GM(logits_lm, logits + i, read_len * sizeof(T));
}
}
#define _XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_(DATA_TYPE) \
template __global__ void speculate_update_value_by_repeat_times( \
const int *repeat_times, \
const DATA_TYPE *penalty_scores, \
const DATA_TYPE *frequency_score, \
const DATA_TYPE *presence_score, \
const float *temperatures, \
DATA_TYPE *logits, \
const int *output_padding_offset, \
const int64_t bs, \
const int64_t length, \
const int64_t token_num, \
const int64_t max_seq_len);
_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_(float);
_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_(float16);
_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_(bfloat16);
template <typename T>
__global__ void speculate_update_value_by_repeat_times_simd(
const int *repeat_times, // [bs * length]
const T *penalty_scores, // [bs]
const T *frequency_score, // [bs]
const T *presence_score, // [bs]
const float *temperatures, // [bs]
T *logits, // [bs * length]
const int *output_padding_offset,
const int64_t bs,
const int64_t length,
const int64_t token_num,
const int64_t max_seq_len) {
int ncores = core_num();
int cid = core_id();
int thread_id = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
int64_t start = -1;
int64_t end = -1;
partition(thread_id, nthreads, token_num * length, 16, &start, &end);
if (start >= end) {
return;
}
const int param_len = 256;
// ncores = 64 for xpu3
__shared__ __simd__ float alpha_buf[param_len * 64];
__shared__ __simd__ float beta_buf[param_len * 64];
__shared__ __simd__ float gamma_buf[param_len * 64];
__shared__ __simd__ float temperatures_buf[param_len * 64];
// assert bs <= param_len * 64
if (cid == 0) {
GM2SM_ASYNC(penalty_scores, alpha_buf, bs * sizeof(T));
GM2SM_ASYNC(frequency_score, beta_buf, bs * sizeof(T));
GM2SM_ASYNC(presence_score, gamma_buf, bs * sizeof(T));
GM2SM(temperatures, temperatures_buf, bs * sizeof(float));
primitive_cast_sm<T, float>(
(const _shared_ptr_ T *)(alpha_buf), alpha_buf, bs);
primitive_cast_sm<T, float>(
(const _shared_ptr_ T *)(beta_buf), beta_buf, bs);
primitive_cast_sm<T, float>(
(const _shared_ptr_ T *)(gamma_buf), gamma_buf, bs);
}
mfence();
sync_all();
float logit_now;
float alpha;
float beta;
float gamma;
float temperature;
int time;
const int buffer_len = 512;
__simd__ float logits_lm[buffer_len];
__simd__ float times_lm[buffer_len];
int output_padding_offset_lm[buffer_len];
float32x16_t logits_;
float32x16_t logits_tmp_0;
float32x16_t logits_tmp_1;
float32x16_t time_;
for (int64_t i = start; i < end; i += buffer_len) {
int read_len = min(end - i, buffer_len);
GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T));
GM2LM_ASYNC(output_padding_offset + i / length,
output_padding_offset_lm,
((read_len + length - 1) / length + 1) * sizeof(int));
GM2LM(repeat_times + i, times_lm, read_len * sizeof(int));
primitive_cast<T, float>((const T *)(logits_lm), logits_lm, read_len);
do_cast((const int *)(times_lm), times_lm, read_len);
int time_mask = 0;
int logit_mask = 0;
for (int j = 0; j < read_len; j += 16) {
time_ = vload_lm_float32x16(times_lm + j);
logits_ = vload_lm_float32x16(logits_lm + j);
int token_idx = (i + j) / length;
int bs_idx =
(token_idx + output_padding_offset_lm[token_idx - i / length]) /
max_seq_len;
if (bs_idx >= bs) {
continue;
}
int param_idx = bs_idx;
temperature = temperatures_buf[param_idx];
alpha = alpha_buf[param_idx];
beta = beta_buf[param_idx];
gamma = gamma_buf[param_idx];
time_mask = svneq_float32x16(0.f, time_); // time != 0 mask
logit_mask = svle_float32x16(0.f, logits_); // logit >= 0 mask
time_ = svmul_float32x16(beta, time_); // time * beta
time_ = svadd_float32x16(gamma, time_); // time * beta + gamma
logits_ = svmul_float32x16_mh(
alpha,
logits_,
logits_,
(time_mask &
~logit_mask)); // when time != 0 && logit < 0, do alpha * logit
logits_ = svmul_float32x16_mh(
1.0f / alpha,
logits_,
logits_,
(time_mask & logit_mask)); // when time != 0 && >=0, do logit / alpha
logits_ = vvsub_float32x16_mh(
logits_, time_, logits_, time_mask); // when time != 0, do logit =
// logit - time * beta - gamma;
logits_ =
svmul_float32x16(1.0f / temperature, logits_); // logit / temperature
vstore_lm_float32x16(logits_lm + j, logits_);
}
mfence_lm();
primitive_cast<float, T>(logits_lm, (T *)logits_lm, read_len);
LM2GM(logits_lm, logits + i, read_len * sizeof(T));
}
}
#define _XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(DATA_TYPE) \
template __global__ void speculate_update_value_by_repeat_times_simd( \
const int *repeat_times, \
const DATA_TYPE *penalty_scores, \
const DATA_TYPE *frequency_score, \
const DATA_TYPE *presence_score, \
const float *temperatures, \
DATA_TYPE *logits, \
const int *output_padding_offset, \
const int64_t bs, \
const int64_t length, \
const int64_t token_num, \
const int64_t max_seq_len);
_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(float);
_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(float16);
_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(bfloat16);
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,335 @@
#include "xpu/kernel/cluster_debug.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
// #include "xpu/internal/aten/xrand_philox4x32_10.h"
// #include "xpu/internal/aten/xrand_uniform.h"
// #include "xpu/internal/aten/xrand_global.h"
namespace xpu3 {
namespace plugin {
static inline __device__ int v_reduce(int32x16_t &v0, int32x16_t &v1) {
int res;
v1 = vvadd_int32x16(v0, v1);
auto v = vsrlp_int32x16(256, v1);
v1 = vvadd_int32x16(v, v1);
v = vsrlp_int32x16(128, v1);
v1 = vvadd_int32x16(v, v1);
v = vsrlp_int32x16(64, v1);
v1 = vvadd_int32x16(v, v1);
v = vsrlp_int32x16(32, v1);
v1 = vvadd_int32x16(v, v1);
res = vextract_int32x16(v1, 1);
return res;
}
static inline __device__ int ClusterReduce(
const _shared_ptr_ int *stop_flag_now_int_sm, int len) {
int sum = 0;
if (core_id() == 0) {
int32x16_t vec_x_0;
int32x16_t vec_x_1;
int32x16_t vec_y_0 = vzero<int>();
int32x16_t vec_y_1 = vzero<int>();
for (int i = 0; i < len; i += 32) {
vload2_sm(stop_flag_now_int_sm + i, vec_x_0, vec_x_1);
vec_y_0 = vvadd_int32x16(vec_y_0, vec_x_0);
vec_y_1 = vvadd_int32x16(vec_y_1, vec_x_1);
}
sum = v_reduce(vec_y_0, vec_y_1);
}
return sum;
}
__device__ bool is_in_end(const int64_t id,
__global_ptr__ const int64_t *end_ids,
int length) {
bool flag = false;
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
}
return flag;
}
__device__ inline bool is_in(__global_ptr__ const int64_t *candidates,
const int64_t draft,
const int candidate_len) {
for (int i = 0; i < candidate_len; i++) {
if (draft == candidates[i]) {
return true;
}
}
return false;
}
// static __device__ inline unsigned int xorwow(unsigned int& state) {
// state ^= state >> 7;
// state ^= state << 9;
// state ^= state >> 13;
// return state;
// }
static __device__ inline unsigned int xorwow(unsigned int &state) {
state ^= state >> 7;
state ^= state << 9;
state ^= state >> 13;
return state;
}
typedef uint32_t curandStatePhilox4_32_10_t;
__device__ int64_t
topp_sampling_kernel(__global_ptr__ const int64_t *candidate_ids,
__global_ptr__ const float *candidate_scores,
__global_ptr__ const float *dev_curand_states,
const int candidate_len,
const float topp) {
const int tid = core_id();
float sum_scores = 0.0f;
float rand_top_p = *dev_curand_states * topp;
// printf("debug rand_top_p:%f\n",rand_top_p);
for (int i = 0; i < candidate_len; i++) {
sum_scores += candidate_scores[i];
if (rand_top_p <= sum_scores) {
return candidate_ids[i];
}
}
return candidate_ids[0];
}
#define sm_size 1024
template <bool ENABLE_TOPP, bool USE_TOPK>
__global__ void speculate_verify(
int64_t *accept_tokens, // out [real_bsz, max_draft_tokens], 输出最终接收的
// token通过验证或采样
int *accept_num, // out [real_bsz], 每个序列最终接受的 token
// 数量(只统计通过验证的)
int64_t
*step_idx, // out [real_bsz], 记录每个bid序列已经生成或接受的token数
bool *stop_flags, // out [real_bsz], 每个序列的停止标志,遇到 <eos>
// 或长度超限时置 true
const int *seq_lens_encoder, // [real_bsz], 每个样本 encoder
// 输入长度,用于判断 prefill 阶段
const int *seq_lens_decoder, // [real_bsz], 每个样本 decoder 输出的 token
// 数(即 draft token 数)
const int64_t *
draft_tokens, // [real_bsz, max_draft_tokens], draft model 输出的 token
const int *actual_draft_token_nums, // [real_bsz], draft_tokens
// 中实际有效的 token 数量
const float *dev_curand_states, // used for random
const float *topp, // [real_bsz]TopP 阈值(如
// 0.9),用于控制核采样截断概率和候选数
const int *seq_lens_this_time, // [real_bsz], 本轮 verify
// 阶段每个样本实际参与验证的 token 数
const int64_t
*verify_tokens, // [sum(seq_lens_this_time), max_candidate_len], verify
// decoder 输出的候选 token
const float
*verify_scores, // 同上, 每个 verify token 对应的概率分布,用于采样
const int64_t *max_dec_len, // [real_bsz],
// 每个样本允许生成的最大长度(超过则触发终止)
const int64_t
*end_tokens, // [end_length], 终止 token 列表(如 <eos>),命中即终止
const bool *is_block_step, // [real_bsz], 指示是否当前为 block step
// true 时跳过 verify
const int
*output_cum_offsets, // [real_bsz], verify_tokens 的起始偏移,用于定位
// token 所在 verify 索引
const int *actual_candidate_len, // [sum(seq_lens_this_time)], 每个 verify
// token 实际可用候选数(用于 TopP 截断)
const int real_bsz, // batch size
const int max_draft_tokens, // scalar, 每个样本最多允许的 draft token 数
const int end_length,
const int max_seq_len, // scalar, 每个序列的最大 token 数(用于偏移计算)
const int max_candidate_len, // scalar, 每个 verify token
// 的最大候选数(用于验证或采样)
const int verify_window, // scalar, TopK 验证窗口(允许连续 top1 匹配次数)
const bool prefill_one_step_stop) {
const int cid = core_id();
const int64_t tid = cluster_id() * core_num() + core_id();
const int64_t nthreads = cluster_num() * core_num();
for (int64_t bid = tid; bid < real_bsz; bid += nthreads) {
int stop_flag_now_int = 0;
int accept_num_now = 1;
if (is_block_step[bid]) {
continue;
}
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
if (stop_flags[bid]) {
stop_flag_now_int = 1;
} else { // 这里prefill阶段也会进入但是因为draft
// tokens会置零因此会直接到最后的采样阶段
auto *verify_tokens_now =
verify_tokens + start_token_id * max_candidate_len;
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
int i = 0;
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
// seq_lens_this_time[bid]-1);
for (; i < seq_lens_this_time[bid] - 1; i++) {
if (seq_lens_encoder[bid] != 0) {
break;
}
if (USE_TOPK) {
if (verify_tokens_now[i * max_candidate_len] ==
draft_tokens_now[i + 1]) {
// accept_num_now++;
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
// printf("[USE_TOPK] bid %d Top 1 verify write accept
// %d is %lld\n", bid, i, accept_token);
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
// printf("[USE_TOPK] bid %d Top 1 verify write
// accept %d is %lld\n", bid, i, accept_token);
break;
} else {
accept_num_now++;
}
} else {
break;
}
} else {
auto actual_candidate_len_value =
actual_candidate_len_now[i] > max_candidate_len
? max_candidate_len
: actual_candidate_len_now[i];
if (is_in(verify_tokens_now + i * max_candidate_len,
draft_tokens_now[i + 1],
actual_candidate_len_value)) {
// Top P verify
// accept_num_now++;
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
// printf("bid %d Top P verify write accept %d is
// %lld\n", bid, i, accept_token);
break;
} else {
accept_num_now++;
}
} else {
// TopK verify
int ii = i;
if (max_candidate_len >= 2 &&
verify_tokens_now[ii * max_candidate_len + 1] ==
draft_tokens_now[ii + 1]) { // top-2
int j = 0;
ii += 1;
for (; j < verify_window && ii < seq_lens_this_time[bid] - 1;
j++, ii++) {
if (verify_tokens_now[ii * max_candidate_len] !=
draft_tokens_now[ii + 1]) {
break;
}
}
if (j >= verify_window) { // accept all
accept_num_now += verify_window + 1;
step_idx[bid] += verify_window + 1;
for (; i < ii; i++) {
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] = accept_token;
// printf(
// "bid %d TopK verify write accept %d
// is "
// "%lld\n",
// bid,
// i,
// accept_token);
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
// printf("bid %d TopK verify write
// accept %d is %lld\n", bid, i,
// end_tokens[0]);
accept_num_now--;
step_idx[bid]--;
break;
}
}
}
}
break;
}
}
}
// sampling阶段
// 第一种draft_token[i+1]被拒绝需要从verify_tokens_now[i]中选一个
// 第二种i == seq_lens_this_time[bid]-1,
// 也是从verify_tokens_now[i]中选一个 但是停止的情况不算
if (!stop_flag_now_int) {
int64_t accept_token;
__global_ptr__ const float *verify_scores_now =
verify_scores + start_token_id * max_candidate_len;
step_idx[bid]++;
if (ENABLE_TOPP) {
auto actual_candidate_len_value =
actual_candidate_len_now[i] > max_candidate_len
? max_candidate_len
: actual_candidate_len_now[i];
accept_token =
topp_sampling_kernel(verify_tokens_now + i * max_candidate_len,
verify_scores_now + i * max_candidate_len,
dev_curand_states,
actual_candidate_len_value,
topp[bid]);
} else {
accept_token = verify_tokens_now[i * max_candidate_len];
}
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (prefill_one_step_stop) {
stop_flags[bid] = true;
}
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
}
}
accept_num[bid] = accept_num_now;
}
}
}
#define SPECULATE_VERIFY_INSTANTIATE(ENABLE_TOPP, USE_TOPK) \
template __global__ void speculate_verify<ENABLE_TOPP, USE_TOPK>( \
int64_t * accept_tokens, \
int *accept_num, \
int64_t *step_idx, \
bool *stop_flags, \
const int *seq_lens_encoder, \
const int *seq_lens_decoder, \
const int64_t *draft_tokens, \
const int *actual_draft_token_nums, \
const float *dev_curand_states, \
const float *topp, \
const int *seq_lens_this_time, \
const int64_t *verify_tokens, \
const float *verify_scores, \
const int64_t *max_dec_len, \
const int64_t *end_tokens, \
const bool *is_block_step, \
const int *output_cum_offsets, \
const int *actual_candidate_len, \
int real_bsz, \
int max_draft_tokens, \
int end_length, \
int max_seq_len, \
int max_candidate_len, \
int verify_window, \
bool prefill_one_step_stop);
SPECULATE_VERIFY_INSTANTIATE(true, true)
SPECULATE_VERIFY_INSTANTIATE(true, false)
SPECULATE_VERIFY_INSTANTIATE(false, true)
SPECULATE_VERIFY_INSTANTIATE(false, false)
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,349 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/cluster_primitive_template.h"
namespace xpu3 {
namespace plugin {
template <typename T, int MaxLength, int TopPBeamTopK>
__device__ void top_p_candidates_big_n(
char* lm,
__global_ptr__ const T* src,
__global_ptr__ const T* top_ps,
__global_ptr__ const int* output_padding_offset,
__global_ptr__ int64_t* out_id,
__global_ptr__ T* out_val,
__global_ptr__ int* actual_candidates_lens,
int vocab_size,
int token_num,
int max_cadidate_len,
int max_seq_len) {
int ncores = core_num();
int cid = core_id();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
int64_t buf_size = 6 * 1024 / sizeof(T);
T* lm_src = (T*)lm;
int64_t lm_out_id[TopPBeamTopK];
T lm_out_val[TopPBeamTopK];
__shared__ int64_t sm_out_id[64 * TopPBeamTopK];
__shared__ T sm_out_val[64 * TopPBeamTopK];
// only used in core 0
int lm_output_padding_offset;
for (int64_t i = cluster_id(); i < token_num; i += cluster_num()) {
if (cid == 0) {
GM2LM(output_padding_offset + i, &lm_output_padding_offset, sizeof(int));
}
for (int64_t j = 0; j < TopPBeamTopK; j++) {
lm_out_id[j] = -1;
}
for (int j = cid * buf_size; j < vocab_size; j += ncores * buf_size) {
int64_t read_size = min(buf_size, static_cast<int64_t>(vocab_size - j));
GM2LM(src + i * vocab_size + j, lm_src, read_size * sizeof(T));
for (int k = 0; k < read_size; k++) {
if (lm_out_id[TopPBeamTopK - 1] == -1 ||
lm_src[k] > lm_out_val[TopPBeamTopK - 1] ||
lm_src[k] == lm_out_val[TopPBeamTopK - 1] &&
k < lm_out_id[TopPBeamTopK - 1]) {
int l = TopPBeamTopK - 2;
for (; l >= 0; l--) {
if (lm_out_id[l] == -1 || lm_src[k] > lm_out_val[l] ||
lm_src[k] == lm_out_val[l] && (j + k) < lm_out_id[l]) {
lm_out_id[l + 1] = lm_out_id[l];
lm_out_val[l + 1] = lm_out_val[l];
} else {
break;
}
}
lm_out_id[l + 1] = j + k;
lm_out_val[l + 1] = lm_src[k];
}
}
mfence_lm();
}
if (cid % 16 != 0) {
for (int64_t j = 0; j < TopPBeamTopK; j++) {
sm_out_id[cid * TopPBeamTopK + j] = lm_out_id[j];
sm_out_val[cid * TopPBeamTopK + j] = lm_out_val[j];
}
}
mfence_sm();
sync_all();
if (cid % 16 == 0) {
int64_t local_sm_out_id;
T local_sm_out_val;
for (int j = cid + 1; j < cid + 16; j += 1) {
for (int offset = 0; offset < TopPBeamTopK; offset++) {
local_sm_out_id = sm_out_id[j * TopPBeamTopK + offset];
local_sm_out_val = sm_out_val[j * TopPBeamTopK + offset];
if (local_sm_out_val > lm_out_val[TopPBeamTopK - 1] ||
local_sm_out_val == lm_out_val[TopPBeamTopK - 1] &&
local_sm_out_id < lm_out_id[TopPBeamTopK - 1]) {
int k = TopPBeamTopK - 2;
for (; k >= 0; k--) {
if (local_sm_out_val > lm_out_val[k] ||
local_sm_out_val == lm_out_val[k] &&
local_sm_out_id < lm_out_id[k]) {
lm_out_id[k + 1] = lm_out_id[k];
lm_out_val[k + 1] = lm_out_val[k];
} else {
break;
}
}
lm_out_id[k + 1] = local_sm_out_id;
lm_out_val[k + 1] = local_sm_out_val;
} else {
break;
}
}
}
if (cid != 0) {
for (int64_t j = 0; j < TopPBeamTopK; j++) {
sm_out_id[cid * TopPBeamTopK + j] = lm_out_id[j];
sm_out_val[cid * TopPBeamTopK + j] = lm_out_val[j];
}
}
}
mfence_sm();
sync_all();
if (cid == 0) {
int64_t local_sm_out_id;
T local_sm_out_val;
for (int j = cid + 16; j < ncores; j += 16) {
for (int offset = 0; offset < TopPBeamTopK; offset++) {
local_sm_out_id = sm_out_id[j * TopPBeamTopK + offset];
local_sm_out_val = sm_out_val[j * TopPBeamTopK + offset];
if (local_sm_out_val > lm_out_val[TopPBeamTopK - 1] ||
local_sm_out_val == lm_out_val[TopPBeamTopK - 1] &&
local_sm_out_id < lm_out_id[TopPBeamTopK - 1]) {
int k = TopPBeamTopK - 2;
for (; k >= 0; k--) {
if (local_sm_out_val > lm_out_val[k] ||
local_sm_out_val == lm_out_val[k] &&
local_sm_out_id < lm_out_id[k]) {
lm_out_id[k + 1] = lm_out_id[k];
lm_out_val[k + 1] = lm_out_val[k];
} else {
break;
}
}
lm_out_id[k + 1] = local_sm_out_id;
lm_out_val[k + 1] = local_sm_out_val;
} else {
break;
}
}
}
int ori_token_id = i + lm_output_padding_offset;
int bid = ori_token_id / max_seq_len;
T lm_top_p;
GM2LM(top_ps + bid, &lm_top_p, sizeof(T));
float top_p_value = static_cast<float>(lm_top_p);
T default_val = static_cast<T>(0.f);
int lm_actual_candidates_len = 0;
float sum_prob = static_cast<float>(lm_out_val[0]);
for (int j = 0; j < TopPBeamTopK; j++) {
if (sum_prob >= top_p_value) {
for (int k = j + 1; k < TopPBeamTopK; k++) {
lm_out_id[k] = 0;
lm_out_val[k] = default_val;
}
lm_actual_candidates_len = j + 1;
break;
} else {
sum_prob += static_cast<float>(lm_out_val[j]);
}
}
mfence_lm();
LM2GM_ASYNC(
&lm_actual_candidates_len, actual_candidates_lens + i, sizeof(int));
LM2GM_ASYNC(lm_out_id,
out_id + i * max_cadidate_len,
TopPBeamTopK * sizeof(int64_t));
LM2GM_ASYNC(
lm_out_val, out_val + i * max_cadidate_len, TopPBeamTopK * sizeof(T));
}
mfence();
sync_all();
}
}
template <typename T, int MaxLength, int TopPBeamTopK>
__device__ void top_p_candidates_normal(
char* lm,
__global_ptr__ const T* src,
__global_ptr__ const T* top_ps,
__global_ptr__ const int* output_padding_offset,
__global_ptr__ int64_t* out_id,
__global_ptr__ T* out_val,
__global_ptr__ int* actual_candidates_lens,
int vocab_size,
int token_num,
int max_cadidate_len,
int max_seq_len) {
int ncores = core_num();
int cid = core_id();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
int64_t buf_size = 6 * 1024 / sizeof(T);
T* lm_src = (T*)lm;
int64_t lm_out_id[TopPBeamTopK];
T lm_out_val[TopPBeamTopK];
int lm_output_padding_offset;
T lm_top_p;
int64_t default_id = 0;
T default_val = static_cast<T>(0.f);
for (int64_t i = tid; i < token_num; i += nthreads) {
float sum_prob = 0.0f;
for (int64_t j = 0; j < TopPBeamTopK; j++) {
lm_out_id[j] = -1;
}
for (int j = 0; j < vocab_size; j += buf_size) {
int64_t read_size = min(buf_size, static_cast<int64_t>(vocab_size - j));
GM2LM(src + i * vocab_size + j, lm_src, read_size * sizeof(T));
for (int k = 0; k < read_size; k++) {
if (lm_out_id[TopPBeamTopK - 1] == -1 ||
lm_src[k] > lm_out_val[TopPBeamTopK - 1] ||
lm_src[k] == lm_out_val[TopPBeamTopK - 1] &&
k < lm_out_id[TopPBeamTopK - 1]) {
lm_out_id[TopPBeamTopK - 1] = j + k;
lm_out_val[TopPBeamTopK - 1] = lm_src[k];
for (int l = TopPBeamTopK - 2; l >= 0; l--) {
if (lm_out_id[l] == -1 || lm_out_val[l + 1] > lm_out_val[l] ||
lm_out_val[l + 1] == lm_out_val[l] &&
lm_out_id[l + 1] < lm_out_id[l]) {
int64_t swap_id = lm_out_id[l];
T swap_val = lm_out_val[l];
lm_out_id[l] = lm_out_id[l + 1];
lm_out_val[l] = lm_out_val[l + 1];
lm_out_id[l + 1] = swap_id;
lm_out_val[l + 1] = swap_val;
}
}
}
}
mfence_lm();
}
GM2LM(output_padding_offset + i, &lm_output_padding_offset, sizeof(int));
int ori_token_id = i + lm_output_padding_offset;
int bid = ori_token_id / max_seq_len;
GM2LM(top_ps + bid, &lm_top_p, sizeof(T));
float top_p_value = static_cast<float>(lm_top_p);
bool set_to_default_val = false;
int lm_actual_candidates_len = 0;
for (int j = 0; j < TopPBeamTopK; j++) {
if (set_to_default_val) {
LM2GM_ASYNC(
&default_id, out_id + i * max_cadidate_len + j, sizeof(int64_t));
LM2GM_ASYNC(
&default_val, out_val + i * max_cadidate_len + j, sizeof(T));
} else {
LM2GM_ASYNC(
lm_out_id + j, out_id + i * max_cadidate_len + j, sizeof(int64_t));
LM2GM_ASYNC(
lm_out_val + j, out_val + i * max_cadidate_len + j, sizeof(T));
sum_prob += static_cast<float>(lm_out_val[j]);
if (sum_prob >= top_p_value) {
lm_actual_candidates_len = j + 1;
mfence_lm();
LM2GM_ASYNC(&lm_actual_candidates_len,
actual_candidates_lens + i,
sizeof(int));
set_to_default_val = true;
}
}
}
mfence_lm();
}
}
template <typename T, int MaxLength, int TopPBeamTopK>
__global__ void top_p_candidates(const T* src,
const T* top_ps,
const int* output_padding_offset,
int64_t* out_id,
T* out_val,
int* actual_candidates_lens,
int vocab_size,
int token_num,
int max_cadidate_len,
int max_seq_len) {
char lm[6 * 1024];
if (token_num % (core_num() * cluster_num()) != 0 &&
vocab_size >= core_num() * (6 * 1024 / sizeof(T)) &&
vocab_size >= core_num() * TopPBeamTopK) {
top_p_candidates_big_n<T, MaxLength, TopPBeamTopK>(lm,
src,
top_ps,
output_padding_offset,
out_id,
out_val,
actual_candidates_lens,
vocab_size,
token_num,
max_cadidate_len,
max_seq_len);
} else {
top_p_candidates_normal<T, MaxLength, TopPBeamTopK>(lm,
src,
top_ps,
output_padding_offset,
out_id,
out_val,
actual_candidates_lens,
vocab_size,
token_num,
max_cadidate_len,
max_seq_len);
}
}
#define _XPU_DEF_TOP_P_CANDIDATES_KERNEL(T, MaxLength, TopPBeamTopK) \
template __global__ void top_p_candidates<T, MaxLength, TopPBeamTopK>( \
const T* src, \
const T* top_ps, \
const int* output_padding_offset, \
int64_t* out_id, \
T* out_val, \
int* actual_candidates_lens, \
int vocab_size, \
int token_num, \
int max_cadidate_len, \
int max_seq_len);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(bfloat16, 2, 2);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(bfloat16, 2, 3);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(bfloat16, 2, 4);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(bfloat16, 2, 5);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(bfloat16, 2, 8);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(bfloat16, 2, 10);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float16, 2, 2);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float16, 2, 3);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float16, 2, 4);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float16, 2, 5);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float16, 2, 8);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float16, 2, 10);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float, 2, 2);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float, 2, 3);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float, 2, 4);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float, 2, 5);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float, 2, 8);
_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float, 2, 10);
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,181 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void ComputeOrderKernel(
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* accept_nums,
int* position_map,
int* output_token_num,
const int bsz,
const int actual_draft_token_num,
const int input_token_num);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context* ctx,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* accept_nums,
int* position_map,
int* output_token_num,
const int bsz,
const int actual_draft_token_num,
const int input_token_num) {
int in_offset = 0; // input_offset(long)
int out_offset = 0; // output_offset(short)
for (int i = 0; i < bsz; ++i) {
int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i];
int cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i];
int cur_seq_lens_this_time = seq_lens_this_time[i];
int accept_num = accept_nums[i];
int cur_seq_lens_encoder = seq_lens_encoder[i];
// 1. eagle encoder. Base step=1
if (cur_seq_lens_encoder > 0) {
for (int j = 0; j < cur_seq_lens_encoder; j++) {
position_map[in_offset++] = out_offset++;
}
// 2. base model encoder. Base step=0
} else if (cur_base_model_seq_lens_encoder != 0) {
// nothing happens
// 3. New end
} else if (cur_base_model_seq_lens_this_time != 0 &&
cur_seq_lens_this_time == 0) {
in_offset += cur_base_model_seq_lens_this_time;
// 4. stopped
} else if (cur_base_model_seq_lens_this_time == 0 &&
cur_seq_lens_this_time == 0) /* end */ {
// nothing happens
} else {
if (accept_num <=
actual_draft_token_num) /*Accept partial draft tokens*/ {
position_map[in_offset + accept_num - 1] = out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
} else /*Accept all draft tokens*/ {
position_map[in_offset + accept_num - 2] = out_offset++;
position_map[in_offset + accept_num - 1] = out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
}
}
}
output_token_num[0] = out_offset;
return api::SUCCESS;
}
static int xpu3_wrapper(Context* ctx,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* accept_nums,
int* position_map,
int* output_token_num,
const int bsz,
const int actual_draft_token_num,
const int input_token_num) {
xpu3::plugin::ComputeOrderKernel<<<1, 1, ctx->xpu_stream>>>(
seq_lens_this_time,
seq_lens_encoder,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
accept_nums,
position_map,
output_token_num,
bsz,
actual_draft_token_num,
input_token_num);
return api::SUCCESS;
}
int compute_order(Context* ctx,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* accept_nums,
int* position_map,
int* output_token_num,
const int bsz,
const int actual_draft_token_num,
const int input_token_num) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_PARAM5(ctx,
seq_lens_this_time,
seq_lens_encoder,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
accept_nums);
WRAPPER_DUMP_PARAM5(ctx,
position_map,
output_token_num,
bsz,
actual_draft_token_num,
input_token_num);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, int, bsz, base_model_seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int, bsz, base_model_seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, int, bsz, accept_nums);
WRAPPER_CHECK_PTR(ctx, int, input_token_num, position_map);
WRAPPER_CHECK_PTR(ctx, int, 1, output_token_num);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
seq_lens_this_time,
seq_lens_encoder,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
accept_nums,
position_map,
output_token_num,
bsz,
actual_draft_token_num,
input_token_num);
} else if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
seq_lens_this_time,
seq_lens_encoder,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
accept_nums,
position_map,
output_token_num,
bsz,
actual_draft_token_num,
input_token_num);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,133 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void ComputeSelfOrderKernel(
const int* last_seq_lens_this_time,
const int* seq_lens_this_time,
const int64_t* step_idx,
int* src_map,
int* output_token_num,
int bsz);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context* ctx,
const int* last_seq_lens_this_time,
const int* seq_lens_this_time,
const int64_t* step_idx,
int* src_map,
int* output_token_num,
int bsz) {
int in_offset = 0;
int out_offset = 0;
for (int i = 0; i < bsz; i++) {
int cur_seq_lens_this_time = seq_lens_this_time[i];
int cur_last_seq_lens_this_time = last_seq_lens_this_time[i];
// 1. encoder
if (step_idx[i] == 1 && cur_seq_lens_this_time > 0) {
in_offset += 1;
src_map[out_offset++] = in_offset - 1;
// 2. decoder
} else if (cur_seq_lens_this_time > 0) /* =1 */ {
in_offset += cur_last_seq_lens_this_time;
src_map[out_offset++] = in_offset - 1;
// 3. stop
} else {
// first token end
if (step_idx[i] == 1) {
in_offset += cur_last_seq_lens_this_time > 0 ? 1 : 0;
// normal end
} else {
in_offset += cur_last_seq_lens_this_time;
}
}
}
output_token_num[0] = out_offset;
return api::SUCCESS;
}
static int xpu3_wrapper(Context* ctx,
const int* last_seq_lens_this_time,
const int* seq_lens_this_time,
const int64_t* step_idx,
int* src_map,
int* output_token_num,
int bsz) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
xpu3::plugin::ComputeSelfOrderKernel<<<1, 1, ctx->xpu_stream>>>(
last_seq_lens_this_time,
seq_lens_this_time,
reinterpret_cast<const XPU_INT64*>(step_idx),
src_map,
output_token_num,
bsz);
return api::SUCCESS;
}
int compute_self_order(Context* ctx,
const int* last_seq_lens_this_time,
const int* seq_lens_this_time,
const int64_t* step_idx,
int* src_map,
int* output_token_num,
int bsz) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_PARAM6(ctx,
last_seq_lens_this_time,
seq_lens_this_time,
step_idx,
src_map,
output_token_num,
bsz);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, int, bsz, last_seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int64_t, bsz, step_idx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
last_seq_lens_this_time,
seq_lens_this_time,
step_idx,
src_map,
output_token_num,
bsz);
} else if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
last_seq_lens_this_time,
seq_lens_this_time,
step_idx,
src_map,
output_token_num,
bsz);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,142 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu2 {
namespace plugin {
__attribute__((global)) void draft_model_postprocess(
const int64_t* base_model_draft_tokens,
int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const bool* base_model_stop_flags,
int bsz,
int base_model_draft_token_len);
} // namespace plugin
} // namespace xpu2
namespace xpu3 {
namespace plugin {
__attribute__((global)) void draft_model_postprocess(
const int64_t* base_model_draft_tokens,
int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const bool* base_model_stop_flags,
int bsz,
int base_model_draft_token_len);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(
Context* ctx,
const int64_t*
base_model_draft_tokens, // size = [bsz, base_model_draft_token_len]
int* base_model_seq_lens_this_time, // size = [bsz]
const int* base_model_seq_lens_encoder, // size = [bsz]
const bool* base_model_stop_flags, // size = [bsz]
int bsz,
int base_model_draft_token_len) {
// 遍历每个样本
for (int tid = 0; tid < bsz; ++tid) {
if (!base_model_stop_flags[tid] && base_model_seq_lens_encoder[tid] == 0) {
// 获取当前样本的草稿token指针
const int64_t* base_model_draft_tokens_now =
base_model_draft_tokens + tid * base_model_draft_token_len;
// 计算有效token数量非-1的token
int token_num = 0;
for (int i = 0; i < base_model_draft_token_len; ++i) {
if (base_model_draft_tokens_now[i] != -1) {
token_num++;
}
}
// 更新序列长度
base_model_seq_lens_this_time[tid] = token_num;
} else if (base_model_stop_flags[tid]) {
// 已停止的样本序列长度为0
base_model_seq_lens_this_time[tid] = 0;
}
}
return api::SUCCESS;
}
static int xpu3_wrapper(Context* ctx,
const int64_t* base_model_draft_tokens,
int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const bool* base_model_stop_flags,
int bsz,
int base_model_draft_token_len) {
xpu3::plugin::draft_model_postprocess<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const xpu3::int64_t*>(base_model_draft_tokens),
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_stop_flags,
bsz,
base_model_draft_token_len);
return api::SUCCESS;
}
int draft_model_postprocess(Context* ctx,
const int64_t* base_model_draft_tokens,
int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const bool* base_model_stop_flags,
int bsz,
int base_model_draft_token_len) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_PARAM6(ctx,
base_model_draft_tokens,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_stop_flags,
bsz,
base_model_draft_token_len);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(
ctx, int64_t, bsz * base_model_draft_token_len, base_model_draft_tokens);
WRAPPER_CHECK_PTR(ctx, int, bsz, base_model_seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, bool, bsz, base_model_stop_flags);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
base_model_draft_tokens,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_stop_flags,
bsz,
base_model_draft_token_len);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
base_model_draft_tokens,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_stop_flags,
bsz,
base_model_draft_token_len);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
// template int draft_model_postprocess(
// Context*, const int64_t*, int*, const int*, const bool*, int, int);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,392 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl/launch_strategy.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include "xpu/xdnn.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void draft_model_preprocess(
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
int real_bsz,
int max_draft_token,
int accept_tokens_len,
int draft_tokens_len,
int input_ids_len,
int base_model_draft_tokens_len,
bool truncate_first_token,
bool splitwise_prefill);
} // namespace plugin
} // namespace xpu3
namespace xpu2 {
namespace plugin {} // namespace plugin
} // namespace xpu2
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(api::Context* ctx,
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
int real_bsz,
int max_draft_token,
int accept_tokens_len,
int draft_tokens_len,
int input_ids_len,
int base_model_draft_tokens_len,
bool truncate_first_token,
bool splitwise_prefill) {
int64_t not_stop_flag_sum = 0;
int64_t not_stop_flag = 0;
for (int tid = 0; tid < real_bsz; tid++) {
if (splitwise_prefill) {
int base_model_step_idx_now = base_model_step_idx[tid];
auto* input_ids_now = input_ids + tid * input_ids_len;
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
// printf("bid: %d, base_model_step_idx_now: %d seq_lens_encoder_record:
// %d\n", tid, base_model_step_idx_now, seq_lens_encoder_record[tid]);
if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) {
not_stop_flag = 1;
int seq_len_encoder_record = seq_lens_encoder_record[tid];
seq_lens_encoder[tid] = seq_len_encoder_record;
seq_lens_encoder_record[tid] = -1;
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder_record;
if (truncate_first_token) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record;
} else {
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
}
} else {
stop_flags[tid] = true;
seq_lens_this_time[tid] = 0;
seq_lens_decoder[tid] = 0;
seq_lens_encoder[tid] = 0;
not_stop_flag = 0;
}
not_stop_flag_sum += not_stop_flag;
} else {
auto base_model_step_idx_now = base_model_step_idx[tid];
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
auto accept_num_now = accept_num[tid];
auto* input_ids_now = input_ids + tid * input_ids_len;
auto* base_model_draft_tokens_now =
base_model_draft_tokens + tid * base_model_draft_tokens_len;
for (int i = 1; i < base_model_draft_tokens_len; i++) {
base_model_draft_tokens_now[i] = -1;
}
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
batch_drop[tid] = true;
stop_flags[tid] = true;
}
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
not_stop_flag = 1;
// 1. first token
if (base_model_step_idx_now == 0) {
seq_lens_this_time[tid] = 0;
not_stop_flag = 0;
} else if (base_model_step_idx_now == 1 &&
seq_lens_encoder_record[tid] > 0) {
// Can be extended to first few tokens
int seq_len_encoder_record = seq_lens_encoder_record[tid];
seq_lens_encoder[tid] = seq_len_encoder_record;
seq_lens_encoder_record[tid] = -1;
seq_lens_decoder[tid] = seq_lens_decoder_record[tid];
seq_lens_decoder_record[tid] = 0;
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder_record;
if (truncate_first_token) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record;
} else {
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder_record + 1;
}
} else if (accept_num_now <=
max_draft_token) /*Accept partial draft tokens*/ {
// Base Model reject stop
if (stop_flags[tid]) {
stop_flags[tid] = false;
seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
step_idx[tid] = base_model_step_idx[tid];
} else {
seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
step_idx[tid] -= max_draft_token - accept_num_now;
}
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
draft_tokens_now[0] = modified_token;
seq_lens_this_time[tid] = 1;
} else /*Accept all draft tokens*/ {
draft_tokens_now[1] = accept_tokens_now[max_draft_token];
seq_lens_this_time[tid] = 2;
}
} else {
stop_flags[tid] = true;
seq_lens_this_time[tid] = 0;
seq_lens_decoder[tid] = 0;
seq_lens_encoder[tid] = 0;
}
not_stop_flag_sum += not_stop_flag;
}
}
not_need_stop[0] = not_stop_flag_sum > 0;
return api::SUCCESS;
}
static int xpu3_wrapper(api::Context* ctx,
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
int real_bsz,
int max_draft_token,
int accept_tokens_len,
int draft_tokens_len,
int input_ids_len,
int base_model_draft_tokens_len,
bool truncate_first_token,
bool splitwise_prefill) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
// NOTE: Don't change 16 to 64, because kernel use gsm
xpu3::plugin::draft_model_preprocess<<<1, 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64*>(draft_tokens),
reinterpret_cast<XPU_INT64*>(input_ids),
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<XPU_INT64*>(step_idx),
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
reinterpret_cast<const XPU_INT64*>(accept_tokens),
accept_num,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
reinterpret_cast<const XPU_INT64*>(base_model_step_idx),
base_model_stop_flags,
base_model_is_block_step,
reinterpret_cast<XPU_INT64*>(base_model_draft_tokens),
real_bsz,
max_draft_token,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
truncate_first_token,
splitwise_prefill);
return api::SUCCESS;
}
int draft_model_preprocess(api::Context* ctx,
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
int* seq_lens_encoder_record,
int* seq_lens_decoder_record,
bool* not_need_stop,
bool* batch_drop,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
int real_bsz,
int max_draft_token,
int accept_tokens_len,
int draft_tokens_len,
int input_ids_len,
int base_model_draft_tokens_len,
bool truncate_first_token,
bool splitwise_prefill) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "draft_model_preprocess", int64_t);
WRAPPER_DUMP_PARAM6(ctx,
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder);
WRAPPER_DUMP_PARAM5(ctx,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop);
WRAPPER_DUMP_PARAM3(
ctx, accept_tokens, accept_num, base_model_seq_lens_encoder);
WRAPPER_DUMP_PARAM3(ctx,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags);
WRAPPER_DUMP_PARAM3(
ctx, base_model_is_block_step, base_model_draft_tokens, real_bsz);
WRAPPER_DUMP_PARAM3(
ctx, max_draft_token, accept_tokens_len, draft_tokens_len);
WRAPPER_DUMP_PARAM3(
ctx, input_ids_len, base_model_draft_tokens_len, truncate_first_token);
WRAPPER_DUMP_PARAM1(ctx, splitwise_prefill);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * accept_tokens_len, accept_tokens);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * input_ids_len, input_ids);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * draft_tokens_len, draft_tokens);
WRAPPER_CHECK_PTR(ctx,
int64_t,
real_bsz * base_model_draft_tokens_len,
base_model_draft_tokens);
WRAPPER_ASSERT_GT(ctx, real_bsz, 0);
WRAPPER_ASSERT_LT(ctx, accept_tokens_len, 128);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
accept_num,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
real_bsz,
max_draft_token,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
truncate_first_token,
splitwise_prefill);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
accept_num,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
real_bsz,
max_draft_token,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
truncate_first_token,
splitwise_prefill);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,324 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void draft_model_update(
const int64_t* inter_next_tokens,
int64_t* draft_tokens,
int64_t* pre_ids,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
const int* output_cum_offsets,
bool* stop_flags,
bool* not_need_stop,
const int64_t* max_dec_len,
const int64_t* end_ids,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int pre_id_length,
const int max_base_model_draft_token,
const int end_ids_len,
const int max_seq_len,
const int substep,
const bool prefill_one_step_stop);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
bool is_in_end(int64_t token, const int64_t* end_ids, int end_ids_len) {
for (int i = 0; i < end_ids_len; ++i) {
if (end_ids[i] == token) {
return true;
}
}
return false;
}
static int cpu_wrapper(Context* ctx,
const int64_t* inter_next_tokens,
int64_t* draft_tokens,
int64_t* pre_ids,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
const int* output_cum_offsets,
bool* stop_flags,
bool* not_need_stop,
const int64_t* max_dec_len,
const int64_t* end_ids,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int pre_id_length,
const int max_base_model_draft_token,
const int end_ids_len,
const int max_seq_len,
const int substep,
const bool prefill_one_step_stop) {
int64_t stop_sum = 0;
// 遍历所有batch
for (int tid = 0; tid < bsz; ++tid) {
auto* draft_token_now = draft_tokens + tid * max_draft_token;
auto* pre_ids_now = pre_ids + tid * pre_id_length;
auto* base_model_draft_tokens_now =
base_model_draft_tokens + tid * max_base_model_draft_token;
const int next_tokens_start_id =
tid * max_seq_len - output_cum_offsets[tid];
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
auto seq_len_this_time = seq_lens_this_time[tid];
auto seq_len_encoder = seq_lens_encoder[tid];
auto seq_len_decoder = seq_lens_decoder[tid];
int64_t stop_flag_now_int = 0;
// 1. update step_idx && seq_lens_dec
if (!stop_flags[tid]) {
int64_t token_this_time = -1;
// decoder step
if (seq_len_decoder > 0 && seq_len_encoder <= 0) {
seq_lens_decoder[tid] += seq_len_this_time;
token_this_time = next_tokens_start[seq_len_this_time - 1];
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
base_model_draft_tokens_now[substep + 1] = token_this_time;
for (int i = 0; i < seq_len_this_time; ++i) {
pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i];
}
step_idx[tid] += seq_len_this_time;
} else {
token_this_time = next_tokens_start[0];
seq_lens_decoder[tid] = seq_len_encoder + seq_len_decoder;
seq_lens_encoder[tid] = 0;
pre_ids_now[1] = token_this_time;
step_idx[tid] += 1;
draft_token_now[0] = token_this_time;
base_model_draft_tokens_now[substep + 1] = token_this_time;
}
// multi_end
if (is_in_end(token_this_time, end_ids, end_ids_len) ||
prefill_one_step_stop) {
stop_flags[tid] = true;
stop_flag_now_int = 1;
// max_dec_len
} else if (step_idx[tid] >= max_dec_len[tid]) {
stop_flags[tid] = true;
draft_token_now[seq_len_this_time - 1] = end_ids[0];
base_model_draft_tokens_now[substep + 1] = end_ids[0];
stop_flag_now_int = 1;
}
} else {
draft_token_now[0] = -1;
base_model_draft_tokens_now[substep + 1] = -1;
stop_flag_now_int = 1;
}
// 2. set end
if (!stop_flags[tid]) {
seq_lens_this_time[tid] = 1;
} else {
seq_lens_this_time[tid] = 0;
seq_lens_encoder[tid] = 0;
}
stop_sum += stop_flag_now_int;
}
// 等价于CUDA中的BlockReduce求和
not_need_stop[0] = stop_sum < bsz;
return SUCCESS;
}
static int xpu2or3_wrapper(Context* ctx,
const int64_t* inter_next_tokens,
int64_t* draft_tokens,
int64_t* pre_ids,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
const int* output_cum_offsets,
bool* stop_flags,
bool* not_need_stop,
const int64_t* max_dec_len,
const int64_t* end_ids,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int pre_id_length,
const int max_base_model_draft_token,
const int end_ids_len,
const int max_seq_len,
const int substep,
const bool prefill_one_step_stop) {
ctx_guard RAII_GUARD(ctx);
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
xpu3::plugin::draft_model_update<<<1, 64, ctx->xpu_stream>>>(
reinterpret_cast<const XPU_INT64*>(inter_next_tokens),
reinterpret_cast<XPU_INT64*>(draft_tokens),
reinterpret_cast<XPU_INT64*>(pre_ids),
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<XPU_INT64*>(step_idx),
output_cum_offsets,
stop_flags,
not_need_stop,
reinterpret_cast<const XPU_INT64*>(max_dec_len),
reinterpret_cast<const XPU_INT64*>(end_ids),
reinterpret_cast<XPU_INT64*>(base_model_draft_tokens),
bsz,
max_draft_token,
pre_id_length,
max_base_model_draft_token,
end_ids_len,
max_seq_len,
substep,
prefill_one_step_stop);
return api::SUCCESS;
}
int draft_model_update(Context* ctx,
const int64_t* inter_next_tokens,
int64_t* draft_tokens,
int64_t* pre_ids,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
const int* output_cum_offsets,
bool* stop_flags,
bool* not_need_stop,
const int64_t* max_dec_len,
const int64_t* end_ids,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int pre_id_length,
const int max_base_model_draft_token,
const int end_ids_len,
const int max_seq_len,
const int substep,
const bool prefill_one_step_stop) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "draft_model_update", int);
WRAPPER_DUMP_PARAM6(ctx,
inter_next_tokens,
draft_tokens,
pre_ids,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder);
WRAPPER_DUMP_PARAM6(ctx,
step_idx,
output_cum_offsets,
stop_flags,
not_need_stop,
max_dec_len,
end_ids);
WRAPPER_DUMP_PARAM6(ctx,
base_model_draft_tokens,
bsz,
max_draft_token,
pre_id_length,
max_base_model_draft_token,
end_ids_len);
WRAPPER_DUMP_PARAM3(ctx, max_seq_len, substep, prefill_one_step_stop);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, int64_t, bsz * max_seq_len, inter_next_tokens);
WRAPPER_CHECK_PTR(ctx, int64_t, bsz * max_draft_token, draft_tokens);
WRAPPER_CHECK_PTR(ctx, int64_t, bsz * pre_id_length, pre_ids);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_decoder);
WRAPPER_CHECK_PTR(ctx, int64_t, bsz, step_idx);
WRAPPER_CHECK_PTR(ctx, int, bsz, output_cum_offsets);
WRAPPER_CHECK_PTR(ctx, bool, bsz, stop_flags);
WRAPPER_CHECK_PTR(ctx, bool, 1, not_need_stop);
WRAPPER_CHECK_PTR(ctx, int64_t, bsz, max_dec_len);
WRAPPER_CHECK_PTR(ctx, int64_t, end_ids_len, end_ids);
WRAPPER_CHECK_PTR(
ctx, int64_t, bsz * max_base_model_draft_token, base_model_draft_tokens);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
inter_next_tokens,
draft_tokens,
pre_ids,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
output_cum_offsets,
stop_flags,
not_need_stop,
max_dec_len,
end_ids,
base_model_draft_tokens,
bsz,
max_draft_token,
pre_id_length,
max_base_model_draft_token,
end_ids_len,
max_seq_len,
substep,
prefill_one_step_stop);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu2or3_wrapper(ctx,
inter_next_tokens,
draft_tokens,
pre_ids,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
output_cum_offsets,
stop_flags,
not_need_stop,
max_dec_len,
end_ids,
base_model_draft_tokens,
bsz,
max_draft_token,
pre_id_length,
max_base_model_draft_token,
end_ids_len,
max_seq_len,
substep,
prefill_one_step_stop);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,254 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu2 {
namespace plugin {} // namespace plugin
} // namespace xpu2
namespace xpu3 {
namespace plugin {
__attribute__((global)) void mtp_free_and_dispatch_block(
bool *base_model_stop_flags,
bool *stop_flags,
bool *batch_drop,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
int *used_list_len,
int *free_list,
int *free_list_len,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_draft_tokens);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context *ctx,
bool *base_model_stop_flags,
bool *stop_flags,
bool *batch_drop,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
int *used_list_len,
int *free_list,
int *free_list_len,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_draft_tokens) {
int need_block_len = 0;
int need_block_list[640];
for (int tid = 0; tid < bsz; tid++) {
need_block_list[tid] = 0;
int *block_table_now = block_tables + tid * block_num_per_seq;
if (base_model_stop_flags[tid] || batch_drop[tid]) {
// 回收block块
const int encoder_block_len = encoder_block_lens[tid];
const int decoder_used_len = used_list_len[tid];
if (decoder_used_len > 0) {
for (int i = 0; i < decoder_used_len; i++) {
free_list[free_list_len[0] + i] =
block_table_now[encoder_block_len + i];
block_table_now[encoder_block_len + i] = -1;
}
free_list_len[0] += decoder_used_len;
encoder_block_lens[tid] = 0;
used_list_len[tid] = 0;
}
}
}
for (int tid = 0; tid < bsz; tid++) {
int *block_table_now = block_tables + tid * block_num_per_seq;
int max_possible_block_idx =
(seq_lens_decoder[tid] + max_draft_tokens + 1) / block_size;
if (!base_model_stop_flags[tid] && !batch_drop[tid] &&
max_possible_block_idx < block_num_per_seq &&
block_table_now[max_possible_block_idx] == -1) {
need_block_list[need_block_len] = tid;
need_block_len++;
}
}
// 这里直接从 bid 0 开始遍历
while (need_block_len > free_list_len[0]) {
int max_used_list_len_id = 0;
int max_used_list_len = 0;
for (int i = 0; i < bsz; i++) {
if (!base_model_stop_flags[i] && used_list_len[i] > max_used_list_len) {
max_used_list_len = used_list_len[i];
max_used_list_len_id = i;
}
}
const int encoder_block_len = encoder_block_lens[max_used_list_len_id];
int *block_table_now =
block_tables + max_used_list_len_id * block_num_per_seq;
for (int i = 0; i < max_used_list_len; i++) {
free_list[free_list_len[0] + i] = block_table_now[encoder_block_len + i];
block_table_now[encoder_block_len + i] = -1;
}
stop_flags[max_used_list_len_id] = true;
batch_drop[max_used_list_len_id] = true;
seq_lens_this_time[max_used_list_len_id] = 0;
seq_lens_decoder[max_used_list_len_id] = 0;
used_list_len[max_used_list_len_id] = 0;
free_list_len[0] += max_used_list_len;
}
for (int tid = 0; tid < need_block_len; tid++) {
const int need_block_id = need_block_list[tid];
// 这里必须用 batch_drop, 不能用 stop_flags
if (!batch_drop[need_block_id]) {
used_list_len[need_block_id] += 1;
int *block_table_now = block_tables + need_block_id * block_num_per_seq;
block_table_now[(seq_lens_decoder[need_block_id] + max_draft_tokens + 1) /
block_size] = free_list[free_list_len[0] - 1];
free_list_len[0] -= 1;
}
}
return api::SUCCESS;
}
static int xpu2or3_wrapper(Context *ctx,
bool *base_model_stop_flags,
bool *stop_flags,
bool *batch_drop,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
int *used_list_len,
int *free_list,
int *free_list_len,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_draft_tokens) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
bool is_xpu3 = ctx->dev().type() == api::kXPU3;
if (!is_xpu3) {
WRAPPER_UNIMPLEMENTED(ctx);
}
auto mtp_free_and_dispatch_block = xpu3::plugin::mtp_free_and_dispatch_block;
mtp_free_and_dispatch_block<<<12, 64, ctx->xpu_stream>>>(
base_model_stop_flags,
stop_flags,
batch_drop,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
used_list_len,
free_list,
free_list_len,
bsz,
block_size,
block_num_per_seq,
max_draft_tokens);
return api::SUCCESS;
}
int mtp_free_and_dispatch_block(Context *ctx,
bool *base_model_stop_flags,
bool *stop_flags,
bool *batch_drop,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
int *used_list_len,
int *free_list,
int *free_list_len,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_draft_tokens) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "mtp_free_and_dispatch_block", float);
WRAPPER_DUMP_PARAM6(ctx,
base_model_stop_flags,
stop_flags,
batch_drop,
seq_lens_this_time,
seq_lens_decoder,
block_tables);
WRAPPER_DUMP_PARAM4(
ctx, encoder_block_lens, used_list_len, free_list, free_list_len);
WRAPPER_DUMP_PARAM4(
ctx, bsz, block_size, block_num_per_seq, max_draft_tokens);
WRAPPER_ASSERT_LE(ctx, bsz, 640);
WRAPPER_CHECK_PTR(ctx, bool, bsz, base_model_stop_flags);
WRAPPER_CHECK_PTR(ctx, bool, bsz, stop_flags);
WRAPPER_CHECK_PTR(ctx, bool, bsz, batch_drop);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_decoder);
WRAPPER_CHECK_PTR(ctx, int, bsz *block_num_per_seq, block_tables);
WRAPPER_CHECK_PTR(ctx, int, bsz, encoder_block_lens);
WRAPPER_CHECK_PTR(ctx, int, bsz, used_list_len);
WRAPPER_CHECK_PTR(ctx, int, 1, free_list_len);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
base_model_stop_flags,
stop_flags,
batch_drop,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
used_list_len,
free_list,
free_list_len,
bsz,
block_size,
block_num_per_seq,
max_draft_tokens);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu2or3_wrapper(ctx,
base_model_stop_flags,
stop_flags,
batch_drop,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
used_list_len,
free_list,
free_list_len,
bsz,
block_size,
block_num_per_seq,
max_draft_tokens);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,101 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
template <typename T>
__attribute__((global)) void rebuildHiddenStatesKernel(const T* input,
const int* position_map,
T* output,
int dim_embed,
int elem_cnt);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T>
static int cpu_wrapper(Context* ctx,
const T* input,
const int* position_map,
T* output,
int dim_embed,
int elem_cnt) {
for (int elem_id = 0; elem_id < elem_cnt; elem_id++) {
int ori_token_idx = elem_id / dim_embed;
int token_idx = position_map[ori_token_idx];
int offset = elem_id % dim_embed;
if (token_idx >= 0) {
output[token_idx * dim_embed + offset] =
input[ori_token_idx * dim_embed + offset];
}
}
return api::SUCCESS;
}
template <typename T>
static int xpu3_wrapper(Context* ctx,
const T* input,
const int* position_map,
T* output,
int dim_embed,
int elem_cnt) {
xpu3::plugin::rebuildHiddenStatesKernel<T>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
input, position_map, output, dim_embed, elem_cnt);
return api::SUCCESS;
}
template <typename T>
int rebuild_hidden_states(Context* ctx,
const T* input,
const int* position_map,
T* output,
int dim_embed,
int elem_cnt) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "rebuild_hidden_states", T);
WRAPPER_DUMP_PARAM5(ctx, input, position_map, output, dim_embed, elem_cnt);
WRAPPER_DUMP(ctx);
WRAPPER_ASSERT_GT(ctx, dim_embed, 0);
WRAPPER_ASSERT_GT(ctx, elem_cnt, 0);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(
ctx, input, position_map, output, dim_embed, elem_cnt);
} else if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<T>(
ctx, input, position_map, output, dim_embed, elem_cnt);
}
WRAPPER_UNIMPLEMENTED(ctx);
return api::SUCCESS;
}
template int rebuild_hidden_states(
Context*, const bfloat16*, const int*, bfloat16*, int, int);
template int rebuild_hidden_states(
Context*, const float*, const int*, float*, int, int);
template int rebuild_hidden_states(
Context*, const float16*, const int*, float16*, int, int);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,94 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
template <typename T>
__attribute__((global)) void rebuildSelfHiddenStatesKernel(
const T* input, int* src_map, T* output, int dim_embed, int elem_cnt);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T>
static int cpu_wrapper(Context* ctx,
const T* input,
int* src_map,
T* output,
int dim_embed,
int elem_cnt) {
for (int elem_id = 0; elem_id < elem_cnt; elem_id++) {
int output_token_idx = elem_id / dim_embed;
int input_token_idx = src_map[output_token_idx];
int offset = elem_id % dim_embed;
output[output_token_idx * dim_embed + offset] =
input[input_token_idx * dim_embed + offset];
}
return api::SUCCESS;
}
template <typename T>
static int xpu3_wrapper(Context* ctx,
const T* input,
int* src_map,
T* output,
int dim_embed,
int elem_cnt) {
xpu3::plugin::rebuildSelfHiddenStatesKernel<T>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
input, src_map, output, dim_embed, elem_cnt);
return api::SUCCESS;
}
template <typename T>
int rebuild_self_hidden_states(Context* ctx,
const T* input,
int* src_map,
T* output,
int dim_embed,
int elem_cnt) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "rebuild_self_hidden_states", T);
WRAPPER_DUMP_PARAM5(ctx, input, src_map, output, dim_embed, elem_cnt);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, T, elem_cnt, output);
WRAPPER_ASSERT_GT(ctx, dim_embed, 0);
WRAPPER_ASSERT_GT(ctx, elem_cnt, 0);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(ctx, input, src_map, output, dim_embed, elem_cnt);
} else if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<T>(ctx, input, src_map, output, dim_embed, elem_cnt);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int rebuild_self_hidden_states(
Context*, const bfloat16*, int*, bfloat16*, int, int);
template int rebuild_self_hidden_states(
Context*, const float*, int*, float*, int, int);
template int rebuild_self_hidden_states(
Context*, const float16*, int*, float16*, int, int);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,73 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void speculate_clear_accept_nums(
int* accept_num, const int* seq_lens_decoder, const int max_bsz);
}
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context* ctx,
int* accept_num,
const int* seq_lens_decoder,
const int max_bsz) {
for (int i = 0; i < max_bsz; i++) {
accept_num[i] = seq_lens_decoder[i] == 0 ? 0 : accept_num[i];
}
return SUCCESS;
}
static int xpu2or3_wrapper(Context* ctx,
int* accept_num,
const int* seq_lens_decoder,
const int max_bsz) {
ctx_guard RAII_GUARD(ctx);
xpu3::plugin::speculate_clear_accept_nums<<<1, 64, ctx->xpu_stream>>>(
accept_num, seq_lens_decoder, max_bsz);
return api::SUCCESS;
}
int speculate_clear_accept_nums(Context* ctx,
int* accept_num,
const int* seq_lens_decoder,
const int max_bsz) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_clear_accept_nums", int);
WRAPPER_DUMP_PARAM3(ctx, accept_num, seq_lens_decoder, max_bsz);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx, accept_num, seq_lens_decoder, max_bsz);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu2or3_wrapper(ctx, accept_num, seq_lens_decoder, max_bsz);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,230 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void speculate_free_and_reschedule(
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context *ctx,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens) {
return -1;
}
static int xpu3_wrapper(Context *ctx,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto speculate_free_and_reschedule =
xpu3::plugin::speculate_free_and_reschedule;
speculate_free_and_reschedule<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(first_token_ids),
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
return api::SUCCESS;
}
int speculate_free_and_reschedule(Context *ctx,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_free_and_reschedule", float);
WRAPPER_DUMP_PARAM6(ctx,
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step);
WRAPPER_DUMP_PARAM6(ctx,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len);
WRAPPER_DUMP_PARAM4(
ctx, used_list_len, free_list, free_list_len, first_token_ids);
WRAPPER_DUMP_PARAM5(ctx,
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
first_token_ids,
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
first_token_ids,
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,118 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void speculate_get_output_padding_offset(
int* output_padding_offset,
int* output_cum_offsets,
const int* output_cum_offsets_tmp,
const int* seq_lens_output,
const int bsz,
const int max_seq_len);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context* ctx,
int* output_padding_offset,
int* output_cum_offsets,
const int* output_cum_offsets_tmp,
const int* seq_lens_output,
const int bsz,
const int max_seq_len) {
for (int bi = 0; bi < bsz; bi++) {
int cum_offset = 0;
if (bi > 0) {
cum_offset = output_cum_offsets_tmp[bi - 1];
}
output_cum_offsets[bi] = cum_offset;
for (int token_i = 0; token_i < seq_lens_output[bi]; token_i++) {
output_padding_offset[bi * max_seq_len - cum_offset + token_i] =
cum_offset;
}
}
return SUCCESS;
}
static int xpu2or3_wrapper(Context* ctx,
int* output_padding_offset,
int* output_cum_offsets,
const int* output_cum_offsets_tmp,
const int* seq_lens_output,
const int bsz,
const int max_seq_len) {
ctx_guard RAII_GUARD(ctx);
xpu3::plugin::speculate_get_output_padding_offset<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
output_padding_offset,
output_cum_offsets,
output_cum_offsets_tmp,
seq_lens_output,
bsz,
max_seq_len);
return api::SUCCESS;
}
int speculate_get_output_padding_offset(Context* ctx,
int* output_padding_offset,
int* output_cum_offsets,
const int* output_cum_offsets_tmp,
const int* seq_lens_output,
const int bsz,
const int max_seq_len) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_output_padding_offset", int);
WRAPPER_DUMP_PARAM5(ctx,
output_padding_offset,
output_cum_offsets,
output_cum_offsets_tmp,
seq_lens_output,
max_seq_len);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
output_padding_offset,
output_cum_offsets,
output_cum_offsets_tmp,
seq_lens_output,
bsz,
max_seq_len);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu2or3_wrapper(ctx,
output_padding_offset,
output_cum_offsets,
output_cum_offsets_tmp,
seq_lens_output,
bsz,
max_seq_len);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,295 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl/xdnn_impl.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
template <typename T>
__attribute__((global)) void speculate_remove_padding(
T* output_data,
const T* input_data,
const T* draft_tokens,
const int* seq_lens,
const int* seq_lens_encoder,
const int* cum_offsets,
int sequence_length,
int max_draft_tokens,
int bsz,
int token_num_data);
__attribute__((global)) void speculate_get_padding_offset(
int* padding_offset,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
const int* cum_offsets,
const int* seq_lens,
const int max_seq_len,
int bsz);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T>
static int cpu_wrapper_remove_padding(Context* ctx,
T* output_data,
const T* input_data,
const T* draft_tokens,
const int* seq_lens,
const int* seq_lens_encoder,
const int* cum_offsets,
int sequence_length,
int max_draft_tokens,
int bsz,
int token_num_data) {
for (int bi = 0; bi < bsz; ++bi) {
for (int i = 0; i < seq_lens[bi]; i++) {
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
if (seq_lens_encoder[bi] > 0) {
const int src_seq_id = bi * sequence_length + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
} else {
const int src_seq_id = bi * max_draft_tokens + i;
output_data[tgt_seq_id] = draft_tokens[src_seq_id];
}
}
}
return api::SUCCESS;
}
static int cpu_wrapper_get_padding_offset(Context* ctx,
int* padding_offset,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
const int* cum_offsets,
const int* seq_lens,
const int max_seq_len,
int bsz) {
for (int bi = 0; bi < bsz; ++bi) {
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
for (int i = 0; i < seq_lens[bi]; i++) {
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
}
cum_offsets_out[bi] = cum_offset;
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
cu_seqlens_q[bi + 1] = cum_seq_len;
cu_seqlens_k[bi + 1] = cum_seq_len;
}
return api::SUCCESS;
}
template <typename T>
static int xpu3_wrapper_remove_padding(Context* ctx,
T* output_data,
const T* input_data,
const T* draft_tokens,
const int* seq_lens,
const int* seq_lens_encoder,
const int* cum_offsets,
int sequence_length,
int max_draft_tokens,
int bsz,
int token_num_data) {
using XPU_T = typename XPUIndexType<T>::type;
xpu3::plugin::speculate_remove_padding<XPU_T>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
static_cast<XPU_T*>(static_cast<void*>(output_data)),
static_cast<const XPU_T*>(static_cast<const void*>(input_data)),
static_cast<const XPU_T*>(static_cast<const void*>(draft_tokens)),
seq_lens,
seq_lens_encoder,
cum_offsets,
sequence_length,
max_draft_tokens,
bsz,
token_num_data);
return api::SUCCESS;
}
static int xpu3_wrapper_get_padding_offset(Context* ctx,
int* padding_offset,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
const int* cum_offsets,
const int* seq_lens,
const int max_seq_len,
int bsz) {
xpu3::plugin::
speculate_get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
padding_offset,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
cum_offsets,
seq_lens,
max_seq_len,
bsz);
return api::SUCCESS;
}
template <typename T>
int speculate_remove_padding(Context* ctx,
T* x_remove_padding,
const T* input_ids,
const T* draft_tokens,
const int* seq_lens,
const int* seq_lens_encoder,
const int* cum_offsets_out,
int seq_length,
int max_draft_tokens,
int bsz,
int token_num_data) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_remove_padding", T);
WRAPPER_DUMP_PARAM6(ctx,
x_remove_padding,
input_ids,
draft_tokens,
seq_lens,
seq_lens_encoder,
cum_offsets_out);
WRAPPER_DUMP_PARAM4(ctx, seq_length, max_draft_tokens, bsz, token_num_data);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, T, bsz * seq_length, input_ids);
WRAPPER_CHECK_PTR(ctx, T, bsz * max_draft_tokens, draft_tokens);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, int, bsz, cum_offsets_out);
WRAPPER_CHECK_PTR(ctx, T, token_num_data, x_remove_padding);
WRAPPER_ASSERT_GT(ctx, bsz, 0);
WRAPPER_ASSERT_GT(ctx, seq_length, 0);
WRAPPER_ASSERT_GT(ctx, max_draft_tokens, 0);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper_remove_padding(ctx,
x_remove_padding,
input_ids,
draft_tokens,
seq_lens,
seq_lens_encoder,
cum_offsets_out,
seq_length,
max_draft_tokens,
bsz,
token_num_data);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper_remove_padding(ctx,
x_remove_padding,
input_ids,
draft_tokens,
seq_lens,
seq_lens_encoder,
cum_offsets_out,
seq_length,
max_draft_tokens,
bsz,
token_num_data);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
int speculate_get_padding_offset(Context* ctx,
int* padding_offset,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
const int* cum_offsets,
const int* seq_lens,
const int max_seq_len,
int bsz) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_padding_offset", float);
WRAPPER_DUMP_PARAM6(ctx,
padding_offset,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
cum_offsets,
seq_lens);
WRAPPER_DUMP_PARAM2(ctx, max_seq_len, bsz);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, int, bsz, cum_offsets);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens);
WRAPPER_CHECK_PTR(ctx, int, bsz, cum_offsets_out);
WRAPPER_CHECK_PTR(ctx, int, bsz + 1, cu_seqlens_q);
WRAPPER_CHECK_PTR(ctx, int, bsz + 1, cu_seqlens_k);
WRAPPER_ASSERT_GT(ctx, bsz, 0);
WRAPPER_ASSERT_GT(ctx, max_seq_len, 0);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper_get_padding_offset(ctx,
padding_offset,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
cum_offsets,
seq_lens,
max_seq_len,
bsz);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper_get_padding_offset(ctx,
padding_offset,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
cum_offsets,
seq_lens,
max_seq_len,
bsz);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
#define INSTANTIATION_SPECULATE_REMOVE_PADDING(T) \
template int speculate_remove_padding<T>(Context * ctx, \
T * x_remove_padding, \
const T* input_ids, \
const T* draft_tokens, \
const int* seq_len, \
const int* seq_lens_encoder, \
const int* cum_offsets_out, \
int seq_length, \
int max_draft_tokens, \
int bsz, \
int token_num_data)
INSTANTIATION_SPECULATE_REMOVE_PADDING(float);
INSTANTIATION_SPECULATE_REMOVE_PADDING(float16);
INSTANTIATION_SPECULATE_REMOVE_PADDING(bfloat16);
INSTANTIATION_SPECULATE_REMOVE_PADDING(int64_t);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,111 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void speculate_get_seq_lens_output(
int* seq_lens_output,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int real_bsz);
}
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context* ctx,
int* seq_lens_output,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int real_bsz) {
for (int bid = 0; bid < real_bsz; ++bid) {
if (seq_lens_this_time[bid] == 0) {
continue;
} else if (seq_lens_this_time[bid] == 1) {
seq_lens_output[bid] = 1;
} else if (seq_lens_encoder[bid] != 0) {
seq_lens_output[bid] = 1;
} else {
seq_lens_output[bid] = seq_lens_this_time[bid];
}
}
return SUCCESS;
}
static int xpu2or3_wrapper(Context* ctx,
int* seq_lens_output,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int real_bsz) {
ctx_guard RAII_GUARD(ctx);
xpu3::plugin::
speculate_get_seq_lens_output<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
seq_lens_output,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
real_bsz);
return api::SUCCESS;
}
int speculate_get_seq_lens_output(Context* ctx,
int* seq_lens_output,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int real_bsz) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_seq_lens_output", int);
WRAPPER_DUMP_PARAM5(ctx,
seq_lens_output,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
real_bsz);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
seq_lens_output,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
real_bsz);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu2or3_wrapper(ctx,
seq_lens_output,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
real_bsz);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,156 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
template <typename T>
__attribute__((global)) void RebuildAppendPaddingKernel(
const T* full_hidden_states,
const int* cum_offsets,
const int* seq_len_encoder,
const int* seq_len_decoder,
const int* output_padding_offset,
int max_seq_len,
int dim_embed,
int elem_nums,
T* out);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T>
static int cpu_wrapper(Context* ctx,
T* full_hidden_states,
int* cum_offsets,
int* seq_len_encoder,
int* seq_len_decoder,
int* output_padding_offset,
int max_seq_len,
int dim_embed,
int elem_nums,
T* out) {
for (int64_t i = 0; i < elem_nums; ++i) {
int64_t out_token_id = i / dim_embed;
int64_t ori_token_id = out_token_id + output_padding_offset[out_token_id];
int64_t bi = ori_token_id / max_seq_len;
int64_t seq_id = 0;
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) {
continue;
} else if (seq_len_encoder[bi] != 0) {
seq_id = seq_len_encoder[bi] - 1;
}
int64_t input_token_id = ori_token_id - cum_offsets[bi] + seq_id;
int64_t bias_idx = i % dim_embed;
out[i] = full_hidden_states[input_token_id * dim_embed + bias_idx];
}
return api::SUCCESS;
}
template <typename T>
static int xpu3_wrapper(Context* ctx,
T* full_hidden_states,
int* cum_offsets,
int* seq_len_encoder,
int* seq_len_decoder,
int* output_padding_offset,
int max_seq_len,
int dim_embed,
int elem_nums,
T* out) {
xpu3::plugin::RebuildAppendPaddingKernel<T>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(full_hidden_states,
cum_offsets,
seq_len_encoder,
seq_len_decoder,
output_padding_offset,
max_seq_len,
dim_embed,
elem_nums,
out);
return api::SUCCESS;
}
template <typename T>
int speculate_rebuild_append_padding(Context* ctx,
T* full_hidden_states,
int* cum_offsets,
int* seq_len_encoder,
int* seq_len_decoder,
int* output_padding_offset,
int max_seq_len,
int dim_embed,
int elem_nums,
T* out) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_rebuild_append_padding", T);
WRAPPER_DUMP_PARAM5(ctx,
full_hidden_states,
cum_offsets,
seq_len_encoder,
seq_len_decoder,
output_padding_offset);
WRAPPER_DUMP_PARAM4(ctx, max_seq_len, dim_embed, elem_nums, out);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, T, elem_nums, out);
WRAPPER_ASSERT_GT(ctx, max_seq_len, 0);
WRAPPER_ASSERT_GT(ctx, dim_embed, 0);
WRAPPER_ASSERT_GT(ctx, elem_nums, 0);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(ctx,
full_hidden_states,
cum_offsets,
seq_len_encoder,
seq_len_decoder,
output_padding_offset,
max_seq_len,
dim_embed,
elem_nums,
out);
} else if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<T>(ctx,
full_hidden_states,
cum_offsets,
seq_len_encoder,
seq_len_decoder,
output_padding_offset,
max_seq_len,
dim_embed,
elem_nums,
out);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int speculate_rebuild_append_padding(
Context*, bfloat16*, int*, int*, int*, int*, int, int, int, bfloat16*);
template int speculate_rebuild_append_padding(
Context*, float16*, int*, int*, int*, int*, int, int, int, float16*);
template int speculate_rebuild_append_padding(
Context*, float*, int*, int*, int*, int*, int, int, int, float*);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,224 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void speculate_set_stop_value_multi_seqs(
bool* stop_flags,
int64_t* accept_tokens,
int* accept_nums,
const int64_t* pre_ids,
const int64_t* step_idx,
const int64_t* stop_seqs,
const int* stop_seqs_len,
const int* seq_lens,
const int64_t* end_ids,
const int bs,
const int accept_tokens_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int pre_ids_len);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context* ctx,
bool* stop_flags,
int64_t* accept_tokens,
int* accept_nums,
const int64_t* pre_ids,
const int64_t* step_idx,
const int64_t* stop_seqs,
const int* stop_seqs_len,
const int* seq_lens,
const int64_t* end_ids,
const int bs,
const int accept_tokens_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int pre_ids_len) {
for (int bid = 0; bid < bs; ++bid) {
const int64_t* pre_ids_now = pre_ids + bid * pre_ids_len;
int64_t* accept_tokens_now = accept_tokens + bid * accept_tokens_len;
const int accept_num = accept_nums[bid];
const int64_t step_idx_now = step_idx[bid];
for (int tid = 0; tid < stop_seqs_bs; ++tid) {
const int stop_seq_len = stop_seqs_len[tid];
if (stop_seq_len <= 0) continue;
const int64_t* stop_seq_now = stop_seqs + tid * stop_seqs_max_len;
if (!stop_flags[bid]) {
int accept_idx = 0;
bool is_end = false;
// 遍历起始位置
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) {
continue;
}
// 遍历一个 stop_seqs
for (int i = stop_seq_len - 1; i >= 0; --i) {
int64_t cur_token_idx = -1;
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
if (stop_seq_len - 1 - i < accept_idx) {
cur_token_idx =
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1];
} else {
int pre_ids_idx = step_idx_now - accept_num + accept_idx -
(stop_seq_len - 1 - i);
// EC3
// 特殊拼接会导致input_ids最后一位无特殊token即pre_ids[0]可能为23,
// 导致异常结束
if (pre_ids_idx <= 0) {
break;
}
cur_token_idx = pre_ids_now[pre_ids_idx];
}
if (cur_token_idx != stop_seq_now[i]) {
break;
}
if (i == 0) {
is_end = true;
}
}
}
if (is_end) {
accept_nums[bid] = accept_idx;
accept_tokens_now[accept_idx - 1] = end_ids[0];
stop_flags[bid] = true;
}
}
}
}
return api::SUCCESS;
}
static int xpu2or3_wrapper(Context* ctx,
bool* stop_flags,
int64_t* accept_tokens,
int* accept_nums,
const int64_t* pre_ids,
const int64_t* step_idx,
const int64_t* stop_seqs,
const int* stop_seqs_len,
const int* seq_lens,
const int64_t* end_ids,
const int bs,
const int accept_tokens_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int pre_ids_len) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
xpu3::plugin::speculate_set_stop_value_multi_seqs<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
stop_flags,
reinterpret_cast<XPU_INT64*>(accept_tokens),
accept_nums,
reinterpret_cast<const XPU_INT64*>(pre_ids),
reinterpret_cast<const XPU_INT64*>(step_idx),
reinterpret_cast<const XPU_INT64*>(stop_seqs),
stop_seqs_len,
seq_lens,
reinterpret_cast<const XPU_INT64*>(end_ids),
bs,
accept_tokens_len,
stop_seqs_bs,
stop_seqs_max_len,
pre_ids_len);
return api::SUCCESS;
}
int speculate_set_stop_value_multi_seqs(Context* ctx,
bool* stop_flags,
int64_t* accept_tokens,
int* accept_nums,
const int64_t* pre_ids,
const int64_t* step_idx,
const int64_t* stop_seqs,
const int* stop_seqs_len,
const int* seq_lens,
const int64_t* end_ids,
const int bs_now,
const int accept_tokens_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int pre_ids_len) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_set_stop_value_multi_seqs", int64_t);
WRAPPER_DUMP_PARAM3(ctx, stop_flags, accept_tokens, accept_nums);
WRAPPER_DUMP_PARAM6(
ctx, pre_ids, step_idx, stop_seqs, stop_seqs_len, seq_lens, end_ids);
WRAPPER_DUMP_PARAM5(ctx,
bs_now,
accept_tokens_len,
stop_seqs_bs,
stop_seqs_max_len,
pre_ids_len);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, int64_t, bs_now * accept_tokens_len, accept_tokens);
WRAPPER_CHECK_PTR(ctx, int64_t, stop_seqs_bs * stop_seqs_max_len, stop_seqs);
WRAPPER_ASSERT_GT(ctx, bs_now, 0);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
stop_flags,
accept_tokens,
accept_nums,
pre_ids,
step_idx,
stop_seqs,
stop_seqs_len,
seq_lens,
end_ids,
bs_now,
accept_tokens_len,
stop_seqs_bs,
stop_seqs_max_len,
pre_ids_len);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu2or3_wrapper(ctx,
stop_flags,
accept_tokens,
accept_nums,
pre_ids,
step_idx,
stop_seqs,
stop_seqs_len,
seq_lens,
end_ids,
bs_now,
accept_tokens_len,
stop_seqs_bs,
stop_seqs_max_len,
pre_ids_len);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,157 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void speculate_set_value_by_flag_and_id(
int64_t *pre_ids_all,
const int64_t *accept_tokens,
const int *accept_num,
const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx,
int bs,
int length,
int max_draft_tokens);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context *ctx,
int64_t *pre_ids_all, // bs * length
const int64_t *accept_tokens, // bs * max_draft_tokens
const int *accept_num, // bs
const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx,
int bs,
int length,
int max_draft_tokens) {
for (int i = 0; i < bs; i++) {
if (stop_flags[i] || (seq_lens_encoder[i] == 0 && seq_lens_decoder[i] == 0))
continue;
int64_t *pre_ids_all_now = pre_ids_all + i * length;
const int64_t *accept_tokens_now = accept_tokens + i * max_draft_tokens;
int accept_num_now = accept_num[i];
int64_t step_idx_now = step_idx[i];
if (step_idx_now >= 0) {
for (int j = 0; j < accept_num_now; j++) {
pre_ids_all_now[step_idx_now - j] =
accept_tokens_now[accept_num_now - 1 - j];
}
}
}
return SUCCESS;
}
static int xpu2or3_wrapper(Context *ctx,
int64_t *pre_ids_all,
const int64_t *accept_tokens,
const int *accept_num,
const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx,
int bs,
int length,
int max_draft_tokens) {
ctx_guard RAII_GUARD(ctx);
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
xpu3::plugin::speculate_set_value_by_flag_and_id<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64 *>(pre_ids_all),
reinterpret_cast<const XPU_INT64 *>(accept_tokens),
accept_num,
stop_flags,
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<const XPU_INT64 *>(step_idx),
bs,
length,
max_draft_tokens);
return api::SUCCESS;
}
int speculate_set_value_by_flag_and_id(Context *ctx,
int64_t *pre_ids_all,
const int64_t *accept_tokens,
const int *accept_num,
const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx,
int bs,
int length,
int max_draft_tokens) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_set_value_by_flag_and_id", int);
WRAPPER_DUMP_PARAM6(ctx,
pre_ids_all,
accept_tokens,
accept_num,
stop_flags,
seq_lens_encoder,
seq_lens_decoder);
WRAPPER_DUMP_PARAM4(ctx, step_idx, bs, length, max_draft_tokens);
WRAPPER_DUMP(ctx);
WRAPPER_ASSERT_LE(ctx, max_draft_tokens, 500);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
pre_ids_all,
accept_tokens,
accept_num,
stop_flags,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
bs,
length,
max_draft_tokens);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu2or3_wrapper(ctx,
pre_ids_all,
accept_tokens,
accept_num,
stop_flags,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
bs,
length,
max_draft_tokens);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,512 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
template <typename T>
__attribute__((global)) void speculate_min_length_logits_process(
T* logits,
const int64_t* cur_len,
const int64_t* min_len,
const int64_t* eos_token_id,
const int* output_padding_offset,
const int* output_cum_offsets,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t token_num,
const int64_t max_seq_len);
__attribute__((global)) void speculate_update_repeat_times(
const int64_t* pre_ids,
const int64_t* cur_len,
int* repeat_times,
const int* output_padding_offset,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t token_num,
const int64_t max_seq_len);
template <typename T>
__attribute__((global)) void speculate_update_value_by_repeat_times(
const int* repeat_times,
const T* penalty_scores,
const T* frequency_score,
const T* presence_score,
const float* temperatures,
T* logits,
const int* output_padding_offset,
const int64_t bs,
const int64_t length,
const int64_t token_num,
const int64_t max_seq_len);
template <typename T>
__attribute__((global)) void speculate_update_value_by_repeat_times_simd(
const int* repeat_times,
const T* penalty_scores,
const T* frequency_score,
const T* presence_score,
const float* temperatures,
T* logits,
const int* output_padding_offset,
const int64_t bs,
const int64_t length,
const int64_t token_num,
const int64_t max_seq_len);
template <typename T>
__attribute__((global)) void speculate_ban_bad_words(
T* logits,
const int64_t* bad_words_list,
const int* output_padding_offset,
const int64_t bs,
const int64_t length,
const int64_t bad_words_length,
const int64_t token_num,
const int64_t max_seq_len);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
void update_repeat_times_cpu(const int64_t* pre_ids,
const int64_t* cur_len,
int* repeat_times,
const int* output_padding_offset,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t token_num,
const int64_t max_seq_len) {
for (int64_t i = 0; i < token_num; i++) {
int64_t bi = (i + output_padding_offset[i]) / max_seq_len;
if (bi < bs && cur_len[bi] >= 0) {
for (int64_t j = 0; j < length_id; j++) {
int64_t id = pre_ids[bi * length_id + j];
if (id < 0) {
break;
} else if (id >= length) {
continue;
} else {
repeat_times[i * length + id] += 1;
}
}
}
}
}
void ban_bad_words_cpu(float* logits,
const int64_t* bad_words_list,
const int* output_padding_offset,
const int64_t bs,
const int64_t length,
const int64_t bad_words_length,
const int64_t token_num,
const int64_t max_seq_len) {
for (int64_t i = 0; i < token_num; i++) {
int64_t bi = (i + output_padding_offset[i]) / max_seq_len;
if (bi >= bs) {
continue;
}
float* logits_now = logits + i * length;
for (int64_t j = 0; j < bad_words_length; j++) {
int64_t bad_words_token_id = bad_words_list[j];
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
logits_now[bad_words_token_id] = -1e10;
}
}
}
template <typename T>
static int cpu_wrapper(Context* ctx,
const int64_t* pre_ids,
T* logits,
const T* penalty_scores,
const T* frequency_scores,
const T* presence_scores,
const float* temperatures,
const int64_t* cur_len,
const int64_t* min_len,
const int64_t* eos_token_id,
const int64_t* bad_words,
const int* output_padding_offset,
const int* output_cum_offsets,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t length_bad_words,
const int64_t token_num,
const int64_t max_seq_len) {
std::vector<float> logitsfp32(token_num * length);
std::vector<float> penalty_scoresfp32(bs);
std::vector<float> frequency_scoresfp32(bs);
std::vector<float> presence_scoresfp32(bs);
std::vector<int> repeat_times_buffer(token_num * length, 0);
int ret =
api::cast<T, float>(ctx, logits, logitsfp32.data(), token_num * length);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
ret = api::cast<T, float>(ctx, penalty_scores, penalty_scoresfp32.data(), bs);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
ret = api::cast<T, float>(
ctx, frequency_scores, frequency_scoresfp32.data(), bs);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
ret =
api::cast<T, float>(ctx, presence_scores, presence_scoresfp32.data(), bs);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
for (int64_t i = 0; i < token_num; i++) {
int64_t bi = (i + output_padding_offset[i]) / max_seq_len;
int64_t query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi];
if (bi < bs && cur_len[bi] >= 0 &&
(cur_len[bi] + (i - query_start_token_idx) < min_len[bi])) {
for (int64_t j = 0; j < end_length; j++) {
logitsfp32[i * length + eos_token_id[j]] =
std::is_same<T, float16>::value ? -1e4 : -1e10;
}
}
}
int* repeat_times = &(repeat_times_buffer[0]);
update_repeat_times_cpu(pre_ids,
cur_len,
repeat_times,
output_padding_offset,
bs,
length,
length_id,
token_num,
max_seq_len);
for (int64_t i = 0; i < token_num; i++) {
int64_t bi = (i + output_padding_offset[i]) / max_seq_len;
if (bi >= bs) {
continue;
}
float alpha = penalty_scoresfp32[bi];
float beta = frequency_scoresfp32[bi];
float gamma = presence_scoresfp32[bi];
float temperature = temperatures[bi];
for (int64_t j = 0; j < length; j++) {
int times = repeat_times[i * length + j];
float logit_now = logitsfp32[i * length + j];
if (times != 0) {
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
logit_now = logit_now - times * beta - gamma;
}
logitsfp32[i * length + j] = logit_now / temperature;
}
}
if (bad_words && length_bad_words > 0) {
ban_bad_words_cpu(logitsfp32.data(),
bad_words,
output_padding_offset,
bs,
length,
length_bad_words,
token_num,
max_seq_len);
}
ret = api::cast<float, T>(ctx, logitsfp32.data(), logits, token_num * length);
return ret;
}
template <typename T>
static int xpu3_wrapper(Context* ctx,
const int64_t* pre_ids,
T* logits,
const T* penalty_scores,
const T* frequency_scores,
const T* presence_scores,
const float* temperatures,
const int64_t* cur_len,
const int64_t* min_len,
const int64_t* eos_token_id,
const int64_t* bad_words,
const int* output_padding_offset,
const int* output_cum_offsets,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t length_bad_words,
const int64_t token_num,
const int64_t max_seq_len) {
api::ctx_guard RAII_GUARD(ctx);
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto min_length_logits_process_kernel =
xpu3::plugin::speculate_min_length_logits_process<T>;
auto update_repeat_times_kernel = xpu3::plugin::speculate_update_repeat_times;
auto update_value_by_repeat_times_kernel =
xpu3::plugin::speculate_update_value_by_repeat_times<T>;
if (length % 16 == 0) {
update_value_by_repeat_times_kernel =
xpu3::plugin::speculate_update_value_by_repeat_times_simd<T>;
}
auto ban_bad_words_kernel = xpu3::plugin::speculate_ban_bad_words<T>;
int* repeat_times = RAII_GUARD.alloc_l3_or_gm<int>(token_num * length);
WRAPPER_ASSERT_WORKSPACE(ctx, repeat_times);
int ret = api::constant<int>(ctx, repeat_times, token_num * length, 0);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
update_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const XPU_INT64*>(pre_ids),
reinterpret_cast<const XPU_INT64*>(cur_len),
repeat_times,
output_padding_offset,
bs,
length,
length_id,
token_num,
max_seq_len);
min_length_logits_process_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
logits,
reinterpret_cast<const XPU_INT64*>(cur_len),
reinterpret_cast<const XPU_INT64*>(min_len),
reinterpret_cast<const XPU_INT64*>(eos_token_id),
output_padding_offset,
output_cum_offsets,
bs,
length,
length_id,
end_length,
token_num,
max_seq_len);
update_value_by_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
repeat_times,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
logits,
output_padding_offset,
bs,
length,
token_num,
max_seq_len);
if (bad_words && length_bad_words > 0) {
ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
logits,
reinterpret_cast<const XPU_INT64*>(bad_words),
output_padding_offset,
bs,
length,
length_bad_words,
token_num,
max_seq_len);
}
return api::SUCCESS;
}
template <typename T>
int speculate_token_penalty_multi_scores(Context* ctx,
const int64_t* pre_ids,
T* logits,
const T* penalty_scores,
const T* frequency_scores,
const T* presence_scores,
const float* temperatures,
const int64_t* cur_len,
const int64_t* min_len,
const int64_t* eos_token_id,
const int64_t* bad_words,
const int* output_padding_offset,
const int* output_cum_offsets,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t length_bad_words,
const int64_t token_num,
const int64_t max_seq_len) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_token_penalty_multi_scores", T);
WRAPPER_DUMP_PARAM6(ctx,
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures);
WRAPPER_DUMP_PARAM6(ctx,
cur_len,
min_len,
eos_token_id,
bad_words,
output_padding_offset,
output_cum_offsets);
WRAPPER_DUMP_PARAM4(ctx, bs, length, length_id, end_length);
WRAPPER_DUMP_PARAM3(ctx, length_bad_words, token_num, max_seq_len);
WRAPPER_DUMP(ctx);
// TODO(mayang02) shape check
int64_t pre_ids_len = -1;
int64_t logits_len = -1;
int64_t penalty_scores_len = -1;
int64_t frequency_scores_len = -1;
int64_t presence_scores_len = -1;
int64_t temperatures_len = -1;
int64_t cur_len_len = -1;
int64_t min_len_len = -1;
int64_t eos_token_id_len = -1;
int64_t bad_words_len = -1;
int64_t output_padding_offset_len = -1;
int64_t output_cum_offsets_len = -1;
WRAPPER_ASSERT_LE(ctx, bs, 640);
WRAPPER_CHECK_SHAPE(ctx, &pre_ids_len, {bs, length_id});
WRAPPER_CHECK_SHAPE(ctx, &logits_len, {token_num, length});
WRAPPER_CHECK_SHAPE(ctx, &penalty_scores_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &frequency_scores_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &presence_scores_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &temperatures_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &cur_len_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &min_len_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &eos_token_id_len, {end_length});
WRAPPER_CHECK_SHAPE(ctx, &bad_words_len, {length_bad_words});
WRAPPER_CHECK_SHAPE(ctx, &output_padding_offset_len, {token_num});
WRAPPER_CHECK_SHAPE(ctx, &output_cum_offsets_len, {bs});
WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_len, pre_ids);
WRAPPER_CHECK_PTR(ctx, T, logits_len, logits);
WRAPPER_CHECK_PTR(ctx, T, penalty_scores_len, penalty_scores);
WRAPPER_CHECK_PTR(ctx, T, frequency_scores_len, frequency_scores);
WRAPPER_CHECK_PTR(ctx, T, presence_scores_len, presence_scores);
WRAPPER_CHECK_PTR(ctx, float, temperatures_len, temperatures);
WRAPPER_CHECK_PTR(ctx, int64_t, cur_len_len, cur_len);
WRAPPER_CHECK_PTR(ctx, int64_t, min_len_len, min_len);
WRAPPER_CHECK_PTR(ctx, int64_t, eos_token_id_len, eos_token_id);
WRAPPER_CHECK_PTR(ctx, int64_t, bad_words_len, bad_words);
WRAPPER_CHECK_PTR(ctx, int, output_padding_offset_len, output_padding_offset);
WRAPPER_CHECK_PTR(ctx, int, output_cum_offsets_len, output_cum_offsets);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(ctx,
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
cur_len,
min_len,
eos_token_id,
bad_words,
output_padding_offset,
output_cum_offsets,
bs,
length,
length_id,
end_length,
length_bad_words,
token_num,
max_seq_len);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<T>(ctx,
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
cur_len,
min_len,
eos_token_id,
bad_words,
output_padding_offset,
output_cum_offsets,
bs,
length,
length_id,
end_length,
length_bad_words,
token_num,
max_seq_len);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int speculate_token_penalty_multi_scores<float>(
Context* ctx,
const int64_t* pre_ids,
float* logits,
const float* penalty_scores,
const float* frequency_scores,
const float* presence_scores,
const float* temperatures,
const int64_t* cur_len,
const int64_t* min_len,
const int64_t* eos_token_id,
const int64_t* bad_words,
const int* output_padding_offset,
const int* output_cum_offsets,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t length_bad_words,
const int64_t token_num,
const int64_t max_seq_len);
template int speculate_token_penalty_multi_scores<float16>(
Context* ctx,
const int64_t* pre_ids,
float16* logits,
const float16* penalty_scores,
const float16* frequency_scores,
const float16* presence_scores,
const float* temperatures,
const int64_t* cur_len,
const int64_t* min_len,
const int64_t* eos_token_id,
const int64_t* bad_words,
const int* output_padding_offset,
const int* output_cum_offsets,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t length_bad_words,
const int64_t token_num,
const int64_t max_seq_len);
template int speculate_token_penalty_multi_scores<bfloat16>(
Context* ctx,
const int64_t* pre_ids,
bfloat16* logits,
const bfloat16* penalty_scores,
const bfloat16* frequency_scores,
const bfloat16* presence_scores,
const float* temperatures,
const int64_t* cur_len,
const int64_t* min_len,
const int64_t* eos_token_id,
const int64_t* bad_words,
const int* output_padding_offset,
const int* output_cum_offsets,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t length_bad_words,
const int64_t token_num,
const int64_t max_seq_len);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,241 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
template <int THREADBLOCK_SIZE>
__attribute__((global)) void speculate_update_v3(
int *seq_lens_encoder, // 输入 [B_max, ]
int *seq_lens_decoder, // 输出 [B_max, ]
bool *not_need_stop, // 输出 [1,]
int64_t *draft_tokens, // 输出 [B_max, T_max]
int *actual_draft_token_nums, // 输出 [B_max, ]
const int64_t *accept_tokens, // 输入 [B_max, T_max]
const int *accept_num, // 输入 [B_max, ]
const bool *stop_flags, // 输入 [B_max, ]
const int *seq_lens_this_time, // 输入 [B_real,]
const bool *is_block_step, // 输入 [B_max, ]
const int64_t *stop_nums, // 输入 [1,]
const int real_bsz,
const int max_bsz,
const int max_draft_tokens);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context *ctx,
int *seq_lens_encoder, // 输入 [B_max, ]
int *seq_lens_decoder, // 输出 [B_max, ]
bool *not_need_stop, // [1,]
int64_t *draft_tokens, // [B_max, T_max]
int *actual_draft_token_nums, // [B_max, ]
const int64_t *accept_tokens, // [B_max, T_max]
const int *accept_num, // [B_max, ]
const bool *stop_flags, // [B_max, ]
const int *seq_lens_this_time, // [B_real,]
const bool *is_block_step, // [B_max, ]
const int64_t *stop_nums, // [1,]
const int real_bsz,
const int max_bsz,
const int max_draft_tokens) {
int64_t stop_sum = 0;
for (int bid = 0; bid < max_bsz; ++bid) {
int stop_flag_now_int = 0;
const bool inactive = (bid >= real_bsz);
const bool block_step = (!inactive && is_block_step[bid]);
if (!block_step && !inactive) {
// 1. 本样本是否已触发 stop
if (stop_flags[bid]) stop_flag_now_int = 1;
// 2. encoder len == 0 时可直接累加 decoder
if (seq_lens_encoder[bid] == 0) {
seq_lens_decoder[bid] += accept_num[bid];
}
// 3. 根据「是否全部接受」动态调整 draft 长度
if (seq_lens_encoder[bid] == 0 && // append-mode 才走
seq_lens_this_time[bid] > 1) {
int cur_len = actual_draft_token_nums[bid];
if (accept_num[bid] - 1 == cur_len) {
// 全部接受:尝试 +2 / +1
if (cur_len + 2 <= max_draft_tokens - 1)
cur_len += 2;
else if (cur_len + 1 <= max_draft_tokens - 1)
cur_len += 1;
else
cur_len = max_draft_tokens - 1;
} else {
// 有拒绝:-1最小 1
cur_len = std::max(1, cur_len - 1);
}
actual_draft_token_nums[bid] = cur_len;
}
// 4. 偿还 encoder 欠账
if (seq_lens_encoder[bid] != 0) {
seq_lens_decoder[bid] += seq_lens_encoder[bid];
const_cast<int *>(seq_lens_encoder)[bid] = 0; // cast 因原指针是 const
}
// 6. 如果 stopdecoder 长度清零
if (stop_flag_now_int) {
seq_lens_decoder[bid] = 0;
} else {
// 5. 写回下一轮首 token但理论上只需要更新有效draft即可
draft_tokens[bid * max_draft_tokens] =
accept_tokens[bid * max_draft_tokens + accept_num[bid] - 1];
}
} else if (inactive) {
// padding slot直接当作 stop
stop_flag_now_int = 1;
}
stop_sum += stop_flag_now_int;
}
// 7. 写出全局标志
not_need_stop[0] = (stop_sum < stop_nums[0]);
return api::SUCCESS;
}
static int xpu3_wrapper(Context *ctx,
int *seq_lens_encoder, // 输入 [B_max, ]
int *seq_lens_decoder, // 输出 [B_max, ]
bool *not_need_stop, // [1,]
int64_t *draft_tokens, // [B_max, T_max]
int *actual_draft_token_nums, // [B_max, ]
const int64_t *accept_tokens, // [B_max, T_max]
const int *accept_num, // [B_max, ]
const bool *stop_flags, // [B_max, ]
const int *seq_lens_this_time, // [B_real,]
const bool *is_block_step, // [B_max, ]
const int64_t *stop_nums, // [1,]
const int real_bsz,
const int max_bsz,
const int max_draft_tokens) {
constexpr int BlockSize = 512;
using XPU_TI = typename XPUIndexType<int64_t>::type;
xpu3::plugin::speculate_update_v3<BlockSize>
<<<1, 64, ctx->xpu_stream>>>(seq_lens_encoder,
seq_lens_decoder,
not_need_stop,
reinterpret_cast<XPU_TI *>(draft_tokens),
actual_draft_token_nums,
(const XPU_TI *)accept_tokens,
accept_num,
stop_flags,
seq_lens_this_time,
is_block_step,
(const XPU_TI *)stop_nums,
real_bsz,
max_bsz,
max_draft_tokens);
return api::SUCCESS;
}
int speculate_update_v3(Context *ctx,
int *seq_lens_encoder, // 输入 [B_max, ]
int *seq_lens_decoder, // 输出 [B_max, ]
bool *not_need_stop, // [1,]
int64_t *draft_tokens, // [B_max, T_max]
int *actual_draft_token_nums, // [B_max, ]
const int64_t *accept_tokens, // [B_max, T_max]
const int *accept_num, // [B_max, ]
const bool *stop_flags, // [B_max, ]
const int *seq_lens_this_time, // [B_real,]
const bool *is_block_step, // [B_max, ]
const int64_t *stop_nums, // [1,]
const int real_bsz,
const int max_bsz,
const int max_draft_tokens) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_update_v3", int);
WRAPPER_DUMP_PARAM4(
ctx, seq_lens_encoder, seq_lens_decoder, not_need_stop, draft_tokens);
WRAPPER_DUMP_PARAM4(
ctx, actual_draft_token_nums, accept_tokens, accept_num, stop_flags);
WRAPPER_DUMP_PARAM4(
ctx, seq_lens_this_time, is_block_step, stop_nums, real_bsz);
WRAPPER_DUMP_PARAM2(ctx, max_bsz, max_draft_tokens);
WRAPPER_DUMP(ctx);
WRAPPER_ASSERT_GT(ctx, real_bsz, 0);
WRAPPER_ASSERT_GT(ctx, max_bsz, 0);
WRAPPER_ASSERT_LE(ctx, max_bsz, 512);
WRAPPER_ASSERT_GT(ctx, max_draft_tokens, 0);
WRAPPER_ASSERT_GE(ctx, max_bsz, real_bsz);
WRAPPER_CHECK_PTR(ctx, int, max_bsz, seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, int, max_bsz, seq_lens_decoder);
WRAPPER_CHECK_PTR(ctx, bool, 1, not_need_stop);
WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz * max_draft_tokens, draft_tokens);
WRAPPER_CHECK_PTR(ctx, int, max_bsz, actual_draft_token_nums);
WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz * max_draft_tokens, accept_tokens);
WRAPPER_CHECK_PTR(ctx, int, max_bsz, accept_num);
WRAPPER_CHECK_PTR(ctx, bool, max_bsz, stop_flags);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, bool, max_bsz, is_block_step);
WRAPPER_CHECK_PTR(ctx, int64_t, 1, stop_nums);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
seq_lens_encoder,
seq_lens_decoder,
not_need_stop,
draft_tokens,
actual_draft_token_nums,
accept_tokens,
accept_num,
stop_flags,
seq_lens_this_time,
is_block_step,
stop_nums,
real_bsz,
max_bsz,
max_draft_tokens);
} else if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
seq_lens_encoder,
seq_lens_decoder,
not_need_stop,
draft_tokens,
actual_draft_token_nums,
accept_tokens,
accept_num,
stop_flags,
seq_lens_this_time,
is_block_step,
stop_nums,
real_bsz,
max_bsz,
max_draft_tokens);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,543 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
typedef uint32_t curandStatePhilox4_32_10_t;
template <bool ENABLE_TOPP, bool USE_TOPK>
__attribute__((global)) void speculate_verify(
int64_t *accept_tokens,
int *accept_num,
int64_t *step_idx,
bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *draft_tokens,
const int *actual_draft_token_nums,
const float *dev_curand_states,
const float *topp,
const int *seq_lens_this_time,
const int64_t *verify_tokens,
const float *verify_scores,
const int64_t *max_dec_len,
const int64_t *end_tokens,
const bool *is_block_step,
const int *output_cum_offsets,
const int *actual_candidate_len,
const int real_bsz,
const int max_draft_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const bool prefill_one_step_stop);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static inline bool is_in_end(const int64_t id,
const int64_t *end_ids,
int length) {
bool flag = false;
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
}
return flag;
}
static inline bool is_in(const int64_t *candidates,
const int64_t draft,
const int candidate_len) {
for (int i = 0; i < candidate_len; i++) {
if (draft == candidates[i]) {
return true;
}
}
return false;
}
static inline unsigned int xorwow(unsigned int &state) { // NOLINT
state ^= state >> 7;
state ^= state << 9;
state ^= state >> 13;
return state;
}
typedef uint32_t curandStatePhilox4_32_10_t;
static int64_t topp_sampling_kernel(const int64_t *candidate_ids,
const float *candidate_scores,
const float *dev_curand_states,
const int candidate_len,
const float topp,
int tid) {
// const int tid = core_id();
float sum_scores = 0.0f;
float rand_top_p = *dev_curand_states * topp;
for (int i = 0; i < candidate_len; i++) {
// printf("debug cpu sample i:%d scores:%f,ids:%ld
// rand_top_p:%f,candidate_len:%d\n",
// i,candidate_scores[i],candidate_ids[i],rand_top_p,candidate_len);
sum_scores += candidate_scores[i];
if (rand_top_p <= sum_scores) {
return candidate_ids[i];
}
}
return candidate_ids[0];
}
template <bool ENABLE_TOPP, bool USE_TOPK>
static int cpu_wrapper(Context *ctx,
int64_t *accept_tokens,
int *accept_num,
int64_t *step_idx,
bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *draft_tokens,
const int *actual_draft_token_nums,
const float *dev_curand_states,
const float *topp,
const int *seq_lens_this_time,
const int64_t *verify_tokens,
const float *verify_scores,
const int64_t *max_dec_len,
const int64_t *end_tokens,
const bool *is_block_step,
const int *output_cum_offsets,
const int *actual_candidate_len,
const int real_bsz,
const int max_draft_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const bool prefill_one_step_stop) {
for (int bid = 0; bid < real_bsz; ++bid) {
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
// verify and set stop flags
int accept_num_now = 1;
int stop_flag_now_int = 0;
if (!(is_block_step[bid] || bid >= real_bsz)) {
// printf("debug cpu bid:%d,start_token_id:%d\n",bid, start_token_id);
// printf("bid %d\n", bid);
if (stop_flags[bid]) {
stop_flag_now_int = 1;
} else { // 这里prefill阶段也会进入但是因为draft
// tokens会置零因此会直接到最后的采样阶段
auto *verify_tokens_now =
verify_tokens + start_token_id * max_candidate_len;
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
int i = 0;
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
// seq_lens_this_time[bid]-1);
for (; i < seq_lens_this_time[bid] - 1; i++) {
if (seq_lens_encoder[bid] != 0) {
break;
}
if (USE_TOPK) {
if (verify_tokens_now[i * max_candidate_len] ==
draft_tokens_now[i + 1]) {
// accept_num_now++;
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
// printf("[USE_TOPK] bid %d Top 1 verify write accept
// %d is %lld\n", bid, i, accept_token);
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
// printf("[USE_TOPK] bid %d Top 1 verify write
// accept %d is %lld\n", bid, i, accept_token);
break;
} else {
accept_num_now++;
}
} else {
break;
}
} else {
auto actual_candidate_len_value =
actual_candidate_len_now[i] > max_candidate_len
? max_candidate_len
: actual_candidate_len_now[i];
if (is_in(verify_tokens_now + i * max_candidate_len,
draft_tokens_now[i + 1],
actual_candidate_len_value)) {
// Top P verify
// accept_num_now++;
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
// printf("bid %d Top P verify write accept %d is
// %lld\n", bid, i, accept_token);
break;
} else {
accept_num_now++;
}
} else {
// TopK verify
int ii = i;
if (max_candidate_len >= 2 &&
verify_tokens_now[ii * max_candidate_len + 1] ==
draft_tokens_now[ii + 1]) { // top-2
int j = 0;
ii += 1;
for (; j < verify_window && ii < seq_lens_this_time[bid] - 1;
j++, ii++) {
if (verify_tokens_now[ii * max_candidate_len] !=
draft_tokens_now[ii + 1]) {
break;
}
}
if (j >= verify_window) { // accept all
accept_num_now += verify_window + 1;
step_idx[bid] += verify_window + 1;
for (; i < ii; i++) {
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] = accept_token;
// printf("bid %d TopK verify write accept %dis "
// "%lld\n",bid,i,accept_token);
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] =
end_tokens[0];
// printf("bid %d TopK verify write accept %d is %lld\n",
// bid, i,end_tokens[0]);
accept_num_now--;
step_idx[bid]--;
break;
}
}
}
}
break;
}
}
}
// sampling阶段
// 第一种draft_token[i+1]被拒绝需要从verify_tokens_now[i]中选一个
// 第二种i == seq_lens_this_time[bid]-1,
// 也是从verify_tokens_now[i]中选一个 但是停止的情况不算
if (!stop_flag_now_int) {
int64_t accept_token;
const float *verify_scores_now =
verify_scores + start_token_id * max_candidate_len;
step_idx[bid]++;
if (ENABLE_TOPP) {
auto actual_candidate_len_value =
actual_candidate_len_now[i] > max_candidate_len
? max_candidate_len
: actual_candidate_len_now[i];
accept_token =
topp_sampling_kernel(verify_tokens_now + i * max_candidate_len,
verify_scores_now + i * max_candidate_len,
dev_curand_states + i,
actual_candidate_len_value,
topp[bid],
bid);
} else {
accept_token = verify_tokens_now[i * max_candidate_len];
}
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (prefill_one_step_stop) {
stop_flags[bid] = true;
}
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
}
}
accept_num[bid] = accept_num_now;
}
}
}
return api::SUCCESS;
}
template <bool ENABLE_TOPP, bool USE_TOPK>
static int xpu3_wrapper(Context *ctx,
int64_t *accept_tokens,
int *accept_num,
int64_t *step_idx,
bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *draft_tokens,
const int *actual_draft_token_nums,
const float *dev_curand_states,
const float *topp,
const int *seq_lens_this_time,
const int64_t *verify_tokens,
const float *verify_scores,
const int64_t *max_dec_len,
const int64_t *end_tokens,
const bool *is_block_step,
const int *output_cum_offsets,
const int *actual_candidate_len,
const int real_bsz,
const int max_draft_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const bool prefill_one_step_stop) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
xpu3::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>
<<<1, 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64 *>(accept_tokens),
accept_num,
reinterpret_cast<XPU_INT64 *>(step_idx),
stop_flags,
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<const XPU_INT64 *>(draft_tokens),
actual_draft_token_nums,
dev_curand_states,
topp,
seq_lens_this_time,
reinterpret_cast<const XPU_INT64 *>(verify_tokens),
verify_scores,
reinterpret_cast<const XPU_INT64 *>(max_dec_len),
reinterpret_cast<const XPU_INT64 *>(end_tokens),
is_block_step,
output_cum_offsets,
actual_candidate_len,
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
return api::SUCCESS;
}
template <bool ENABLE_TOPP, bool USE_TOPK>
int speculate_verify(Context *ctx,
int64_t *accept_tokens,
int *accept_num,
int64_t *step_idx,
bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *draft_tokens,
const int *actual_draft_token_nums,
const float *dev_curand_states,
const float *topp,
const int *seq_lens_this_time,
const int64_t *verify_tokens,
const float *verify_scores,
const int64_t *max_dec_len,
const int64_t *end_tokens,
const bool *is_block_step,
const int *output_cum_offsets,
const int *actual_candidate_len,
const int real_bsz,
const int max_draft_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const bool prefill_one_step_stop) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_verify", int64_t);
WRAPPER_DUMP_PARAM3(ctx, accept_tokens, accept_num, step_idx);
WRAPPER_DUMP_PARAM6(ctx,
stop_flags,
seq_lens_encoder,
seq_lens_decoder,
draft_tokens,
actual_draft_token_nums,
topp);
WRAPPER_DUMP_PARAM5(ctx,
seq_lens_this_time,
verify_tokens,
verify_scores,
max_dec_len,
end_tokens);
WRAPPER_DUMP_PARAM5(ctx,
is_block_step,
output_cum_offsets,
actual_candidate_len,
real_bsz,
max_draft_tokens);
WRAPPER_DUMP_PARAM5(ctx,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, accept_tokens);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, accept_num);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, step_idx);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, stop_flags);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_decoder);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, draft_tokens);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, actual_draft_token_nums);
WRAPPER_CHECK_PTR(ctx, float, real_bsz, dev_curand_states);
WRAPPER_CHECK_PTR(ctx, float, real_bsz, topp);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time);
// WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, verify_tokens);
// WRAPPER_CHECK_PTR(ctx, float, real_bsz, verify_scores);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, max_dec_len);
WRAPPER_CHECK_PTR(ctx, int64_t, end_length, end_tokens);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, is_block_step);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, output_cum_offsets);
// WRAPPER_CHECK_PTR(ctx, int, real_bsz, actual_candidate_len);
// param check sm size limit
WRAPPER_ASSERT_GT(ctx, real_bsz, 0);
WRAPPER_ASSERT_LE(ctx, real_bsz, 1024);
WRAPPER_ASSERT_LE(ctx, real_bsz * max_candidate_len, 2048);
WRAPPER_ASSERT_LE(ctx, verify_window * max_candidate_len, 128);
// int sum = 0;
// for (int i=0;i < real_bsz; i++){
// sum+= seq_lens_this_time[i];
// }
// WRAPPER_ASSERT_LE(ctx, sum * max_draft_tokens, 2048);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<ENABLE_TOPP, USE_TOPK>(ctx,
accept_tokens,
accept_num,
step_idx,
stop_flags,
seq_lens_encoder,
seq_lens_decoder,
draft_tokens,
actual_draft_token_nums,
dev_curand_states,
topp,
seq_lens_this_time,
verify_tokens,
verify_scores,
max_dec_len,
end_tokens,
is_block_step,
output_cum_offsets,
actual_candidate_len,
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<ENABLE_TOPP, USE_TOPK>(ctx,
accept_tokens,
accept_num,
step_idx,
stop_flags,
seq_lens_encoder,
seq_lens_decoder,
draft_tokens,
actual_draft_token_nums,
dev_curand_states,
topp,
seq_lens_this_time,
verify_tokens,
verify_scores,
max_dec_len,
end_tokens,
is_block_step,
output_cum_offsets,
actual_candidate_len,
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
#define INSTANTIATE_SPECULATE_VERIFY(ENABLE_TOPP, USE_TOPK) \
template int \
baidu::xpu::api::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>( \
baidu::xpu::api::Context *, /* xpu_ctx */ \
int64_t *, /* accept_tokens */ \
int *, /* accept_num */ \
int64_t *, /* step_idx */ \
bool *, /* stop_flags */ \
const int *, /* seq_lens_encoder */ \
const int *, /* seq_lens_decoder */ \
const int64_t *, /* draft_tokens */ \
const int *, /* actual_draft_token_nums */ \
const float *, /* dev_curand_states or topp */ \
const float *, /* topp or nullptr */ \
const int *, /* seq_lens_this_time */ \
const int64_t *, /* verify_tokens */ \
const float *, /* verify_scores */ \
const int64_t *, /* max_dec_len */ \
const int64_t *, /* end_tokens */ \
const bool *, /* is_block_step */ \
const int *, /* output_cum_offsets */ \
const int *, /* actual_candidate_len */ \
int, /* real_bsz */ \
int, /* max_draft_tokens */ \
int, /* end_length */ \
int, /* max_seq_len */ \
int, /* max_candidate_len */ \
int, /* verify_window */ \
bool); /* prefill_one_step_stop */
INSTANTIATE_SPECULATE_VERIFY(false, false)
INSTANTIATE_SPECULATE_VERIFY(false, true)
INSTANTIATE_SPECULATE_VERIFY(true, false)
INSTANTIATE_SPECULATE_VERIFY(true, true)
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,266 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
template <typename T, int MaxLength, int TopPBeamTopK>
__attribute__((global)) void top_p_candidates(const T* src,
const T* top_ps,
const int* output_padding_offset,
int64_t* out_id,
T* out_val,
int* actual_candidates_lens,
int vocab_size,
int token_num,
int max_candidate_len,
int max_seq_len);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T, int MaxLength, int TopPBeamTopK>
static int cpu_wrapper(Context* ctx,
const T* src,
const T* top_ps,
const int* output_padding_offset,
int64_t* out_id,
T* out_val,
int* actual_candidates_lens,
int vocab_size,
int token_num,
int candidate_len,
int max_seq_len) {
int64_t local_out_id[TopPBeamTopK];
T local_out_val[TopPBeamTopK];
for (int64_t i = 0; i < token_num; i++) {
float sum_prob = 0.0f;
for (int j = 0; j < TopPBeamTopK; j++) {
local_out_id[j] = -1;
local_out_val[j] = std::numeric_limits<T>::min();
}
const T* cur_row_src = src + i * vocab_size;
for (int id = 0; id < vocab_size; id++) {
if (cur_row_src[id] > local_out_val[TopPBeamTopK - 1] ||
(cur_row_src[id] == local_out_val[TopPBeamTopK - 1] &&
id < local_out_id[TopPBeamTopK - 1])) {
local_out_id[TopPBeamTopK - 1] = id;
local_out_val[TopPBeamTopK - 1] = cur_row_src[id];
for (int k = TopPBeamTopK - 2; k >= 0; k--) {
if (local_out_val[k + 1] > local_out_val[k] ||
(local_out_val[k + 1] == local_out_val[k] &&
local_out_id[k + 1] < local_out_id[k])) {
std::swap(local_out_id[k + 1], local_out_id[k]);
std::swap(local_out_val[k + 1], local_out_val[k]);
}
}
}
}
int ori_token_id = i + output_padding_offset[i];
int bid = ori_token_id / max_seq_len;
float top_p_value = static_cast<float>(top_ps[bid]);
bool set_to_default_val = false;
for (int j = 0; j < TopPBeamTopK; j++) {
if (set_to_default_val) {
out_id[i * candidate_len + j] = 0;
out_val[i * candidate_len + j] = 0;
} else {
out_id[i * candidate_len + j] = local_out_id[j];
out_val[i * candidate_len + j] = local_out_val[j];
float val = static_cast<float>(local_out_val[j]);
sum_prob += val;
if (sum_prob >= top_p_value) {
actual_candidates_lens[i] = j + 1;
set_to_default_val = true;
}
}
}
}
return api::SUCCESS;
}
template <typename T, int MaxLength, int TopPBeamTopK>
static int xpu3_wrapper(Context* ctx,
const T* src,
const T* top_ps,
const int* output_padding_offset,
int64_t* out_id,
T* out_val,
int* actual_candidates_lens,
int vocab_size,
int token_num,
int candidate_len,
int max_seq_len) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
xpu3::plugin::top_p_candidates<T, MaxLength, TopPBeamTopK>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
src,
top_ps,
output_padding_offset,
reinterpret_cast<XPU_INT64*>(out_id),
out_val,
actual_candidates_lens,
vocab_size,
token_num,
candidate_len,
max_seq_len);
return api::SUCCESS;
}
template <typename T, int MaxLength, int TopPBeamTopK>
int top_p_candidates(Context* ctx,
const T* src,
const T* top_ps,
const int* output_padding_offset,
int64_t* out_id,
T* out_val,
int* actual_candidates_lens,
int vocab_size,
int token_num,
int candidate_len,
int max_seq_len) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "top_p_candidates", T);
WRAPPER_DUMP_PARAM5(ctx, src, top_ps, output_padding_offset, out_id, out_val);
WRAPPER_DUMP_PARAM5(ctx,
actual_candidates_lens,
vocab_size,
token_num,
candidate_len,
max_seq_len);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, T, token_num * vocab_size, src);
WRAPPER_CHECK_PTR(ctx, T, token_num, output_padding_offset);
WRAPPER_CHECK_PTR(ctx, T, token_num * candidate_len, out_id);
WRAPPER_CHECK_PTR(ctx, T, token_num * candidate_len, out_val);
WRAPPER_ASSERT_GT(ctx, vocab_size, 0);
WRAPPER_ASSERT_GT(ctx, token_num, 0);
WRAPPER_ASSERT_GT(ctx, candidate_len, 0);
WRAPPER_ASSERT_GT(ctx, max_seq_len, 0);
WRAPPER_ASSERT_GT(ctx, TopPBeamTopK, 0);
WRAPPER_ASSERT_LE(ctx, TopPBeamTopK, 10);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T, MaxLength, TopPBeamTopK>(ctx,
src,
top_ps,
output_padding_offset,
out_id,
out_val,
actual_candidates_lens,
vocab_size,
token_num,
candidate_len,
max_seq_len);
} else if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<T, MaxLength, TopPBeamTopK>(ctx,
src,
top_ps,
output_padding_offset,
out_id,
out_val,
actual_candidates_lens,
vocab_size,
token_num,
candidate_len,
max_seq_len);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
#define _XPU_DEF_TOP_P_CANDIDATES_WRAPPER(T, MaxLength) \
template int top_p_candidates<T, MaxLength, 2>(Context*, \
const T*, \
const T*, \
const int*, \
int64_t*, \
T*, \
int*, \
int, \
int, \
int, \
int); \
template int top_p_candidates<T, MaxLength, 3>(Context*, \
const T*, \
const T*, \
const int*, \
int64_t*, \
T*, \
int*, \
int, \
int, \
int, \
int); \
template int top_p_candidates<T, MaxLength, 4>(Context*, \
const T*, \
const T*, \
const int*, \
int64_t*, \
T*, \
int*, \
int, \
int, \
int, \
int); \
template int top_p_candidates<T, MaxLength, 5>(Context*, \
const T*, \
const T*, \
const int*, \
int64_t*, \
T*, \
int*, \
int, \
int, \
int, \
int); \
template int top_p_candidates<T, MaxLength, 8>(Context*, \
const T*, \
const T*, \
const int*, \
int64_t*, \
T*, \
int*, \
int, \
int, \
int, \
int); \
template int top_p_candidates<T, MaxLength, 10>(Context*, \
const T*, \
const T*, \
const int*, \
int64_t*, \
T*, \
int*, \
int, \
int, \
int, \
int);
_XPU_DEF_TOP_P_CANDIDATES_WRAPPER(bfloat16, 2);
_XPU_DEF_TOP_P_CANDIDATES_WRAPPER(float, 2);
_XPU_DEF_TOP_P_CANDIDATES_WRAPPER(float16, 2);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -162,6 +162,11 @@ def xpu_setup_ops():
] ]
ops = [os.path.join(base_dir, op) for op in ops] ops = [os.path.join(base_dir, op) for op in ops]
for root, dirs, files in os.walk(base_dir / "ops/mtp_ops"):
for file in files:
if file.endswith(".cc"):
ops.append(os.path.join(root, file))
include_dirs = [ include_dirs = [
os.path.join(base_dir, "./"), os.path.join(base_dir, "./"),
os.path.join(base_dir, "./plugin/include"), os.path.join(base_dir, "./plugin/include"),

View File

@@ -0,0 +1,93 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import draft_model_postprocess
def draft_model_postprocess_cpu(
base_model_draft_tokens, # 2D列表: [bsz, base_model_draft_token_len] # 1D列表: [bsz]
base_model_seq_lens_encoder, # 1D列表: [bsz]
base_model_stop_flags, # 1D列表: [bsz]
):
bsz = base_model_draft_tokens.shape[0]
base_model_draft_token_len = base_model_draft_tokens.shape[1]
base_model_seq_lens_this_time = paddle.ones((bsz), dtype=paddle.int32)
# 遍历每个样本
for tid in range(bsz):
if (not base_model_stop_flags[tid]) and (base_model_seq_lens_encoder[tid] == 0):
# 获取当前样本的草稿token列表
base_model_draft_tokens_now = base_model_draft_tokens[tid]
token_num = 0
for i in range(base_model_draft_token_len):
if base_model_draft_tokens_now[i] != -1:
token_num += 1
# 更新序列长度
base_model_seq_lens_this_time[tid] = token_num
elif base_model_stop_flags[tid]:
# 已停止的样本序列长度为0
base_model_seq_lens_this_time[tid] = 0
return [base_model_seq_lens_this_time]
def test_draft_model_postprocess(batch_size=1, base_model_draft_token_len=8192): # 批次大小
paddle.seed(66)
base_model_draft_tokens = paddle.randint(
low=-1,
high=1,
shape=[batch_size, base_model_draft_token_len],
dtype="int64",
)
# base_model_seq_lens_this_time = paddle.ones((batch_size), dtype=paddle.int32)
base_model_seq_lens_encoder = paddle.randint(low=0, high=2, shape=[batch_size], dtype="int32")
random_floats = paddle.rand(shape=[batch_size])
base_model_stop_flags = random_floats >= 0.5
base_model_seq_lens_this_time = draft_model_postprocess_cpu(
base_model_draft_tokens, # 2D列表: [bsz, base_model_draft_token_len]
base_model_seq_lens_encoder, # 1D列表: [bsz]
base_model_stop_flags,
)
base_model_seq_lens_this_time_xpu = paddle.ones((batch_size), dtype=paddle.int32)
draft_model_postprocess(
base_model_draft_tokens, # 2D列表: [bsz, base_model_draft_token_len]
base_model_seq_lens_this_time_xpu, # 1D列表: [bsz]
base_model_seq_lens_encoder, # 1D列表: [bsz]
base_model_stop_flags,
)
print("test start")
assert np.allclose(base_model_seq_lens_this_time, base_model_seq_lens_this_time_xpu)
print("test passed")
def test_enough_cases():
test_draft_model_postprocess(100, 1024)
test_draft_model_postprocess(1, 11)
test_draft_model_postprocess(1, 8192)
test_draft_model_postprocess(2, 2048)
test_draft_model_postprocess(3, 1023)
test_draft_model_postprocess(4, 2047)
test_draft_model_postprocess(5, 4095)
test_draft_model_postprocess(10, 9191)
test_draft_model_postprocess(20, 618)
test_draft_model_postprocess(30, 703)
test_draft_model_postprocess(100, 1025)
test_draft_model_postprocess(1536, 1026)
if __name__ == "__main__":
test_enough_cases()

View File

@@ -0,0 +1,135 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import draft_model_preprocess
def run_test(device="xpu"):
paddle.seed(2022)
# Define parameters
bsz = 10
draft_tokens_len = 4
input_ids_len = 8
max_draft_token = 10
truncate_first_token = True
splitwise_prefill = False
# Create input tensors
if device == "cpu":
paddle.set_device(device)
draft_tokens = paddle.randint(0, 100, [bsz, draft_tokens_len], dtype="int64")
input_ids = paddle.randint(0, 100, [bsz, input_ids_len], dtype="int64")
stop_flags = paddle.randint(0, 1, [bsz], dtype="int").cast("bool")
seq_lens_this_time = paddle.randint(0, 100, [bsz], dtype="int32")
seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32")
seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32")
step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
seq_lens_encoder_record = paddle.randint(0, 100, [bsz], dtype="int32")
seq_lens_decoder_record = paddle.randint(0, 100, [bsz], dtype="int32")
not_need_stop = paddle.zeros([1], dtype="bool").cpu()
batch_drop = paddle.zeros([bsz], dtype="bool")
# Output tensors
accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64")
accept_num = paddle.randint(1, max_draft_token + 5, [bsz], dtype="int32")
base_model_seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32")
base_model_seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32")
base_model_step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
base_model_stop_flags = paddle.zeros([bsz], dtype="bool")
base_model_is_block_step = paddle.zeros([bsz], dtype="bool")
base_model_draft_tokens = paddle.zeros([bsz, max_draft_token], dtype="int64")
# Run the op
outputs = draft_model_preprocess(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
seq_lens_encoder_record,
seq_lens_decoder_record,
not_need_stop,
batch_drop,
accept_tokens,
accept_num,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
max_draft_token=max_draft_token,
truncate_first_token=truncate_first_token,
splitwise_prefill=splitwise_prefill,
)
# Return results for comparison
results = {
"draft_tokens": draft_tokens.numpy(),
"input_ids": input_ids.numpy(),
"stop_flags": stop_flags.numpy(),
"seq_lens_this_time": seq_lens_this_time.numpy(),
"accept_tokens": accept_tokens.numpy(),
"accept_num": accept_num.numpy(),
"not_need_stop": not_need_stop.numpy(),
"outputs": [x.numpy() for x in outputs],
}
return results
def compare_results(cpu_results, xpu_results):
# Compare all outputs
for key in cpu_results:
if key == "outputs":
for i, (cpu_out, xpu_out) in enumerate(zip(cpu_results[key], xpu_results[key])):
np.testing.assert_allclose(
cpu_out,
xpu_out,
rtol=1e-5,
atol=1e-8,
err_msg=f"Output {i} mismatch between CPU and GPU",
)
else:
np.testing.assert_allclose(
cpu_results[key],
xpu_results[key],
rtol=1e-5,
atol=1e-8,
err_msg=f"{key} mismatch between CPU and GPU",
)
print("CPU and GPU results match!")
def test_draft_model_preprocess():
print("Running XPU test...")
xpu_results = run_test("xpu")
print("Running CPU test...")
cpu_results = run_test("cpu")
print("Comparing results...")
compare_results(cpu_results, xpu_results)
print("Test passed!")
if __name__ == "__main__":
test_draft_model_preprocess()

View File

@@ -0,0 +1,122 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import draft_model_update
def run_paddle_test(device="cpu"):
np.random.seed(42)
paddle.seed(42)
if device == "cpu":
paddle.set_device(device)
elif device == "xpu":
paddle.set_device(device)
else:
raise ValueError(f"Invalid device: {device}")
# 设置参数
max_bsz = 128
max_draft_token = 3
pre_id_length = 3
max_seq_len = 100
max_base_model_draft_token = 4
substep = 2
# 创建随机张量
inter_next_tokens = paddle.randint(1, 100, shape=(max_bsz, max_seq_len), dtype="int64")
draft_tokens = paddle.randint(1, 100, shape=(max_bsz, max_draft_token), dtype="int64")
pre_ids = paddle.randint(1, 100, shape=(max_bsz, pre_id_length), dtype="int64")
seq_lens_this_time = paddle.randint(1, 2, shape=(max_bsz,), dtype="int32")
seq_lens_encoder = paddle.randint(1, 10, shape=(max_bsz,), dtype="int32")
seq_lens_decoder = paddle.randint(1, 10, shape=(max_bsz,), dtype="int32")
step_idx = paddle.randint(1, 10, shape=(max_bsz,), dtype="int64")
output_cum_offsets = paddle.randint(0, 2, shape=(max_bsz,), dtype="int32")
output_cum_offsets[0] = 0 # 确保第一个偏移量为0
stop_flags = paddle.zeros([max_bsz], dtype="bool")
not_need_stop = paddle.zeros([1], dtype="bool")
max_dec_len = paddle.randint(100, 102, shape=(max_bsz,), dtype="int64")
end_ids = paddle.to_tensor([2], dtype="int64")
base_model_draft_tokens = paddle.randint(1, 10, shape=(max_bsz, max_base_model_draft_token), dtype="int64")
# 打印张量信息
# print("inter_next_tokens shape:", inter_next_tokens.shape)
# print("draft_tokens shape:", draft_tokens.shape)
# print("pre_ids shape:", pre_ids.shape)
# print("seq_lens_this_time shape:", seq_lens_this_time.shape)
# print("seq_lens_encoder shape:", seq_lens_encoder.shape)
# print("seq_lens_decoder shape:", seq_lens_decoder.shape)
# print("step_idx shape:", step_idx.shape)
# print("output_cum_offsets shape:", output_cum_offsets.shape)
# print("stop_flags shape:", stop_flags.shape)
# print("not_need_stop shape:", not_need_stop.shape)
# print("max_dec_len shape:", max_dec_len.shape)
# print("end_ids shape:", end_ids.shape)
# print("base_model_draft_tokens shape:", base_model_draft_tokens.shape)
# print("draft_tokens before update:", draft_tokens)
# print("pre_ids before update:", pre_ids)
draft_model_update(
inter_next_tokens,
draft_tokens,
pre_ids,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
output_cum_offsets,
stop_flags,
not_need_stop,
max_dec_len,
end_ids,
base_model_draft_tokens,
max_seq_len,
substep,
)
# print("draft_tokens after update:", draft_tokens)
# print("pre_ids after update:", pre_ids)
return (
draft_tokens,
pre_ids,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
stop_flags,
not_need_stop,
base_model_draft_tokens,
)
if __name__ == "__main__":
res_xpu = run_paddle_test("xpu")
res_cpu = run_paddle_test()
for idx in range(len(res_cpu)):
# 将结果转换为numpy数组
cpu_arr = res_cpu[idx].numpy()
xpu_arr = res_xpu[idx].numpy()
# 检查是否为布尔类型
if cpu_arr.dtype == bool:
assert np.array_equal(cpu_arr, xpu_arr), f"布尔结果在索引 {idx} 处不匹配"
else:
# 对于数值类型,使用更宽松的比较条件
assert np.allclose(
cpu_arr, xpu_arr, rtol=1e-4, atol=1e-5
), f"数值结果在索引 {idx} 处不匹配,最大差异: {np.max(np.abs(cpu_arr - xpu_arr))}"
print(f"结果 {idx} 验证通过")

View File

@@ -0,0 +1,104 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import eagle_get_hidden_states
def test_eagle_get_hidden_states():
bs = np.random.randint(1, 8 + 1, dtype=np.int32)
input_token_num = np.random.randint(2 * 1024, 4 * 1024 + 1, dtype=np.int32)
dim_embed = np.random.randint(1, 4 * 1024 + 1, dtype=np.int32)
actual_draft_token_num = np.random.randint(2, 6, dtype=np.int32)
seq_lens_this_time = np.random.randint(0, 2, bs, dtype=np.int32)
seq_lens_encoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
accept_nums = np.random.randint(0, actual_draft_token_num + 1, bs, dtype=np.int32)
base_model_seq_lens_this_time = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
base_model_seq_lens_encoder = np.random.randint(0, 2, bs, dtype=np.int32)
# dont care
seq_lens_decoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
stop_flags = np.random.randint(0, 2, bs, dtype=np.int32)
seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32)
seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder, dtype=paddle.int32)
accept_nums_tensor = paddle.to_tensor(accept_nums, dtype=paddle.int32)
base_model_seq_lens_this_time_tensor = paddle.to_tensor(base_model_seq_lens_this_time, dtype=paddle.int32)
base_model_seq_lens_encoder_tensor = paddle.to_tensor(base_model_seq_lens_encoder, dtype=paddle.int32)
# dont care
seq_lens_decoder_tensor = paddle.to_tensor(seq_lens_decoder, dtype=paddle.int32)
stop_flags_tensor = paddle.to_tensor(stop_flags, dtype=paddle.int32)
# fp32 test
input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32)
input_tensor = paddle.to_tensor(input, dtype=paddle.float32)
cpu_out = eagle_get_hidden_states(
input_tensor.cpu(),
seq_lens_this_time_tensor.cpu(),
seq_lens_encoder_tensor.cpu(),
seq_lens_decoder_tensor.cpu(),
stop_flags_tensor.cpu(),
accept_nums_tensor.cpu(),
base_model_seq_lens_this_time_tensor.cpu(),
base_model_seq_lens_encoder_tensor.cpu(),
actual_draft_token_num,
)
xpu_out = eagle_get_hidden_states(
input_tensor,
seq_lens_this_time_tensor,
seq_lens_encoder_tensor,
seq_lens_decoder_tensor,
stop_flags_tensor,
accept_nums_tensor,
base_model_seq_lens_this_time_tensor,
base_model_seq_lens_encoder_tensor,
actual_draft_token_num,
)
assert np.allclose(cpu_out.numpy(), xpu_out.numpy())
# bf16/fp16 test
for dtype in [paddle.bfloat16, paddle.float16]:
input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int16)
input_tensor = paddle.to_tensor(input, dtype=dtype)
cpu_out = eagle_get_hidden_states(
input_tensor.cpu(),
seq_lens_this_time_tensor.cpu(),
seq_lens_encoder_tensor.cpu(),
seq_lens_decoder_tensor.cpu(),
stop_flags_tensor.cpu(),
accept_nums_tensor.cpu(),
base_model_seq_lens_this_time_tensor.cpu(),
base_model_seq_lens_encoder_tensor.cpu(),
actual_draft_token_num,
)
xpu_out = eagle_get_hidden_states(
input_tensor,
seq_lens_this_time_tensor,
seq_lens_encoder_tensor,
seq_lens_decoder_tensor,
stop_flags_tensor,
accept_nums_tensor,
base_model_seq_lens_this_time_tensor,
base_model_seq_lens_encoder_tensor,
actual_draft_token_num,
)
assert np.allclose(cpu_out.numpy(), xpu_out.numpy())
print("All test passed")
if __name__ == "__main__":
test_eagle_get_hidden_states()

View File

@@ -0,0 +1,132 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import eagle_get_self_hidden_states
def computeOrder(last_seq_lens_this_time, seq_lens_this_time, step_idx, src_map, bsz):
in_offset = 0
out_offset = 0
for i in range(bsz):
cur_seq_lens_this_time = seq_lens_this_time[i]
cur_last_seq_lens_this_time = last_seq_lens_this_time[i]
# 1. encoder
if step_idx[i] == 1 and cur_seq_lens_this_time > 0:
in_offset += 1
src_map[out_offset] = in_offset - 1
out_offset += 1
# 2. decoder
elif cur_seq_lens_this_time > 0:
in_offset += cur_last_seq_lens_this_time
src_map[out_offset] = in_offset - 1
out_offset += 1
# 3. stop
else:
# first token end
if step_idx[i] == 1:
in_offset += 1 if cur_last_seq_lens_this_time > 0 else 0
# normal end
else:
in_offset += cur_last_seq_lens_this_time
return (out_offset, src_map)
def rebuildSelfHiddenStatesKernel(input, src_map, out, dim_embed, elem_cnt):
print(f"input.shape {input.shape}")
print(f"out.shape {out.shape}")
print(f"elem_cnt {elem_cnt}")
for elem_id in range(elem_cnt):
output_token_idx = elem_id // dim_embed
input_token_idx = src_map[output_token_idx]
offset = elem_id % dim_embed
out[output_token_idx * dim_embed + offset] = input[input_token_idx * dim_embed + offset]
return out
def ref_eagle_get_self_hidden_states(input, last_seq_lens_this_time, seq_lens_this_time, step_idx):
input_token_num = input.shape[0]
dim_embed = input.shape[1]
bsz = seq_lens_this_time.shape[0]
src_map = np.full(input_token_num, -1, seq_lens_this_time.dtype)
output_token_num, src_map = computeOrder(last_seq_lens_this_time, seq_lens_this_time, step_idx, src_map, bsz)
out = np.full([output_token_num * dim_embed], -1, input.dtype)
elem_cnt = output_token_num * dim_embed
out = rebuildSelfHiddenStatesKernel(input, src_map, out, dim_embed, elem_cnt)
out = out.reshape([output_token_num, dim_embed])
return out
def test_eagle_get_self_hidden_states():
bs = np.random.randint(1, 8 + 1, dtype=np.int32)
input_token_num = np.random.randint(2 * 1024, 4 * 1024 + 1, dtype=np.int32)
dim_embed = np.random.randint(1, 4 * 1024 + 1, dtype=np.int32)
last_seq_lens_this_time = np.random.randint(0, input_token_num // bs, bs, dtype=np.int32)
seq_lens_this_time = np.random.randint(0, input_token_num // bs, bs, dtype=np.int32)
step_idx = np.arange(0, bs, dtype=np.int32)
last_seq_lens_this_time_tensor = paddle.to_tensor(last_seq_lens_this_time, dtype=paddle.int32)
seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32)
step_idx_tensor = paddle.to_tensor(step_idx, dtype=paddle.int64)
# fp32 test
input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32)
input_tensor = paddle.to_tensor(input, dtype=paddle.float32)
cpu_out = eagle_get_self_hidden_states(
input_tensor.cpu(),
last_seq_lens_this_time_tensor.cpu(),
seq_lens_this_time_tensor.cpu(),
step_idx_tensor.cpu(),
)
xpu_out = eagle_get_self_hidden_states(
input_tensor,
last_seq_lens_this_time_tensor,
seq_lens_this_time_tensor,
step_idx_tensor,
)
assert np.allclose(cpu_out.numpy(), xpu_out.numpy())
# bf16/fp16 test
for dtype in [paddle.bfloat16, paddle.float16]:
input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int16)
input_tensor = paddle.to_tensor(input, dtype=dtype)
cpu_out = eagle_get_self_hidden_states(
input_tensor.cpu(),
last_seq_lens_this_time_tensor.cpu(),
seq_lens_this_time_tensor.cpu(),
step_idx_tensor.cpu(),
)
xpu_out = eagle_get_self_hidden_states(
input_tensor,
last_seq_lens_this_time_tensor,
seq_lens_this_time_tensor,
step_idx_tensor,
)
assert np.allclose(cpu_out.numpy(), xpu_out.numpy())
print("All test passed")
if __name__ == "__main__":
test_eagle_get_self_hidden_states()

View File

@@ -0,0 +1,46 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import speculate_clear_accept_nums
np.set_printoptions(threshold=np.inf) # threshold设为无穷大
np.set_printoptions(linewidth=np.inf) # 确保一行显示完整(可选)
def speculate_clear_accept_nums_np(accept_num, seq_lens_decoder):
for i in range(len(accept_num)):
if seq_lens_decoder[i] == 0:
accept_num[i] = 0
return accept_num, seq_lens_decoder
max_bs = 1024
accept_num_np = np.random.randint(low=0, high=11, size=[max_bs], dtype="int32")
accept_num_paddle = paddle.to_tensor(accept_num_np)
seq_lens_decoder_np = np.random.randint(low=0, high=2, size=[max_bs], dtype="int32")
seq_lens_decoder_paddle = paddle.to_tensor(seq_lens_decoder_np)
a = accept_num_paddle.numpy()
# print((a - accept_num_np).sum())
assert (a - accept_num_np).sum() == 0, "Check failed."
accept_num_np, seq_lens_decoder_np = speculate_clear_accept_nums_np(accept_num_np, seq_lens_decoder_np)
seq_lens_decoder_paddle = speculate_clear_accept_nums(accept_num_paddle, seq_lens_decoder_paddle)
b = accept_num_paddle.numpy()
# print(b)
# print((accept_num_np - b).sum())
assert (accept_num_np - b).sum() == 0, "Check failed."

View File

@@ -0,0 +1,66 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
if paddle.is_compiled_with_xpu():
from fastdeploy.model_executor.ops.xpu import speculate_get_output_padding_offset
else:
from efficientllm.ops.gpu import speculate_get_output_padding_offset
def test_speculate_get_output_padding_offset():
bsz = 256
max_seq_len = 8192
seq_lens_output = np.random.randint(0, 4, size=bsz)
output_token_num = np.sum(seq_lens_output)
seq_lens_output = paddle.to_tensor(seq_lens_output, dtype="int32")
out_token_num = paddle.sum(seq_lens_output)
output_cum_offsets_tmp = paddle.cumsum(max_seq_len - seq_lens_output)
output_padding_offset_xpu, output_cum_offsets_xpu = speculate_get_output_padding_offset(
output_cum_offsets_tmp, out_token_num, seq_lens_output, max_seq_len
)
output_padding_offset_cpu = [-1] * output_token_num
output_cum_offsets_cpu = [-1] * bsz
for bi in range(bsz):
cum_offset = 0 if bi == 0 else output_cum_offsets_tmp[bi - 1]
output_cum_offsets_cpu[bi] = cum_offset
for token_i in range(seq_lens_output[bi]):
output_padding_offset_cpu[bi * max_seq_len - cum_offset + token_i] = cum_offset
# print(f"seq_lens_output: {seq_lens_output}")
# print(f"output_cum_offsets_tmp: {output_cum_offsets_tmp}")
# print(f"output_padding_offset_xpu: {output_padding_offset_xpu}")
# print(f"output_cum_offsets_xpu: {output_cum_offsets_xpu}")
# print(f"output_padding_offset_cpu: {output_padding_offset_cpu}")
# print(f"output_cum_offsets_cpu: {output_cum_offsets_cpu}")
assert np.array_equal(
output_padding_offset_xpu, output_padding_offset_cpu
), "output_padding_offset_xpu != output_padding_offset_cpu"
assert np.array_equal(
output_cum_offsets_xpu, output_cum_offsets_cpu
), "output_cum_offsets_xpu != output_cum_offsets_cpu"
print("test_speculate_get_output_padding_offset passed!")
if __name__ == "__main__":
test_speculate_get_output_padding_offset()

View File

@@ -0,0 +1,525 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import speculate_get_padding_offset
test_failed = False
def ref_speculate_get_padding_offset(cum_offsets, seq_lens, max_seq_len, token_num_data):
bsz = seq_lens.shape[0]
padding_offset = np.zeros([token_num_data], dtype=np.int32)
cum_offsets_out = np.zeros([bsz], dtype=np.int32)
cu_seqlens_q = np.zeros([bsz + 1], dtype=np.int32)
cu_seqlens_k = np.zeros([bsz + 1], dtype=np.int32)
modified_indices = {
"padding_offset": [],
"cum_offsets_out": [],
"cu_seqlens_q": [],
"cu_seqlens_k": [],
}
cu_seqlens_q[0] = 0
cu_seqlens_k[0] = 0
modified_indices["cu_seqlens_q"].append(0)
modified_indices["cu_seqlens_k"].append(0)
for bi in range(bsz):
cum_offset = 0 if bi == 0 else cum_offsets[bi - 1]
cum_offsets_out[bi] = cum_offset
modified_indices["cum_offsets_out"].append(bi)
for i in range(seq_lens[bi]):
idx = bi * max_seq_len - cum_offset + i
if idx >= 0 and idx < token_num_data:
padding_offset[idx] = cum_offset
modified_indices["padding_offset"].append(idx)
cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]
cu_seqlens_q[bi + 1] = cum_seq_len
cu_seqlens_k[bi + 1] = cum_seq_len
modified_indices["cu_seqlens_q"].append(bi + 1)
modified_indices["cu_seqlens_k"].append(bi + 1)
return (
padding_offset,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
modified_indices,
)
def test_speculate_get_padding_offset():
global test_failed
print("Testing speculate_get_padding_offset...")
test_cases = [
{
"name": "Basic test case",
"bsz": 4,
"max_seq_len": 10,
"token_num_data": 32,
"cum_offsets": np.array([2, 5, 8, 12], dtype=np.int32),
"seq_lens": np.array([8, 5, 7, 6], dtype=np.int32),
"seq_lens_encoder": np.array([1, 0, 1, 0], dtype=np.int32),
},
{
"name": "Batch copy optimization",
"bsz": 5,
"max_seq_len": 12,
"token_num_data": 50,
"cum_offsets": np.array([1, 4, 8, 13, 19], dtype=np.int32),
"seq_lens": np.array([10, 6, 8, 5, 7], dtype=np.int32),
"seq_lens_encoder": np.array([1, 0, 1, 0, 1], dtype=np.int32),
},
{
"name": "Boundary conditions",
"bsz": 3,
"max_seq_len": 8,
"token_num_data": 20,
"cum_offsets": np.array([3, 8, 14], dtype=np.int32),
"seq_lens": np.array([4, 3, 2], dtype=np.int32),
"seq_lens_encoder": np.array([1, 0, 1], dtype=np.int32),
},
{
"name": "Large sequence length",
"bsz": 2,
"max_seq_len": 2000,
"token_num_data": 3000,
"cum_offsets": np.array([100, 500], dtype=np.int32),
"seq_lens": np.array([1800, 1500], dtype=np.int32),
"seq_lens_encoder": np.array([1, 0], dtype=np.int32),
},
]
max_draft_tokens = 4
all_passed = True
for i, case in enumerate(test_cases):
print(f" Test case {i+1}: {case['name']}")
input_ids = np.random.randint(0, 1000, (case["bsz"], case["max_seq_len"]), dtype=np.int64)
draft_tokens = np.random.randint(0, 1000, (case["bsz"], max_draft_tokens), dtype=np.int64)
token_num = np.array([case["token_num_data"]], dtype=np.int64)
input_ids_tensor = paddle.to_tensor(input_ids)
draft_tokens_tensor = paddle.to_tensor(draft_tokens)
cum_offsets_tensor = paddle.to_tensor(case["cum_offsets"])
seq_lens_tensor = paddle.to_tensor(case["seq_lens"])
seq_lens_encoder_tensor = paddle.to_tensor(case["seq_lens_encoder"])
token_num_tensor = paddle.to_tensor(token_num)
(
x_remove_padding,
cum_offsets_out,
padding_offset,
cu_seqlens_q,
cu_seqlens_k,
) = speculate_get_padding_offset(
input_ids_tensor,
draft_tokens_tensor,
cum_offsets_tensor,
token_num_tensor,
seq_lens_tensor,
seq_lens_encoder_tensor,
)
(
ref_padding_offset,
ref_cum_offsets_out,
ref_cu_seqlens_q,
ref_cu_seqlens_k,
modified_indices,
) = ref_speculate_get_padding_offset(
case["cum_offsets"],
case["seq_lens"],
case["max_seq_len"],
case["token_num_data"],
)
output_arrays = {
"padding_offset": padding_offset.numpy(),
"cum_offsets_out": cum_offsets_out.numpy(),
"cu_seqlens_q": cu_seqlens_q.numpy(),
"cu_seqlens_k": cu_seqlens_k.numpy(),
}
ref_arrays = {
"padding_offset": ref_padding_offset,
"cum_offsets_out": ref_cum_offsets_out,
"cu_seqlens_q": ref_cu_seqlens_q,
"cu_seqlens_k": ref_cu_seqlens_k,
}
case_passed = True
for key in output_arrays:
modified_pos = modified_indices[key]
if case["name"] == "Large sequence length" and key == "padding_offset":
match_count = sum(1 for pos in modified_pos if output_arrays[key][pos] == ref_arrays[key][pos])
total_positions = len(modified_pos)
if match_count != total_positions:
case_passed = False
print(f" \033[91m✗ {key}: {match_count}/{total_positions} positions match\033[0m")
else:
print(f" \033[92m✓ {key}: All {total_positions} positions match\033[0m")
else:
match_count = sum(1 for pos in modified_pos if output_arrays[key][pos] == ref_arrays[key][pos])
if match_count != len(modified_pos):
case_passed = False
print(f" \033[91m✗ {key}: {match_count}/{len(modified_pos)} positions match\033[0m")
else:
print(f" \033[92m✓ {key}: {match_count}/{len(modified_pos)} positions match\033[0m")
if not case_passed:
all_passed = False
test_failed = True
if all_passed:
print("\033[92m✓ All speculate_get_padding_offset tests passed\033[0m\n")
else:
print("\033[91m✗ Some speculate_get_padding_offset tests failed\033[0m\n")
def test_speculate_get_padding_offset_edge_cases():
global test_failed
print("Testing speculate_get_padding_offset edge cases...")
print("Test case 1: Single batch")
bsz = 1
max_seq_len = 10
token_num_data = 10
max_draft_tokens = 3
input_ids = np.random.randint(0, 1000, (bsz, max_seq_len), dtype=np.int64)
draft_tokens = np.random.randint(0, 1000, (bsz, max_draft_tokens), dtype=np.int64)
cum_offsets = np.array([3], dtype=np.int32)
seq_lens = np.array([7], dtype=np.int32)
seq_lens_encoder = np.array([1], dtype=np.int32)
token_num = np.array([token_num_data], dtype=np.int64)
input_ids_tensor = paddle.to_tensor(input_ids)
draft_tokens_tensor = paddle.to_tensor(draft_tokens)
cum_offsets_tensor = paddle.to_tensor(cum_offsets)
seq_lens_tensor = paddle.to_tensor(seq_lens)
seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder)
token_num_tensor = paddle.to_tensor(token_num)
try:
(
x_remove_padding,
cum_offsets_out,
padding_offset,
cu_seqlens_q,
cu_seqlens_k,
) = speculate_get_padding_offset(
input_ids_tensor,
draft_tokens_tensor,
cum_offsets_tensor,
token_num_tensor,
seq_lens_tensor,
seq_lens_encoder_tensor,
)
print(
f"\033[92m✓ Test case 1 passed, shapes: {[x.shape for x in [x_remove_padding, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k]]}\033[0m"
)
except Exception as e:
print(f"\033[91m✗ Test case 1 failed: {e}\033[0m")
test_failed = True
print("Test case 2: Large batch")
bsz = 8
max_seq_len = 16
token_num_data = 100
input_ids = np.random.randint(0, 1000, (bsz, max_seq_len), dtype=np.int64)
draft_tokens = np.random.randint(0, 1000, (bsz, max_draft_tokens), dtype=np.int64)
cum_offsets = np.array([1, 3, 6, 10, 15, 21, 28, 36], dtype=np.int32)
seq_lens = np.random.randint(1, max_seq_len, bsz).astype(np.int32)
seq_lens_encoder = np.random.randint(0, 2, bsz).astype(np.int32)
token_num = np.array([token_num_data], dtype=np.int64)
input_ids_tensor = paddle.to_tensor(input_ids)
draft_tokens_tensor = paddle.to_tensor(draft_tokens)
cum_offsets_tensor = paddle.to_tensor(cum_offsets)
seq_lens_tensor = paddle.to_tensor(seq_lens)
seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder)
token_num_tensor = paddle.to_tensor(token_num)
try:
(
x_remove_padding,
cum_offsets_out,
padding_offset,
cu_seqlens_q,
cu_seqlens_k,
) = speculate_get_padding_offset(
input_ids_tensor,
draft_tokens_tensor,
cum_offsets_tensor,
token_num_tensor,
seq_lens_tensor,
seq_lens_encoder_tensor,
)
print(
f"\033[92m✓ Test case 2 passed, shapes: {[x.shape for x in [x_remove_padding, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k]]}\033[0m"
)
except Exception as e:
print(f"\033[91m✗ Test case 2 failed: {e}\033[0m")
test_failed = True
print("Test case 3: Small sequences")
bsz = 3
max_seq_len = 5
token_num_data = 12
input_ids = np.random.randint(0, 1000, (bsz, max_seq_len), dtype=np.int64)
draft_tokens = np.random.randint(0, 1000, (bsz, max_draft_tokens), dtype=np.int64)
cum_offsets = np.array([1, 2, 4], dtype=np.int32)
seq_lens = np.array([2, 3, 1], dtype=np.int32)
seq_lens_encoder = np.array([1, 0, 1], dtype=np.int32)
token_num = np.array([token_num_data], dtype=np.int64)
input_ids_tensor = paddle.to_tensor(input_ids)
draft_tokens_tensor = paddle.to_tensor(draft_tokens)
cum_offsets_tensor = paddle.to_tensor(cum_offsets)
seq_lens_tensor = paddle.to_tensor(seq_lens)
seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder)
token_num_tensor = paddle.to_tensor(token_num)
try:
(
x_remove_padding,
cum_offsets_out,
padding_offset,
cu_seqlens_q,
cu_seqlens_k,
) = speculate_get_padding_offset(
input_ids_tensor,
draft_tokens_tensor,
cum_offsets_tensor,
token_num_tensor,
seq_lens_tensor,
seq_lens_encoder_tensor,
)
print(
f"\033[92m✓ Test case 3 passed, shapes: {[x.shape for x in [x_remove_padding, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k]]}\033[0m\n"
)
except Exception as e:
print(f"\033[91m✗ Test case 3 failed: {e}\033[0m\n")
test_failed = True
def test_large_scale():
global test_failed
print("Testing large scale data...")
bsz = 32
max_seq_len = 128
token_num_data = 2048
max_draft_tokens = 16
input_ids = np.random.randint(0, 1000, (bsz, max_seq_len), dtype=np.int64)
draft_tokens = np.random.randint(0, 1000, (bsz, max_draft_tokens), dtype=np.int64)
cum_offsets = np.cumsum(np.random.randint(1, 20, bsz)).astype(np.int32)
seq_lens = np.random.randint(1, max_seq_len, bsz).astype(np.int32)
seq_lens_encoder = np.random.randint(0, 2, bsz).astype(np.int32)
token_num = np.array([token_num_data], dtype=np.int64)
input_ids_tensor = paddle.to_tensor(input_ids)
draft_tokens_tensor = paddle.to_tensor(draft_tokens)
cum_offsets_tensor = paddle.to_tensor(cum_offsets)
seq_lens_tensor = paddle.to_tensor(seq_lens)
seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder)
token_num_tensor = paddle.to_tensor(token_num)
try:
(
x_remove_padding,
cum_offsets_out,
padding_offset,
cu_seqlens_q,
cu_seqlens_k,
) = speculate_get_padding_offset(
input_ids_tensor,
draft_tokens_tensor,
cum_offsets_tensor,
token_num_tensor,
seq_lens_tensor,
seq_lens_encoder_tensor,
)
print("\033[92m✓ Large scale speculate_get_padding_offset test passed\033[0m")
print(
f"\033[92m Shapes: {[x.shape for x in [x_remove_padding, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k]]}\033[0m\n"
)
except Exception as e:
print(f"\033[91m✗ Large scale speculate_get_padding_offset test failed: {e}\033[0m\n")
test_failed = True
def get_modified_indices_for_consistency_test(cum_offsets, seq_lens, max_seq_len, token_num_data):
bsz = seq_lens.shape[0]
modified_indices = {
"x_remove_padding": [],
"padding_offset": [],
"cum_offsets_out": [],
"cu_seqlens_q": [],
"cu_seqlens_k": [],
}
for bi in range(bsz):
modified_indices["cum_offsets_out"].append(bi)
for i in range(bsz + 1):
modified_indices["cu_seqlens_q"].append(i)
modified_indices["cu_seqlens_k"].append(i)
for bi in range(bsz):
cum_offset = 0 if bi == 0 else cum_offsets[bi - 1]
for i in range(seq_lens[bi]):
padding_idx = bi * max_seq_len - cum_offset + i
if padding_idx >= 0 and padding_idx < token_num_data:
modified_indices["padding_offset"].append(padding_idx)
remove_padding_idx = bi * max_seq_len - cum_offsets[bi] + i
if remove_padding_idx >= 0 and remove_padding_idx < token_num_data:
modified_indices["x_remove_padding"].append(remove_padding_idx)
return modified_indices
def test_consistency():
global test_failed
print("Testing consistency...")
np.random.seed(42)
bsz = 4
max_seq_len = 8
token_num_data = 24
max_draft_tokens = 3
input_ids = np.random.randint(0, 1000, (bsz, max_seq_len), dtype=np.int64)
draft_tokens = np.random.randint(0, 1000, (bsz, max_draft_tokens), dtype=np.int64)
cum_offsets = np.array([1, 3, 6, 10], dtype=np.int32)
seq_lens = np.array([6, 4, 5, 3], dtype=np.int32)
seq_lens_encoder = np.array([1, 0, 1, 0], dtype=np.int32)
token_num = np.array([token_num_data], dtype=np.int64)
input_ids_tensor = paddle.to_tensor(input_ids)
draft_tokens_tensor = paddle.to_tensor(draft_tokens)
cum_offsets_tensor = paddle.to_tensor(cum_offsets)
seq_lens_tensor = paddle.to_tensor(seq_lens)
seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder)
token_num_tensor = paddle.to_tensor(token_num)
modified_indices = get_modified_indices_for_consistency_test(cum_offsets, seq_lens, max_seq_len, token_num_data)
print("Checking consistency for modified positions only:")
for key, indices in modified_indices.items():
print(f" {key}: {len(indices)} positions")
results = []
for run in range(3):
(
x_remove_padding,
cum_offsets_out,
padding_offset,
cu_seqlens_q,
cu_seqlens_k,
) = speculate_get_padding_offset(
input_ids_tensor,
draft_tokens_tensor,
cum_offsets_tensor,
token_num_tensor,
seq_lens_tensor,
seq_lens_encoder_tensor,
)
results.append(
[
x_remove_padding.numpy(),
cum_offsets_out.numpy(),
padding_offset.numpy(),
cu_seqlens_q.numpy(),
cu_seqlens_k.numpy(),
]
)
output_names = [
"x_remove_padding",
"cum_offsets_out",
"padding_offset",
"cu_seqlens_q",
"cu_seqlens_k",
]
consistent = True
for j, name in enumerate(output_names):
indices = modified_indices[name] if name in modified_indices else []
if not indices:
print(f"\033[93m ~ {name}: No modified indices to check\033[0m")
continue
positions_consistent = True
for i in range(1, len(results)):
for idx in indices:
if results[0][j][idx] != results[i][j][idx]:
consistent = False
positions_consistent = False
print(
f"\033[91m ✗ {name}[{idx}]: Run 1 = {results[0][j][idx]}, Run {i+1} = {results[i][j][idx]}\033[0m"
)
break
if not positions_consistent:
break
if positions_consistent:
print(f"\033[92m ✓ {name}: All {len(indices)} modified positions are consistent\033[0m")
if consistent:
print(
"\033[92m✓ Consistency test passed - results are identical across runs (modified positions only)\033[0m\n"
)
else:
print("\033[91m✗ Consistency test failed - some modified positions are inconsistent\033[0m\n")
print("Note: This test now only compares positions that the kernel actually modifies,")
print(" ignoring uninitialized values in other positions.\n")
test_failed = True
if __name__ == "__main__":
print("=" * 60)
print("Testing Speculate Get Padding Offset Kernels")
print("=" * 60)
test_speculate_get_padding_offset()
test_speculate_get_padding_offset_edge_cases()
test_large_scale()
test_consistency()
print("=" * 60)
if test_failed:
print("\033[91mSOME TESTS FAILED! \033[0m")
print("\033[91mPlease check the output above for failed test details.\033[0m")
else:
print("\033[92mALL TESTS PASSED! \033[0m")
print("\033[92mAll speculate_get_padding_offset kernels are working correctly.\033[0m")
print("=" * 60)

View File

@@ -0,0 +1,105 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import speculate_get_seq_lens_output # 假设已编译并导入
def run_seq_lens_test(device="cpu"):
"""运行序列长度测试函数"""
paddle.seed(42)
np.random.seed(42)
if device == "cpu":
paddle.set_device(device)
elif device == "xpu":
paddle.set_device(device)
else:
raise ValueError(f"Invalid device: {device}")
# 创建不同尺寸的随机测试数据
batch_sizes = [1, 4, 16, 64, 128, 192, 256]
results = []
test_times = 100
for _ in range(test_times):
for bsz in batch_sizes:
# 生成随机输入张量
seq_lens_this_time = paddle.randint(0, 10, shape=(bsz,), dtype="int32")
seq_lens_encoder = paddle.randint(0, 10, shape=(bsz,), dtype="int32")
seq_lens_decoder = paddle.randint(0, 10, shape=(bsz,), dtype="int32")
# 记录输入值用于调试
input_values = [
seq_lens_this_time.numpy().copy(),
seq_lens_encoder.numpy().copy(),
seq_lens_decoder.numpy().copy(),
]
# 运行算子
seq_lens_output = speculate_get_seq_lens_output(seq_lens_this_time, seq_lens_encoder, seq_lens_decoder)[0]
# 收集结果
results.append((input_values, seq_lens_output.numpy()))
return results
if __name__ == "__main__":
print("\n运行XPU测试...")
xpu_results = run_seq_lens_test("xpu")
print("运行CPU测试...")
cpu_results = run_seq_lens_test("cpu")
print("\n比较结果...")
all_pass = True
# 逐个批次比较结果
for i, (cpu_data, xpu_data) in enumerate(zip(cpu_results, xpu_results)):
# 解包数据
cpu_inputs, cpu_output = cpu_data
xpu_inputs, xpu_output = xpu_data
# 比较输入数据是否相同
for j in range(3):
if not np.array_equal(cpu_inputs[j], xpu_inputs[j]):
print(f"错误: 批次 #{i+1} 输入 {j} 不同 (CPU vs XPU)")
print(f"CPU输入: {cpu_inputs[j]}")
print(f"XPU输入: {xpu_inputs[j]}")
all_pass = False
# 比较输出结果是否相同
if not np.array_equal(cpu_output, xpu_output):
print(f"\n错误: 批次 #{i+1} 输出不同 (CPU vs XPU)")
print(f"CPU输出: {cpu_output}")
print(f"XPU输出: {xpu_output}")
# 打印差异详情
diff_indices = np.where(cpu_output != xpu_output)[0]
for idx in diff_indices:
print(f"索引 {idx}: CPU输出={cpu_output[idx]}, XPU输出={xpu_output[idx]}")
print(
f"对应输入: this_time={cpu_inputs[0][idx]}, "
f"encoder={cpu_inputs[1][idx]}, decoder={cpu_inputs[2][idx]}"
)
all_pass = False
else:
print(f"批次 #{i+1} 结果匹配")
if all_pass:
print("\n所有测试通过! CPU和XPU结果完全一致")
else:
print("\n测试失败: 发现不一致的结果")
exit(1)

View File

@@ -0,0 +1,206 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import speculate_get_token_penalty_multi_scores
paddle.seed(2023)
def allclose_any(a, b, rtol=1e-5, atol=1e-5, equal_nan=False):
"""检查两个数组是否满足任意一个容差条件"""
condition = (np.abs(a - b) <= atol) | (np.abs(a - b) <= rtol * np.abs(b)) # 绝对误差条件 # 相对误差条件
print(f"cond={condition}")
# 处理 NaN如果需要
if equal_nan:
nan_mask = np.isnan(a) & np.isnan(b)
condition = condition | nan_mask
# 检查所有元素是否都满足条件
return np.all(condition)
def find_max_diff(arr1, arr2):
"""找出两个数组元素差值的最大值及其索引
返回:
max_diff (float): 最大绝对值差
index (tuple): 最大值的位置索引
actual_diff (float): 实际差值(带符号)
"""
diff = arr1 - arr2
abs_diff = np.abs(diff)
flat_idx = np.argmax(abs_diff)
idx = np.unravel_index(flat_idx, arr1.shape)
return abs_diff[idx], idx, diff[idx], arr1[idx], arr2[idx]
def test_main(
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id,
seq_len_this_time,
output_padding_offset,
output_cum_offsets,
max_seq_len,
):
pre_ids_ref = pre_ids.cpu()
logits_ref = logits.cpu()
penalty_scores_ref = penalty_scores.cpu()
frequency_scores_ref = frequency_scores.cpu()
presence_scores_ref = presence_scores.cpu()
temperatures_ref = temperatures.cpu()
bad_tokens_ref = bad_tokens.cpu()
cur_len_ref = cur_len.cpu()
min_len_ref = min_len.cpu()
eos_token_id_ref = eos_token_id.cpu()
seq_len_this_time_ref = seq_len_this_time.cpu()
output_padding_offset_ref = output_padding_offset.cpu()
output_cum_offsets_ref = output_cum_offsets.cpu()
speculate_get_token_penalty_multi_scores(
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id,
seq_len_this_time,
output_padding_offset,
output_cum_offsets,
max_seq_len,
)
speculate_get_token_penalty_multi_scores(
pre_ids_ref,
logits_ref,
penalty_scores_ref,
frequency_scores_ref,
presence_scores_ref,
temperatures_ref,
bad_tokens_ref,
cur_len_ref,
min_len_ref,
eos_token_id_ref,
seq_len_this_time_ref,
output_padding_offset_ref,
output_cum_offsets_ref,
max_seq_len,
)
logits_ref_np = logits_ref.astype("float32").numpy()
logits_np = logits.astype("float32").numpy()
np.set_printoptions(threshold=10000)
# print(f"logits_ref={logits_ref_np[:50,:100]}")
# print(f"logits={logits_np[:50,:100]}")
diff_logits = np.sum(np.abs(logits_ref_np - logits_np))
print("diff_logits\n", diff_logits)
abs_diff, idx, diff, val1, val2 = find_max_diff(logits_ref_np, logits_np)
print(f"abs_diff={abs_diff}, index={idx}, diff={diff}, {val1} vs {val2}")
assert allclose_any(logits_ref_np, logits_np, 1e-5, 1e-5)
# assert np.allclose(logits_ref_np, logits_np, 1e-5, 1e-5)
# gtest_speculate_token_penalty_multi_scores<float>(api::kXPU3, "GM", "GM", "GM", "GM", "GM", "GM", "GM", "GM", "GM", "GM", "GM", "GM",
# 84, 100352, 12288, 1, 1, 54, 32768);
def miain():
seed = np.random.randint(1, 1e9)
print(f"random seed is {seed}")
np.random.seed(seed)
bs = 64
max_seq_len = 32768 # 1024 #2048 #8192
data_type = "float32" # bfloat16 or float32
# prepare output_padding_offset and output_cum_offsets
tokens = [1] * bs
token_num = np.sum(tokens)
print(f"bs={bs}, tokens={tokens}, token_num={token_num}")
output_padding_offset = []
output_cum_offsets = [0]
opo_offset = 0
for bid in range(bs):
ts = tokens[bid]
for i in range(ts):
output_padding_offset.append(opo_offset)
opo_offset += max_seq_len - ts
output_cum_offsets.append(opo_offset)
output_cum_offsets = output_cum_offsets[:-1]
# print(f"output_padding_offset={output_padding_offset}")
# print(f"output_cum_offsets={output_cum_offsets}")
output_padding_offset = paddle.to_tensor(output_padding_offset, "int32")
output_cum_offsets = paddle.to_tensor(output_cum_offsets, "int32")
# prepare pre_ids and logits
pre_ids_len = 12288
# pre_ids_len = np.random.randint(1, 512)
logits_len = 100352
# print(f"pre_ids_len={pre_ids_len}, logits_len={logits_len}")
pre_ids = np.random.randint(1, logits_len, size=(bs, pre_ids_len))
negative_start = np.random.randint(1, pre_ids_len + 1, size=(bs))
print(negative_start)
for i in range(bs):
pre_ids[:, negative_start[i] :] = -1
pre_ids = paddle.to_tensor(pre_ids).astype("int64")
# logits = paddle.to_tensor(
# np.float32(np.random.random([token_num, logits_len]))
# ).astype(data_type)
logits = paddle.to_tensor(np.float32(np.zeros([token_num, logits_len]))).astype(data_type)
# prepare other params
penalty_scores = paddle.to_tensor(np.random.random([bs])).astype(data_type)
frequency_scores = paddle.to_tensor(np.random.random([bs])).astype(data_type)
presence_scores = paddle.to_tensor(np.random.random([bs])).astype(data_type)
temperatures = paddle.to_tensor(np.random.random([bs])).astype("float32")
bad_tokens = paddle.to_tensor(np.random.randint(0, 101, size=(1))).astype("int64")
cur_len = paddle.to_tensor(np.random.randint(1, 50, size=(bs))).astype("int64")
min_len = paddle.to_tensor(np.random.randint(1, 50, size=(bs))).astype("int64")
eos_token_id = paddle.to_tensor(np.random.randint(1, 101, size=(1))).astype("int64")
seq_len_this_time = paddle.to_tensor(
np.random.randint(0, 1, size=(bs)), "int32"
) # value of seq_len_this_time is useless
# test
test_main(
pre_ids,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id,
seq_len_this_time,
output_padding_offset,
output_cum_offsets,
max_seq_len,
)
if __name__ == "__main__":
for i in range(10):
miain()

View File

@@ -0,0 +1,132 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import speculate_rebuild_append_padding
def ref_speculate_rebuild_append_padding(
full_hidden_states,
cum_offsets,
seq_len_encoder,
seq_len_decoder,
output_padding_offset,
max_seq_len,
):
dim_embed = full_hidden_states.shape[1]
output_token_num = output_padding_offset.shape[0]
elem_nums = output_token_num * dim_embed
out = np.zeros(output_token_num * dim_embed, dtype=full_hidden_states.dtype)
full_hidden_states_flatten = full_hidden_states.flatten()
cum_offsets_flatten = cum_offsets.flatten()
seq_len_encoder_flatten = seq_len_encoder.flatten()
seq_len_decoder_flatten = seq_len_decoder.flatten()
output_padding_offset_flatten = output_padding_offset.flatten()
for i in range(elem_nums):
out_token_id = i // dim_embed
ori_token_id = out_token_id + output_padding_offset_flatten[out_token_id]
bi = ori_token_id // max_seq_len
seq_id = 0
if seq_len_decoder_flatten[bi] == 0 and seq_len_encoder_flatten[bi] == 0:
continue
elif seq_len_encoder_flatten[bi] != 0:
seq_id = seq_len_encoder[bi] - 1
input_token_id = ori_token_id - cum_offsets_flatten[bi] + seq_id
bias_idx = i % dim_embed
out[i] = full_hidden_states_flatten[input_token_id * dim_embed + bias_idx]
out = np.reshape(out, (output_token_num, dim_embed))
return out
def test_speculate_rebuild_append_padding():
bs = np.random.randint(1, 4 + 1, dtype=np.int32)
max_seq_len = 1 * 1024
dim_embed = np.random.randint(1, 4 * 1024 + 1, dtype=np.int32)
seq_lens = []
for _ in range(bs):
seq_lens.append(np.random.randint(1, max_seq_len + 1, dtype=np.int32))
seq_lens = np.asarray(seq_lens)
cum_offsets = np.cumsum(np.asarray(max_seq_len) - seq_lens)
cum_offsets = np.insert(cum_offsets, 0, 0)
output_padding_offsets = []
for i in range(bs):
offset = cum_offsets[i]
for j in range(seq_lens[i]):
output_padding_offsets.append(offset)
output_padding_offsets = np.asarray(output_padding_offsets)
# TODO: seq_len_encoder with non-zero element
seq_len_decoder = np.random.randint(0, 2 + 1, bs, dtype=np.int32)
seq_len_encoder_zeros = np.zeros(bs, dtype=np.int32)
for dtype in [paddle.bfloat16, paddle.float16]:
full_hidden_states = np.random.randint(0, 10, (np.sum(seq_lens), dim_embed), dtype=np.int16)
full_hidden_states_tensor = paddle.to_tensor(full_hidden_states, dtype=dtype)
cum_offsets_tensor = paddle.to_tensor(cum_offsets, dtype=paddle.int32)
seq_len_encoder_zeros_tensor = paddle.to_tensor(seq_len_encoder_zeros, dtype=paddle.int32)
seq_len_decoder_tensor = paddle.to_tensor(seq_len_decoder, dtype=paddle.int32)
output_padding_offsets_tensor = paddle.to_tensor(output_padding_offsets, dtype=paddle.int32)
cpu_out = speculate_rebuild_append_padding(
full_hidden_states_tensor.cpu(),
cum_offsets_tensor.cpu(),
seq_len_encoder_zeros_tensor.cpu(),
seq_len_decoder_tensor.cpu(),
output_padding_offsets_tensor.cpu(),
max_seq_len,
)
xpu_out = speculate_rebuild_append_padding(
full_hidden_states_tensor,
cum_offsets_tensor,
seq_len_encoder_zeros_tensor,
seq_len_decoder_tensor,
output_padding_offsets_tensor,
max_seq_len,
)
assert np.allclose(cpu_out.numpy(), xpu_out.numpy())
for dtype in [paddle.float32]:
full_hidden_states = np.random.randint(0, 10, (np.sum(seq_lens), dim_embed), dtype=np.int32)
full_hidden_states_tensor = paddle.to_tensor(full_hidden_states, dtype=dtype)
cum_offsets_tensor = paddle.to_tensor(cum_offsets, dtype=paddle.int32)
seq_len_encoder_zeros_tensor = paddle.to_tensor(seq_len_encoder_zeros, dtype=paddle.int32)
seq_len_decoder_tensor = paddle.to_tensor(seq_len_decoder, dtype=paddle.int32)
output_padding_offsets_tensor = paddle.to_tensor(output_padding_offsets, dtype=paddle.int32)
cpu_out = speculate_rebuild_append_padding(
full_hidden_states_tensor.cpu(),
cum_offsets_tensor.cpu(),
seq_len_encoder_zeros_tensor.cpu(),
seq_len_decoder_tensor.cpu(),
output_padding_offsets_tensor.cpu(),
max_seq_len,
)
xpu_out = speculate_rebuild_append_padding(
full_hidden_states_tensor,
cum_offsets_tensor,
seq_len_encoder_zeros_tensor,
seq_len_decoder_tensor,
output_padding_offsets_tensor,
max_seq_len,
)
assert np.allclose(cpu_out.numpy(), xpu_out.numpy())
print("All test passed")
if __name__ == "__main__":
test_speculate_rebuild_append_padding()

View File

@@ -0,0 +1,307 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import speculate_set_stop_value_multi_seqs
def compare_results(cpu_results, xpu_results):
# Compare all outputs
for key in cpu_results:
if key in ["output_accept_tokens", "output_stop_flags"]:
np.testing.assert_array_equal(
cpu_results[key],
xpu_results[key],
err_msg=f"{key} mismatch between CPU and GPU",
)
print("CPU and GPU results match!")
class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
def setUp(self):
self.place = paddle.device.XPUPlace(0)
def run_op(
self,
device,
accept_tokens,
accept_num,
pre_ids,
step_idx,
stop_flags,
seq_lens,
stop_seqs,
stop_seqs_len,
end_ids,
):
if device == "cpu":
accept_tokens = accept_tokens.cpu()
accept_num = accept_num.cpu()
pre_ids = pre_ids.cpu()
step_idx = step_idx.cpu()
stop_flags = stop_flags.cpu()
seq_lens = seq_lens.cpu()
stop_seqs = stop_seqs.cpu()
stop_seqs_len = stop_seqs_len.cpu()
end_ids = end_ids.cpu()
accept_tokens_out = accept_tokens.clone()
stop_flags_out = stop_flags.clone()
speculate_set_stop_value_multi_seqs(
accept_tokens_out,
accept_num,
pre_ids,
step_idx,
stop_flags_out,
seq_lens,
stop_seqs,
stop_seqs_len,
end_ids,
)
# Return results for comparison
results = {
"accept_tokens": accept_tokens.numpy(),
"accept_num": accept_num.numpy(),
"pre_ids": pre_ids.numpy(),
"step_idx": step_idx.numpy(),
"stop_flags": stop_flags.numpy(),
"output_accept_tokens": accept_tokens_out.numpy(),
"output_stop_flags": stop_flags_out.numpy(),
}
return results
def test_basic_functionality(self):
# Test basic functionality with one sequence matching stop sequence
import paddle
accept_tokens = paddle.to_tensor(
[
[4, 5, 0, 0, 0], # batch 0
[1, 2, 3, 0, 0], # batch 1 (不匹配)
],
dtype="int64",
)
accept_num = paddle.to_tensor([3, 4], dtype="int32")
pre_ids = paddle.to_tensor(
[
[7, 8, 9, 3, 4, 5], # batch 0
[7, 8, 9, 1, 2, 3], # batch 1
],
dtype="int64",
)
step_idx = paddle.to_tensor([6, 6], dtype="int64") # pre_ids最后一位为下标5
stop_flags = paddle.to_tensor([False, False], dtype="bool")
seq_lens = paddle.to_tensor([6, 6], dtype="int32")
stop_seqs = paddle.to_tensor(
[
[3, 4, 5], # batch 0
[0, 0, 0], # batch 1
],
dtype="int64",
)
stop_seqs_len = paddle.to_tensor([3, 0], dtype="int32")
end_ids = paddle.to_tensor([-1], dtype="int64")
# Run operator
xpu_results = self.run_op(
"xpu",
accept_tokens,
accept_num,
pre_ids,
step_idx,
stop_flags,
seq_lens,
stop_seqs,
stop_seqs_len,
end_ids,
)
cpu_results = self.run_op(
"cpu",
accept_tokens,
accept_num,
pre_ids,
step_idx,
stop_flags,
seq_lens,
stop_seqs,
stop_seqs_len,
end_ids,
)
compare_results(cpu_results, xpu_results)
# Verify results
expected_accept_tokens = np.array([[4, 5, -1, 0, 0], [1, 2, 3, 0, 0]])
expected_stop_flags = np.array([True, False])
np.testing.assert_array_equal(xpu_results["output_accept_tokens"], expected_accept_tokens)
np.testing.assert_array_equal(xpu_results["output_stop_flags"], expected_stop_flags)
def test_no_match(self):
# Test case where no stop sequence matches
# Input tensors
accept_tokens = paddle.to_tensor(
[[10, 20, 30, 0, 0], [40, 50, 60, 0, 0]],
dtype="int64",
place=self.place,
)
accept_num = paddle.to_tensor([3, 3], dtype="int32", place=self.place)
pre_ids = paddle.to_tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], dtype="int64", place=self.place)
step_idx = paddle.to_tensor([8, 8], dtype="int64", place=self.place)
stop_flags = paddle.to_tensor([False, False], dtype="bool", place=self.place)
seq_lens = paddle.to_tensor([10, 10], dtype="int32", place=self.place)
# Stop sequences that don't match
stop_seqs = paddle.to_tensor([[11, 12, 13], [14, 15, 16]], dtype="int64", place=self.place)
stop_seqs_len = paddle.to_tensor([3, 3], dtype="int32", place=self.place)
end_ids = paddle.to_tensor([-1], dtype="int64", place=self.place)
# Run operator
xpu_results = self.run_op(
"xpu",
accept_tokens,
accept_num,
pre_ids,
step_idx,
stop_flags,
seq_lens,
stop_seqs,
stop_seqs_len,
end_ids,
)
cpu_results = self.run_op(
"cpu",
accept_tokens,
accept_num,
pre_ids,
step_idx,
stop_flags,
seq_lens,
stop_seqs,
stop_seqs_len,
end_ids,
)
compare_results(cpu_results, xpu_results)
# Verify nothing changed
np.testing.assert_array_equal(xpu_results["output_accept_tokens"], accept_tokens.numpy())
np.testing.assert_array_equal(xpu_results["output_stop_flags"], stop_flags.numpy())
def test_partial_match(self):
# Test case where only part of the sequence matches
# Input tensors
accept_tokens = paddle.to_tensor([[10, 20, 30, 0, 0]], dtype="int64", place=self.place)
accept_num = paddle.to_tensor([3], dtype="int32", place=self.place)
pre_ids = paddle.to_tensor([[1, 2, 3, 4, 5]], dtype="int64", place=self.place)
step_idx = paddle.to_tensor([8], dtype="int64", place=self.place)
stop_flags = paddle.to_tensor([False], dtype="bool", place=self.place)
seq_lens = paddle.to_tensor([10], dtype="int32", place=self.place)
# Stop sequence that partially matches
stop_seqs = paddle.to_tensor(
[[5, 4, 99]], # Only 5,4 matches (from pre_ids), 99 doesn't
dtype="int64",
place=self.place,
)
stop_seqs_len = paddle.to_tensor([3], dtype="int32", place=self.place)
end_ids = paddle.to_tensor([-1], dtype="int64", place=self.place)
# Run operator
xpu_results = self.run_op(
"xpu",
accept_tokens,
accept_num,
pre_ids,
step_idx,
stop_flags,
seq_lens,
stop_seqs,
stop_seqs_len,
end_ids,
)
cpu_results = self.run_op(
"cpu",
accept_tokens,
accept_num,
pre_ids,
step_idx,
stop_flags,
seq_lens,
stop_seqs,
stop_seqs_len,
end_ids,
)
compare_results(cpu_results, xpu_results)
# Verify nothing changed
np.testing.assert_array_equal(xpu_results["output_accept_tokens"], accept_tokens.numpy())
np.testing.assert_array_equal(xpu_results["output_stop_flags"], stop_flags.numpy())
def test_already_stopped(self):
# Test case where sequence is already stopped
# Input tensors
accept_tokens = paddle.to_tensor([[10, 20, 30, 0, 0]], dtype="int64", place=self.place)
accept_num = paddle.to_tensor([3], dtype="int32", place=self.place)
pre_ids = paddle.to_tensor([[1, 2, 3, 4, 5]], dtype="int64", place=self.place)
step_idx = paddle.to_tensor([8], dtype="int64", place=self.place)
stop_flags = paddle.to_tensor([True], dtype="bool", place=self.place) # Already stopped
seq_lens = paddle.to_tensor([10], dtype="int32", place=self.place)
# Stop sequence that would match
stop_seqs = paddle.to_tensor([[5, 4, 3]], dtype="int64", place=self.place)
stop_seqs_len = paddle.to_tensor([3], dtype="int32", place=self.place)
end_ids = paddle.to_tensor([-1], dtype="int64", place=self.place)
# Run operator
xpu_results = self.run_op(
"xpu",
accept_tokens,
accept_num,
pre_ids,
step_idx,
stop_flags,
seq_lens,
stop_seqs,
stop_seqs_len,
end_ids,
)
cpu_results = self.run_op(
"cpu",
accept_tokens,
accept_num,
pre_ids,
step_idx,
stop_flags,
seq_lens,
stop_seqs,
stop_seqs_len,
end_ids,
)
compare_results(cpu_results, xpu_results)
# Verify nothing changed
np.testing.assert_array_equal(xpu_results["output_accept_tokens"], accept_tokens.numpy())
np.testing.assert_array_equal(xpu_results["output_stop_flags"], stop_flags.numpy())
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,83 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
if paddle.is_compiled_with_xpu():
from fastdeploy.model_executor.ops.xpu import speculate_set_value_by_flags_and_idx
else:
from efficientllm.ops.gpu import speculate_set_value_by_flags_and_idx
def test_speculate_set_value_by_flags_and_idx():
# 将accept_tokens添加到pre_ids的特定位置
bs = 256
length = 8192
max_draft_tokens = 4
pre_ids_all = paddle.to_tensor(np.full((bs, length), -1), dtype="int64")
accept_tokens = np.random.randint(100, 200, size=(bs, max_draft_tokens))
accept_tokens = paddle.to_tensor(accept_tokens, dtype="int64")
accept_num = np.random.randint(0, max_draft_tokens + 1, size=bs)
accept_num = paddle.to_tensor(accept_num, dtype="int32")
stop_flags = np.random.choice([True, False, False, False], size=bs)
stop_flags = paddle.to_tensor(stop_flags, dtype="bool")
seq_lens_this_time = paddle.to_tensor(np.full((bs), 1), dtype="int32")
seq_lens_encoder = paddle.to_tensor(np.full((bs), 0), dtype="int32")
seq_lens_decoder = paddle.to_tensor(np.full((bs), 2), dtype="int32")
step_idx = np.random.randint(max_draft_tokens, length, size=bs)
step_idx = paddle.to_tensor(step_idx, dtype="int64")
out_xpu = speculate_set_value_by_flags_and_idx(
pre_ids_all,
accept_tokens,
accept_num,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
)
out_xpu = out_xpu.numpy()
out_cpu = paddle.to_tensor(np.full((bs, length), -1), dtype="int64")
for i in range(bs):
if stop_flags[i] or (seq_lens_encoder[i] == 0 and seq_lens_decoder[i] == 0):
continue
if step_idx[i] >= 0:
for j in range(accept_num[i]):
out_cpu[i, step_idx[i] - j] = accept_tokens[i, accept_num[i] - 1 - j]
# print(f"accept_tokens: {accept_tokens}")
# print(f"accept_num: {accept_num}")
# print(f"stop_flags: {stop_flags}")
# print(f"seq_lens_this_time: {seq_lens_this_time}")
# print(f"seq_lens_encoder: {seq_lens_encoder}")
# print(f"seq_lens_decoder: {seq_lens_decoder}")
# print(f"step_idx: {step_idx}")
# print(f"out_xpu: {out_xpu}")
# print(f"out_cpu: {out_cpu}")
assert np.array_equal(out_xpu, out_cpu), "out_xpu != out_cpu"
print("test_speculate_set_value_by_flags_and_idx passed!")
if __name__ == "__main__":
test_speculate_set_value_by_flags_and_idx()

View File

@@ -0,0 +1,210 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
# tests/test_speculate_update_v3.py
import paddle
from fastdeploy.model_executor.ops.xpu import speculate_update_v3
# ---------------- NumPy 参考实现 ----------------
def speculate_update_v3_np(
seq_lens_encoder,
seq_lens_decoder,
not_need_stop,
draft_tokens,
actual_draft_token_nums,
accept_tokens,
accept_num,
stop_flags,
seq_lens_this_time,
is_block_step,
stop_nums,
):
"""
完全复现 CPU / CUDA 逻辑的 NumPy 参考版本(就地修改)。
"""
stop_sum = 0
real_bsz = seq_lens_this_time.shape[0]
max_bsz = stop_flags.shape[0]
max_draft_tokens = draft_tokens.shape[1]
for bid in range(max_bsz):
stop_flag_now_int = 0
inactive = bid >= real_bsz
block_step = (not inactive) and is_block_step[bid]
if (not block_step) and (not inactive):
if stop_flags[bid]:
stop_flag_now_int = 1
# encoder 长度为 0 时直接累加 decoder
if seq_lens_encoder[bid] == 0:
seq_lens_decoder[bid] += accept_num[bid]
# draft 长度自适应
if (seq_lens_encoder[bid] == 0) and (seq_lens_this_time[bid] > 1):
cur_len = actual_draft_token_nums[bid]
if accept_num[bid] - 1 == cur_len: # 全部接受
if cur_len + 2 <= max_draft_tokens - 1:
cur_len += 2
elif cur_len + 1 <= max_draft_tokens - 1:
cur_len += 1
else:
cur_len = max_draft_tokens - 1
else: # 有拒绝
cur_len = max(1, cur_len - 1)
actual_draft_token_nums[bid] = cur_len
# 偿还 encoder 欠账
if seq_lens_encoder[bid] != 0:
seq_lens_decoder[bid] += seq_lens_encoder[bid]
seq_lens_encoder[bid] = 0
# 写回下一轮首 token
draft_tokens[bid, 0] = accept_tokens[bid, accept_num[bid] - 1]
# 停止则清零 decoder
if stop_flag_now_int:
seq_lens_decoder[bid] = 0
elif inactive:
stop_flag_now_int = 1 # padding slot 视为 stop
stop_sum += stop_flag_now_int
# print("stop_sum: ", stop_sum)
not_need_stop[0] = stop_sum < stop_nums[0]
# 返回引用,仅供一致性
return (
seq_lens_encoder,
seq_lens_decoder,
not_need_stop,
draft_tokens,
actual_draft_token_nums,
)
# ---------------- 生成随机输入 ----------------
def gen_inputs(
max_bsz=512, # 与 CUDA BlockSize 对齐
max_draft_tokens=16,
real_bsz=123, # 可自调;须 ≤ max_bsz
seed=2022,
):
rng = np.random.default_rng(seed)
# 基本张量
seq_lens_encoder = rng.integers(0, 3, size=max_bsz, dtype=np.int32)
seq_lens_decoder = rng.integers(0, 20, size=max_bsz, dtype=np.int32)
not_need_stop = rng.integers(0, 1, size=1, dtype=np.bool_)
draft_tokens = rng.integers(0, 1000, size=(max_bsz, max_draft_tokens), dtype=np.int64)
actual_draft_nums = rng.integers(1, max_draft_tokens, size=max_bsz, dtype=np.int32)
accept_tokens = rng.integers(0, 1000, size=(max_bsz, max_draft_tokens), dtype=np.int64)
accept_num = rng.integers(1, max_draft_tokens, size=max_bsz, dtype=np.int32)
stop_flags = rng.integers(0, 2, size=max_bsz, dtype=np.bool_)
is_block_step = rng.integers(0, 2, size=max_bsz, dtype=np.bool_)
stop_nums = np.array([5], dtype=np.int64) # 阈值随意
# seq_lens_this_time 仅取 real_bsz 长度
seq_lens_this_time = rng.integers(1, max_draft_tokens, size=real_bsz, dtype=np.int32)
return {
"seq_lens_encoder": seq_lens_encoder,
"seq_lens_decoder": seq_lens_decoder,
"not_need_stop": not_need_stop,
"draft_tokens": draft_tokens,
"actual_draft_token_nums": actual_draft_nums,
"accept_tokens": accept_tokens,
"accept_num": accept_num,
"stop_flags": stop_flags,
"seq_lens_this_time": seq_lens_this_time,
"is_block_step": is_block_step,
"stop_nums": stop_nums,
# real_bsz = real_bsz,
# max_bsz = max_bsz,
# max_draft_tokens = max_draft_tokens
}
# ------------------- 单测主体 -------------------
inputs = gen_inputs(max_bsz=512, max_draft_tokens=32, real_bsz=201)
# ---- Paddle 端 ----
paddle_inputs = {}
for k, v in inputs.items():
if k in ("real_bsz", "max_bsz", "max_draft_tokens"):
paddle_inputs[k] = v # 纯 python int
else:
if k == "not_need_stop":
paddle_inputs[k] = paddle.to_tensor(v, place=paddle.CPUPlace())
else:
# 其余张量保持默认 place想测 GPU 就手动加 place=paddle.CUDAPlace(0)
paddle_inputs[k] = paddle.to_tensor(v)
# ---- NumPy 端 ----
# 为保证初值一致,这里必须复制 Paddle 入参的 numpy 值再传给参考实现
np_inputs = {
k: (paddle_inputs[k].numpy().copy() if isinstance(paddle_inputs[k], paddle.Tensor) else paddle_inputs[k])
for k in paddle_inputs
}
# 调用自定义算子
# print("seq_lens_encoder_xpu_before: ", paddle_inputs["seq_lens_encoder"])
out_pd = speculate_update_v3(**paddle_inputs)
# print("seq_lens_encoder_xpu_after: ", out_pd[0])
# print("not_need_stop: ", out_pd[2])
# speculate_update_v3 返回 5 个张量(与 Outputs 对应)
(
seq_lens_encoder_pd,
seq_lens_decoder_pd,
not_need_stop_pd,
draft_tokens_pd,
actual_draft_nums_pd,
) = out_pd
# print("seq_lens_encoder_np_before: ", np_inputs["seq_lens_encoder"])
out_np = speculate_update_v3_np(**np_inputs)
# print("seq_lens_encoder_np_after: ", out_np[0])
# print("not_need_stop: ", out_np[2])
# ---------------- 校对 ----------------
names = [
"seq_lens_encoder",
"seq_lens_decoder",
"not_need_stop",
"draft_tokens",
"actual_draft_token_nums",
]
pd_tensors = [
seq_lens_encoder_pd,
seq_lens_decoder_pd,
not_need_stop_pd,
draft_tokens_pd,
actual_draft_nums_pd,
]
for name, pd_val, np_val in zip(names, pd_tensors, out_np):
pd_arr = pd_val.numpy()
ok = np.array_equal(pd_arr, np_val)
print(f"{name:25s} equal :", ok)
# 也可以加 assert配合 pytest
# assert all(np.array_equal(p.numpy(), n) for p,n in zip(pd_tensors, out_np))

View File

@@ -0,0 +1,634 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import List
import numpy as np
# tests/speculate_verify.py
import paddle
from fastdeploy.model_executor.ops.xpu import speculate_verify
def topp_sampling_kernel(candidate_ids, candidate_scores, curand_value, candidate_len, topp, tid=0):
"""
Python 仿真版 Top-p 样本选择函数。
参数:
- candidate_ids: [candidate_len] int64 array,候选 token
- candidate_scores: [candidate_len] float32 array,对应概率
- curand_value: float,范围在 [0,1),模拟 GPU 中的 curand_uniform
- candidate_len: int,候选个数
- topp: float,TopP 截断阈值
- tid: 模拟线程 ID,仅用于调试(非必须)
返回:
- 采样得到的 token(int64)
"""
rand_top_p = curand_value * topp
sum_scores = 0.0
for i in range(candidate_len):
print(
f"debug sample i:{i} scores:{candidate_scores[i]},ids:{candidate_ids[i]},curand_value{curand_value},topp{topp}, value*topp{rand_top_p}"
)
sum_scores += candidate_scores[i]
sum_scores += candidate_scores[i]
if rand_top_p <= sum_scores:
return candidate_ids[i]
return candidate_ids[0] # fallback理论上不会走到这
# def is_in_end(id: int, end_ids: np.ndarray, length: int) -> bool:
# """
# 判断 id 是否存在于 end_ids 前 length 个元素中。
# """
# for i in range(length):
# if id == end_ids[i]:
# return True
# return False
# def is_in(candidates: np.ndarray, draft: int, candidate_len: int) -> bool:
# """
# 判断 draft 是否在 candidates 的前 candidate_len 个元素中。
# """
# for i in range(candidate_len):
# if draft == candidates[i]:
# return True
# return False
# ---------------- NumPy 参考实现 ----------------
def speculate_verify_np(
accept_tokens,
accept_num,
step_idx,
stop_flags,
seq_lens_encoder,
seq_lens_decoder,
draft_tokens,
seq_lens_this_time,
verify_tokens,
verify_scores,
max_dec_len,
end_tokens,
is_block_step,
output_cum_offsets,
actual_candidate_len,
actual_draft_token_nums,
topp,
max_seq_len,
verify_window,
enable_topp,
):
def is_in_end(token, end_tokens, end_length):
return token in end_tokens[:end_length]
def is_in(candidate_list, token, length):
return token in candidate_list[:length]
bsz = accept_tokens.shape[0]
real_bsz = seq_lens_this_time.shape[0]
max_draft_tokens = draft_tokens.shape[1]
end_length = end_tokens.shape[0]
max_candidate_len = verify_tokens.shape[1]
use_topk = False
prefill_one_step_stop = False
# random
initial_seed = 0
infer_seed: List[int] = [initial_seed] * bsz
dev_curand_states: List[float] = []
# 循环生成随机数
for i in range(bsz):
current_seed = infer_seed[i] # 这里 current_seed 总是等于 initial_seed
# 使用当前的种子创建一个独立的随机数生成器实例
# 这对应于 C++ 的 std::mt19937_64 engine(infer_seed[i]);
rng = random.Random(current_seed)
# 从独立的生成器中获取一个 [0.0, 1.0) 范围内的浮点数
# 这对应于 C++ 的 dist(engine);
dev_curand_states.append(rng.random())
# --- 在函数内部进行扁平化操作 ---
# 只有那些在 C++ 中通过指针算术访问的多维数组需要扁平化
accept_tokens_flat = accept_tokens.reshape(-1)
draft_tokens_flat = draft_tokens.reshape(-1)
verify_tokens_flat = verify_tokens.reshape(-1)
verify_scores_flat = verify_scores.reshape(-1)
print(f"DEBUG: accept_tokens_flat shape: {accept_tokens_flat.shape}")
print(f"DEBUG: draft_tokens_flat shape: {draft_tokens_flat.shape}")
print(f"DEBUG: verify_tokens_flat shape: {verify_tokens_flat.shape}")
print(f"DEBUG: verify_scores_flat shape: {verify_scores_flat.shape}")
# 其他数组 (如 accept_num, step_idx, stop_flags, end_tokens, dev_curand_states, actual_candidate_len,
# seq_lens_encoder, seq_lens_decoder, actual_draft_token_nums, topp_values,
# seq_lens_this_time, max_dec_len, is_block_step, output_cum_offsets)
# 根据其 C++ 原始定义,如果本身就是一维的,则不需要额外的 reshape。
# 这里直接使用其原始引用,或者如果其维度不确定,也可以做 flatten()。
# 为了明确,我们假设这些参数如果不是 (N, K) 形式,就已经是 (N,) 形式。
print()
# 遍历批次中的每个样本
for bid in range(real_bsz):
# C++: const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
start_token_id = bid * max_seq_len - output_cum_offsets[bid]
accept_num_now = 1
stop_flag_now_int = 0
print(
f"DEBUG: start_token_id: {start_token_id}, max_seq_len: {max_seq_len}, output_cum_offsets[{bid}]: {output_cum_offsets[bid]}"
)
# C++: if (!(is_block_step[bid] || bid >= real_bsz))
if not (
is_block_step[bid] or bid >= real_bsz
): # bid >= real_bsz 在 Python for 循环中天然满足,但为保持一致保留
if stop_flags[bid]:
stop_flag_now_int = 1
else:
# C++: auto *verify_tokens_now = verify_tokens + start_token_id * max_candidate_len;
# Python: verify_tokens_now 是一个指向当前批次 verify_tokens 起始的扁平视图
# 模拟了 C++ 中指针偏移后的“基地址”
verify_tokens_now = verify_tokens_flat[start_token_id * max_candidate_len :] # 从基址到末尾
# C++: auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
# Python: draft_tokens_now 是当前批次 draft_tokens 起始的扁平视图
draft_tokens_now = draft_tokens_flat[bid * max_draft_tokens :] # 从基址到末尾
# C++: auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
# Python: actual_candidate_len_now 是当前批次 actual_candidate_len 起始的扁平视图
actual_candidate_len_now = actual_candidate_len[start_token_id:] # actual_candidate_len 已经是 1D
# C++: int i = 0;
i = 0
# C++: for (; i < seq_lens_this_time[bid] - 1; i++)
for loop_i in range(seq_lens_this_time[bid] - 1): # 使用 loop_i 作为 Python 的循环变量
i = loop_i # 保持 C++ 的 i 在每次迭代中更新为当前索引
# C++: if (seq_lens_encoder[bid] != 0)
if seq_lens_encoder[bid] != 0:
break
if use_topk:
# C++: if (verify_tokens_now[i * max_candidate_len] == draft_tokens_now[i + 1])
if verify_tokens_now[i * max_candidate_len] == draft_tokens_now[i + 1]:
step_idx[bid] += 1
accept_token = draft_tokens_now[i + 1]
# C++: accept_tokens[bid * max_draft_tokens + i] = accept_token;
accept_tokens_flat[bid * max_draft_tokens + i] = accept_token
# C++: if (is_in_end(accept_token, end_tokens, end_length) || step_idx[bid] >= max_dec_len[bid])
if is_in_end(accept_token, end_tokens, end_length) or step_idx[bid] >= max_dec_len[bid]:
stop_flags[bid] = True
stop_flag_now_int = 1
if step_idx[bid] >= max_dec_len[bid]:
accept_tokens_flat[bid * max_draft_tokens + i] = end_tokens[0]
break
else:
accept_num_now += 1
else:
break
else: # C++: else (Top P verify)
# C++: auto actual_candidate_len_value = actual_candidate_len_now[i] > max_candidate_len ? max_candidate_len : actual_candidate_len_now[i];
actual_candidate_len_value = min(actual_candidate_len_now[i], max_candidate_len)
# C++: if (is_in(verify_tokens_now + i * max_candidate_len, draft_tokens_now[i + 1], actual_candidate_len_value))
# 传入当前候选的扁平视图
verify_tokens_current_candidate_view = verify_tokens_now[
i * max_candidate_len : (i + 1) * max_candidate_len
]
if is_in(
verify_tokens_current_candidate_view,
draft_tokens_now[i + 1],
actual_candidate_len_value,
):
step_idx[bid] += 1
accept_token = draft_tokens_now[i + 1]
accept_tokens_flat[bid * max_draft_tokens + i] = accept_token
if is_in_end(accept_token, end_tokens, end_length) or step_idx[bid] >= max_dec_len[bid]:
stop_flags[bid] = True
stop_flag_now_int = 1
if step_idx[bid] >= max_dec_len[bid]:
accept_tokens_flat[bid * max_draft_tokens + i] = end_tokens[0]
break
else:
accept_num_now += 1
else:
# TopK verify
ii = i # C++ 中 ii 从 i 开始
# C++: if (max_candidate_len >= 2 && verify_tokens_now[ii * max_candidate_len + 1] == draft_tokens_now[ii + 1])
if (
max_candidate_len >= 2
and verify_tokens_now[ii * max_candidate_len + 1] == draft_tokens_now[ii + 1]
): # top-2
j = 0
ii += 1 # C++ 中 ii 从下一个位置开始检查
# C++: for (; j < verify_window && ii < seq_lens_this_time[bid] - 1; j++, ii++)
while j < verify_window and ii < seq_lens_this_time[bid] - 1:
if verify_tokens_now[ii * max_candidate_len] != draft_tokens_now[ii + 1]:
break
j += 1
ii += 1
# C++: if (j >= verify_window)
if j >= verify_window: # accept all
accept_num_now += verify_window + 1
step_idx[bid] += verify_window + 1
# C++: for (; i < ii; i++)
for k_accepted_idx in range(i, ii): # i 会被更新
accept_token = draft_tokens_now[k_accepted_idx + 1]
accept_tokens_flat[bid * max_draft_tokens + k_accepted_idx] = accept_token
if (
is_in_end(
accept_token,
end_tokens,
end_length,
)
or step_idx[bid] >= max_dec_len[bid]
):
stop_flags[bid] = True
stop_flag_now_int = 1
if step_idx[bid] >= max_dec_len[bid]:
accept_tokens_flat[bid * max_draft_tokens + k_accepted_idx] = (
end_tokens[0]
)
accept_num_now -= 1
step_idx[bid] -= 1
break # 跳出内层接受循环
break # 跳出主验证循环 (TopK 逻辑结束,无论成功与否)
# else 的 break 对应 is_in(Top P 验证失败,也不是 TopK 匹配)
break # 跳出主验证循环
# 采样阶段 (Sampling Phase)
# C++ 中 i 变量在循环结束后会保留其最终值,直接用于采样
# Python 同样loop_i 的最终值赋值给了 i
if not stop_flag_now_int:
accept_token: int
# C++: const float *verify_scores_now = verify_scores + start_token_id * max_candidate_len;
# Python: verify_scores_now 对应 C++ 中从 start_token_id 开始的 verify_scores 视图
verify_scores_now = verify_scores_flat[start_token_id * max_candidate_len :]
step_idx[bid] += 1
if enable_topp:
# C++: auto actual_candidate_len_value = actual_candidate_len_now[i] > max_candidate_len ? max_candidate_len : actual_candidate_len_now[i];
actual_candidate_len_value = min(actual_candidate_len_now[i], max_candidate_len)
# 传入当前候选的扁平视图
verify_tokens_sampling_view = verify_tokens_now[
i * max_candidate_len : (i + 1) * max_candidate_len
]
verify_scores_sampling_view = verify_scores_now[
i * max_candidate_len : (i + 1) * max_candidate_len
]
# C++: accept_token = topp_sampling_kernel(...)
accept_token = topp_sampling_kernel(
verify_tokens_sampling_view,
verify_scores_sampling_view,
dev_curand_states[i], # C++: dev_curand_states + i
actual_candidate_len_value,
topp[bid], # C++: topp[bid]
bid, # C++: bid
)
else:
accept_token = int(verify_tokens_now[i * max_candidate_len])
print(
"debug python last accept_token",
accept_token,
"prefill_one_step_stop",
prefill_one_step_stop,
)
# C++: accept_tokens[bid * max_draft_tokens + i] = accept_token;
accept_tokens_flat[bid * max_draft_tokens + i] = accept_token
if prefill_one_step_stop:
stop_flags[bid] = True
if is_in_end(accept_token, end_tokens, end_length) or step_idx[bid] >= max_dec_len[bid]:
stop_flags[bid] = True
stop_flag_now_int = 1
if step_idx[bid] >= max_dec_len[bid]:
accept_tokens_flat[bid * max_draft_tokens + i] = end_tokens[0]
accept_num[bid] = accept_num_now
return accept_tokens, accept_num, step_idx, stop_flags
# ---------------- 生成随机输入 ----------------
def gen_speculate_verify_inputs(
real_bsz=123,
max_draft_tokens=16,
max_seq_len=256,
max_candidate_len=8,
verify_window=2,
end_length=4,
enable_topp=True,
seed=2025,
):
rng = np.random.default_rng(seed)
# 基础输入
seq_lens_encoder = rng.integers(0, 3, size=real_bsz, dtype=np.int32)
seq_lens_decoder = rng.integers(1, max_draft_tokens, size=real_bsz, dtype=np.int32)
draft_tokens = rng.integers(0, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64)
actual_draft_token_nums = rng.integers(1, max_draft_tokens + 1, size=real_bsz, dtype=np.int32)
seq_lens_this_time = rng.integers(1, max_seq_len + 1, size=real_bsz, dtype=np.int32)
sum_seq_this_time = int(np.sum(seq_lens_this_time))
# print("debug param set sum_seq_this_time",sum_seq_this_time)
# print("debug param real_bsz * max_draft_tokens < 2k",real_bsz * max_draft_tokens)
# print("debug sum_seq_this_time * max_candidate_len < 2k",sum_seq_this_time * max_candidate_len)
verify_tokens = rng.integers(0, 1000, size=(sum_seq_this_time, max_candidate_len), dtype=np.int64)
verify_scores = rng.random(size=(sum_seq_this_time, max_candidate_len)).astype(np.float32)
max_dec_len = rng.integers(16, 64, size=real_bsz, dtype=np.int64)
end_tokens = rng.integers(1, 1000, size=end_length, dtype=np.int64)
is_block_step = rng.integers(0, 2, size=real_bsz, dtype=bool)
# output_cum_offsets = np.zeros_like(seq_lens_this_time)
# output_cum_offsets[1:] = np.cumsum(seq_lens_this_time[:-1])
blank_lengths = max_seq_len - seq_lens_this_time
output_cum_offsets = np.concatenate([[0], np.cumsum(blank_lengths[:-1])])
output_cum_offsets = output_cum_offsets.astype("int32")
actual_candidate_len = rng.integers(1, max_candidate_len + 1, size=sum_seq_this_time, dtype=np.int32)
topp = (
rng.uniform(0.8, 1.0, size=real_bsz).astype(np.float32)
if enable_topp
else np.zeros(real_bsz, dtype=np.float32)
)
# 输出(占位)
accept_tokens = np.zeros((real_bsz, max_draft_tokens), dtype=np.int64)
accept_num = np.zeros(real_bsz, dtype=np.int32)
step_idx = np.zeros(real_bsz, dtype=np.int64)
stop_flags = np.zeros(real_bsz, dtype=bool)
return {
"accept_tokens": accept_tokens,
"accept_num": accept_num,
"step_idx": step_idx,
"stop_flags": stop_flags,
"seq_lens_encoder": seq_lens_encoder,
"seq_lens_decoder": seq_lens_decoder,
"draft_tokens": draft_tokens,
"seq_lens_this_time": seq_lens_this_time,
"verify_tokens": verify_tokens,
"verify_scores": verify_scores,
"max_dec_len": max_dec_len,
"end_tokens": end_tokens,
"is_block_step": is_block_step,
"output_cum_offsets": output_cum_offsets,
"actual_candidate_len": actual_candidate_len,
"actual_draft_token_nums": actual_draft_token_nums,
"topp": topp,
"max_seq_len": max_seq_len,
"verify_window": verify_window,
"enable_topp": enable_topp,
}
# ------------------- 单测主体 -------------------
# # ---- Paddle 端 ----
def run_speculate_verify_test(
real_bsz,
max_draft_tokens,
max_seq_len,
max_candidate_len,
verify_window,
end_length,
enable_topp,
seed,
):
inputs = gen_speculate_verify_inputs(
real_bsz=real_bsz,
max_draft_tokens=max_draft_tokens,
max_seq_len=max_seq_len,
max_candidate_len=max_candidate_len,
verify_window=verify_window,
end_length=end_length,
enable_topp=enable_topp,
seed=seed,
)
paddle_inputs = {}
print("========= 1 xpu process==========")
for k, v in inputs.items():
if isinstance(v, (int, bool)):
paddle_inputs[k] = v
# print(f"{k:<25} type: {type(v).__name__}, value: {v}")
else:
# paddle_inputs[k] = paddle.to_tensor(v, place=paddle.CPUPlace())
paddle_inputs[k] = paddle.to_tensor(v, place=paddle.XPUPlace(0))
# print(f"{k:<25} type: Tensor, dtype: {paddle_inputs[k].dtype}, shape: {paddle_inputs[k].shape}")
out_pd = speculate_verify(**paddle_inputs)
(accept_tokens_pd, accept_num_pd, step_idx_pd, stop_flags_pd) = out_pd
pd_tensors = [accept_tokens_pd, accept_num_pd, step_idx_pd, stop_flags_pd]
print("========= 1 end==========")
print("========= 2 python process==========")
# np_inputs = {k: (paddle_inputs[k].numpy().copy() if isinstance(paddle_inputs[k], paddle.Tensor)
# else paddle_inputs[k])
# for k in paddle_inputs}
# out_np = speculate_verify_np(**np_inputs)
# (accept_tokens_np, accept_num_np, step_idx_np, stop_flags_np) = out_np
# np_tensors = [accept_tokens_np, accept_num_np, step_idx_np, stop_flags_np]
print("=========2 end =======")
print("========= 3 (CPU)==========")
paddle_inputs_cpu = {}
for k, v in inputs.items(): # 重新使用原始的 inputs 字典,确保数据原始状态
if isinstance(v, (int, bool)):
paddle_inputs_cpu[k] = v
# print(f"{k:<25} type: {type(v).__name__}, value: {v}")
else:
# 核心修改:使用 paddle.CPUPlace()
paddle_inputs_cpu[k] = paddle.to_tensor(v, place=paddle.CPUPlace())
# print(f"{k:<25} type: Tensor, dtype: {paddle_inputs_cpu[k].dtype}, shape: {paddle_inputs_cpu[k].shape}")
out_cpu = speculate_verify(**paddle_inputs_cpu)
(accept_tokens_cpu, accept_num_cpu, step_idx_cpu, stop_flags_cpu) = out_cpu
cpu_tensors = [
accept_tokens_cpu,
accept_num_cpu,
step_idx_cpu,
stop_flags_cpu,
]
print("========= 3 (CPU) end==========")
# ---------------- 校对 ----------------
# print("========= python/cpu vs xpu verify ==========")
# names = ["accept_tokens", "accept_num", "step_idx", "stop_flags"]
# for name, pd_val, np_val in zip(names, pd_tensors, np_tensors):
# pd_arr = pd_val.numpy()
# ok = np.array_equal(pd_arr, np_val)
# print(f"{name:20s} equal: {ok}")
# if not ok:
# print(f"{name} mismatch!\nPaddle:\n{pd_arr}\n\nNumPy:\n{np_val}")
print("========= cpu vs xpu verify ==========")
names = ["accept_tokens", "accept_num", "step_idx", "stop_flags"]
# for name, pd_val, np_val in zip(names, pd_tensors, cpu_tensors):
# pd_arr = pd_val.numpy()
# ok = np.array_equal(pd_arr, np_val)
# print(f"{name:20s} equal: {ok}")
# if not ok:
# print(f"{name} mismatch!\nPaddle:\n{pd_arr}\n\nNumPy:\n{np_val}")
for name, pd_val, np_val in zip(names, pd_tensors, cpu_tensors):
pd_arr = pd_val.numpy()
ok = np.array_equal(pd_arr, np_val)
print(f"{name:20s} equal: {ok}")
if not ok:
print(f"{name} mismatch!")
# 输出不同位置的索引和对应值
print(f"{name} mismatch!\nPaddle:\n{pd_arr}\n\nNumPy:\n{np_val}")
mismatches = np.where(pd_arr != np_val)
for idx in zip(*mismatches):
print(f" idx {idx}: Paddle = {pd_arr[idx]}, NumPy = {np_val[idx]}")
# 如果差异太多可限制输出数量
if len(mismatches[0]) > 20:
print(" ... (truncated)")
# -------------------------------------
# 测试用例
# -------------------------------------
test_configs = [
{
"real_bsz": 4,
"max_draft_tokens": 3,
"max_seq_len": 30,
"max_candidate_len": 4,
"verify_window": 2,
"end_length": 2,
"enable_topp": True,
"seed": 2025,
},
{
"real_bsz": 77,
"max_draft_tokens": 10,
"max_seq_len": 12000,
"max_candidate_len": 8,
"verify_window": 2,
"end_length": 4,
"enable_topp": True,
"seed": 2025,
},
{
"real_bsz": 1,
"max_draft_tokens": 2,
"max_seq_len": 10,
"max_candidate_len": 1,
"verify_window": 1,
"end_length": 1,
"enable_topp": True,
"seed": 42,
},
{
"real_bsz": 128,
"max_draft_tokens": 7,
"max_seq_len": 999,
"max_candidate_len": 5,
"verify_window": 3,
"end_length": 3,
"enable_topp": True,
"seed": 422,
},
{
"real_bsz": 99,
"max_draft_tokens": 5,
"max_seq_len": 10,
"max_candidate_len": 3,
"verify_window": 4,
"end_length": 4,
"enable_topp": True,
"seed": 42,
},
{
"real_bsz": 1,
"max_draft_tokens": 9,
"max_seq_len": 11,
"max_candidate_len": 4,
"verify_window": 2,
"end_length": 5,
"enable_topp": False,
"seed": 42,
},
{
"real_bsz": 33,
"max_draft_tokens": 5,
"max_seq_len": 10111,
"max_candidate_len": 5,
"verify_window": 2,
"end_length": 6,
"enable_topp": False,
"seed": 42,
},
{
"real_bsz": 6,
"max_draft_tokens": 4,
"max_seq_len": 10001,
"max_candidate_len": 6,
"verify_window": 2,
"end_length": 7,
"enable_topp": False,
"seed": 42,
},
{
"real_bsz": 7,
"max_draft_tokens": 3,
"max_seq_len": 777,
"max_candidate_len": 7,
"verify_window": 2,
"end_length": 5,
"enable_topp": False,
"seed": 42,
},
{
"real_bsz": 55,
"max_draft_tokens": 5,
"max_seq_len": 31,
"max_candidate_len": 9,
"verify_window": 2,
"end_length": 3,
"enable_topp": False,
"seed": 42,
},
]
for i, cfg in enumerate(test_configs):
print(f"\n\n======== Running Test Case {i} ========")
run_speculate_verify_test(**cfg)