diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_postprocess.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_postprocess.cc new file mode 100644 index 000000000..c61fda27b --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_postprocess.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens, + const paddle::Tensor& base_model_seq_lens_this_time, + const paddle::Tensor& base_model_seq_lens_encoder, + const paddle::Tensor& base_model_stop_flags) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + int real_bsz = base_model_draft_tokens.shape()[0]; + int base_model_draft_token_len = base_model_draft_tokens.shape()[1]; + int r = baidu::xpu::api::plugin::draft_model_postprocess( + xpu_ctx->x_context(), + const_cast(base_model_draft_tokens.data()), + const_cast(base_model_seq_lens_this_time.data()), + const_cast(base_model_seq_lens_encoder.data()), + const_cast(base_model_stop_flags.data()), + real_bsz, + base_model_draft_token_len); + PADDLE_ENFORCE_XDNN_SUCCESS(r, ""); +} + +PD_BUILD_OP(draft_model_postprocess) + .Inputs({"base_model_draft_tokens", + "base_model_seq_lens_this_time", + "base_model_seq_lens_encoder", + "base_model_stop_flags"}) + .Outputs({"base_model_draft_tokens_out", + "base_model_seq_lens_this_time_out", + "base_model_stop_flags_out"}) + .SetInplaceMap({{"base_model_draft_tokens", "base_model_draft_tokens_out"}, + {"base_model_seq_lens_this_time", + "base_model_seq_lens_this_time_out"}, + {"base_model_stop_flags", "base_model_stop_flags_out"}}) + .SetKernelFn(PD_KERNEL(DraftModelPostprocess)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_preprocess.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_preprocess.cc new file mode 100644 index 000000000..68551c548 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_preprocess.cc @@ -0,0 +1,138 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" + +namespace api = baidu::xpu::api; +void DraftModelPreprocess(const paddle::Tensor& draft_tokens, + const paddle::Tensor& input_ids, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& seq_lens_encoder_record, + const paddle::Tensor& seq_lens_decoder_record, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& batch_drop, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& base_model_seq_lens_encoder, + const paddle::Tensor& base_model_seq_lens_decoder, + const paddle::Tensor& base_model_step_idx, + const paddle::Tensor& base_model_stop_flags, + const paddle::Tensor& base_model_is_block_step, + const paddle::Tensor& base_model_draft_tokens, + const int max_draft_token, + const bool truncate_first_token, + const bool splitwise_prefill) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + api::Context* ctx = static_cast(dev_ctx)->x_context(); + if (draft_tokens.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + int real_bsz = seq_lens_this_time.shape()[0]; + int accept_tokens_len = accept_tokens.shape()[1]; + int input_ids_len = input_ids.shape()[1]; + int draft_tokens_len = draft_tokens.shape()[1]; + int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1]; + auto not_need_stop_gpu = + not_need_stop.copy_to(seq_lens_this_time.place(), false); + + int r = baidu::xpu::api::plugin::draft_model_preprocess( + ctx, + const_cast(draft_tokens.data()), + const_cast(input_ids.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_idx.data()), + const_cast(seq_lens_encoder_record.data()), + const_cast(seq_lens_decoder_record.data()), + const_cast(not_need_stop_gpu.data()), + const_cast(batch_drop.data()), + accept_tokens.data(), + accept_num.data(), + base_model_seq_lens_encoder.data(), + base_model_seq_lens_decoder.data(), + base_model_step_idx.data(), + base_model_stop_flags.data(), + base_model_is_block_step.data(), + const_cast(base_model_draft_tokens.data()), + real_bsz, + max_draft_token, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + truncate_first_token, + splitwise_prefill); + PD_CHECK(r == 0, "xpu::plugin::draft_model_preprocess failed."); + auto not_need_stop_cpu = + not_need_stop_gpu.copy_to(not_need_stop.place(), false); + bool* not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; +} + +PD_BUILD_OP(draft_model_preprocess) + .Inputs({"draft_tokens", + "input_ids", + "stop_flags", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "step_idx", + "seq_lens_encoder_record", + "seq_lens_decoder_record", + "not_need_stop", + "batch_drop", + "accept_tokens", + "accept_num", + "base_model_seq_lens_encoder", + "base_model_seq_lens_decoder", + "base_model_step_idx", + "base_model_stop_flags", + "base_model_is_block_step", + "base_model_draft_tokens"}) + .Outputs({"draft_tokens_out", + "input_ids_out", + "stop_flags_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "step_idx_out", + "not_need_stop_out", + "batch_drop_out", + "seq_lens_encoder_record_out", + "seq_lens_decoder_record_out"}) + .Attrs({"max_draft_token: int", + "truncate_first_token: bool", + "splitwise_prefill: bool"}) + .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, + {"input_ids", "input_ids_out"}, + {"stop_flags", "stop_flags_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"step_idx", "step_idx_out"}, + {"not_need_stop", "not_need_stop_out"}, + {"batch_drop", "batch_drop_out"}, + {"seq_lens_encoder_record", "seq_lens_encoder_record_out"}, + {"seq_lens_decoder_record", "seq_lens_decoder_record_out"}}) + .SetKernelFn(PD_KERNEL(DraftModelPreprocess)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_update.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_update.cc new file mode 100644 index 000000000..930fc7804 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_update.cc @@ -0,0 +1,122 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" + +void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& pre_ids, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& max_dec_len, + const paddle::Tensor& end_ids, + const paddle::Tensor& base_model_draft_tokens, + const int max_seq_len, + const int substep) { + // printf("enter clear \n"); + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + baidu::xpu::api::Context* ctx = + static_cast(dev_ctx)->x_context(); + + if (draft_tokens.is_cpu()) { + ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU); + } + + auto seq_lens_this_time_shape = seq_lens_this_time.shape(); + const int real_bsz = seq_lens_this_time_shape[0]; + auto not_need_stop_device = + not_need_stop.copy_to(seq_lens_this_time.place(), false); + const int end_ids_len = end_ids.shape()[0]; + const int max_draft_token = draft_tokens.shape()[1]; + const int pre_id_length = pre_ids.shape()[1]; + const int max_base_model_draft_token = base_model_draft_tokens.shape()[1]; + constexpr int BlockSize = 512; + bool prefill_one_step_stop = false; + if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { + // std::cout << "Your PATH is: " << env_p << '\n'; + if (env_p[0] == '1') { + prefill_one_step_stop = true; + } + } + + int r = baidu::xpu::api::plugin::draft_model_update( + ctx, + inter_next_tokens.data(), + const_cast(draft_tokens.data()), + const_cast(pre_ids.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_idx.data()), + output_cum_offsets.data(), + const_cast(stop_flags.data()), + const_cast(not_need_stop_device.data()), + max_dec_len.data(), + end_ids.data(), + const_cast(base_model_draft_tokens.data()), + real_bsz, + max_draft_token, + pre_id_length, + max_base_model_draft_token, + end_ids_len, + max_seq_len, + substep, + prefill_one_step_stop); + + PD_CHECK(r == 0, "draft_model_update failed."); +} + +PD_BUILD_OP(draft_model_update) + .Inputs({"inter_next_tokens", + "draft_tokens", + "pre_ids", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "step_idx", + "output_cum_offsets", + "stop_flags", + "not_need_stop", + "max_dec_len", + "end_ids", + "base_model_draft_tokens"}) + .Attrs({"max_seq_len: int", "substep: int"}) + .Outputs({"draft_tokens_out", + "pre_ids_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "step_idx_out", + "stop_flags_out", + "not_need_stop_out", + "base_model_draft_tokens_out"}) + .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, + {"pre_ids", "pre_ids_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"step_idx", "step_idx_out"}, + {"stop_flags", "stop_flags_out"}, + {"not_need_stop", "not_need_stop_out"}, + {"base_model_draft_tokens", "base_model_draft_tokens_out"}}) + .SetKernelFn(PD_KERNEL(DraftModelUpdate)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/eagle_get_hidden_states.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/eagle_get_hidden_states.cc new file mode 100644 index 000000000..b45c8febd --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/eagle_get_hidden_states.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +namespace api = baidu::xpu::api; +std::vector EagleGetHiddenStates( + const paddle::Tensor& input, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& stop_flags, + const paddle::Tensor& accept_nums, + const paddle::Tensor& base_model_seq_lens_this_time, + const paddle::Tensor& base_model_seq_lens_encoder, + const int actual_draft_token_num) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + api::Context* ctx = xpu_ctx->x_context(); + if (input.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + + auto input_token_num = input.shape()[0]; + auto dim_embed = input.shape()[1]; + int bsz = seq_lens_this_time.shape()[0]; + auto position_map = paddle::full( + {input_token_num}, -1, seq_lens_this_time.dtype(), input.place()); + auto output_token_num = paddle::empty( + {1}, seq_lens_this_time.dtype(), seq_lens_this_time.place()); + + int r = api::plugin::compute_order(ctx, + seq_lens_this_time.data(), + seq_lens_encoder.data(), + base_model_seq_lens_this_time.data(), + base_model_seq_lens_encoder.data(), + accept_nums.data(), + position_map.data(), + output_token_num.data(), + bsz, + actual_draft_token_num, + input_token_num); + PD_CHECK(r == 0, "xpu::plugin::compute_order failed."); + + int output_token_num_cpu = + output_token_num.copy_to(paddle::CPUPlace(), false).data()[0]; + auto out = paddle::empty( + {output_token_num_cpu, dim_embed}, input.dtype(), input.place()); + int elem_cnt = input_token_num * dim_embed; + + switch (input.dtype()) { + case paddle::DataType::BFLOAT16: + using XPUTypeBF16 = typename XPUTypeTrait::Type; + typedef paddle::bfloat16 bf16_data_t; + r = api::plugin::rebuild_hidden_states( + ctx, + reinterpret_cast(input.data()), + position_map.data(), + reinterpret_cast(out.data()), + dim_embed, + elem_cnt); + PD_CHECK(r == 0, "xpu::plugin::rebuild_hidden_states failed."); + return {out}; + case paddle::DataType::FLOAT16: + using XPUTypeFP16 = typename XPUTypeTrait::Type; + typedef paddle::float16 fp16_data_t; + r = api::plugin::rebuild_hidden_states( + ctx, + reinterpret_cast(input.data()), + position_map.data(), + reinterpret_cast(out.data()), + dim_embed, + elem_cnt); + PD_CHECK(r == 0, "xpu::plugin::rebuild_hidden_states failed."); + return {out}; + case paddle::DataType::FLOAT32: + r = api::plugin::rebuild_hidden_states( + ctx, + reinterpret_cast(input.data()), + position_map.data(), + reinterpret_cast(out.data()), + dim_embed, + elem_cnt); + PD_CHECK(r == 0, "xpu::plugin::rebuild_hidden_states failed."); + return {out}; + default: + PD_THROW("Unsupported data type."); + } +} + +PD_BUILD_OP(eagle_get_hidden_states) + .Inputs({"input", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "stop_flags", + "accept_nums", + "base_model_seq_lens_this_time", + "base_model_seq_lens_encoder"}) + .Attrs({"actual_draft_token_num: int"}) + .Outputs({"out"}) + .SetKernelFn(PD_KERNEL(EagleGetHiddenStates)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/eagle_get_self_hidden_states.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/eagle_get_self_hidden_states.cc new file mode 100644 index 000000000..68d09662a --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/eagle_get_self_hidden_states.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +namespace api = baidu::xpu::api; +std::vector EagleGetSelfHiddenStates( + const paddle::Tensor& input, + const paddle::Tensor& last_seq_lens_this_time, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& step_idx) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + api::Context* ctx = xpu_ctx->x_context(); + if (input.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + + int input_token_num = input.shape()[0]; + int dim_embed = input.shape()[1]; + int bsz = seq_lens_this_time.shape()[0]; + auto src_map = paddle::empty({input_token_num}, + seq_lens_this_time.dtype(), + seq_lens_this_time.place()); + auto output_token_num = paddle::empty( + {1}, seq_lens_this_time.dtype(), seq_lens_this_time.place()); + + int r = api::plugin::compute_self_order( + ctx, + reinterpret_cast(last_seq_lens_this_time.data()), + reinterpret_cast(seq_lens_this_time.data()), + reinterpret_cast(step_idx.data()), + reinterpret_cast(src_map.data()), + reinterpret_cast(output_token_num.data()), + bsz); + PD_CHECK(r == 0, "xpu::plugin::compute_self_order failed."); + + int output_token_num_cpu = + output_token_num.copy_to(paddle::CPUPlace(), false).data()[0]; + + auto out = paddle::empty( + {output_token_num_cpu, dim_embed}, input.type(), input.place()); + + int elem_cnt = output_token_num_cpu * dim_embed; + + switch (input.dtype()) { + case paddle::DataType::BFLOAT16: + using XPUTypeBF16 = typename XPUTypeTrait::Type; + typedef paddle::bfloat16 bf16_data_t; + r = api::plugin::rebuild_self_hidden_states( + ctx, + reinterpret_cast(input.data()), + src_map.data(), + reinterpret_cast(out.data()), + dim_embed, + elem_cnt); + PD_CHECK(r == 0, "xpu::plugin::rebuild_self_hidden_states failed."); + return {out}; + case paddle::DataType::FLOAT16: + using XPUTypeFP16 = typename XPUTypeTrait::Type; + typedef paddle::float16 fp16_data_t; + r = api::plugin::rebuild_self_hidden_states( + ctx, + reinterpret_cast(input.data()), + src_map.data(), + reinterpret_cast(out.data()), + dim_embed, + elem_cnt); + PD_CHECK(r == 0, "xpu::plugin::rebuild_self_hidden_states failed."); + return {out}; + case paddle::DataType::FLOAT32: + r = api::plugin::rebuild_self_hidden_states( + ctx, + reinterpret_cast(input.data()), + src_map.data(), + reinterpret_cast(out.data()), + dim_embed, + elem_cnt); + PD_CHECK(r == 0, "xpu::plugin::rebuild_self_hidden_states failed."); + return {out}; + default: + PD_THROW("Unsupported data type."); + } +} + +PD_BUILD_OP(eagle_get_self_hidden_states) + .Inputs( + {"input", "last_seq_lens_this_time", "seq_lens_this_time", "step_idx"}) + .Outputs({"out"}) + .SetKernelFn(PD_KERNEL(EagleGetSelfHiddenStates)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/mtp_save_first_token.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/mtp_save_first_token.cc new file mode 100644 index 000000000..eccea7730 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/mtp_save_first_token.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#define MAX_BSZ 256 + +// #define SAVE_WITH_OUTPUT_DEBUG +#define MAX_DRAFT_TOKENS 6 +struct msgdata { + long mtype; // NOLINT + int mtext[2 + MAX_BSZ + + MAX_BSZ * MAX_DRAFT_TOKENS]; // stop_flag, token_num, tokens +}; + +void MTPSaveFirstToken(const paddle::Tensor& x, + const paddle::Tensor& not_need_stop, + int64_t rank_id, + int msg_queue_id, + bool save_each_rank) { + if (!save_each_rank && rank_id > 0) { + return; + } + int x_dim = x.shape()[1]; + auto x_cpu = x.copy_to(paddle::CPUPlace(), false); + int64_t* x_data = x_cpu.data(); + static struct msgdata msg_sed; + + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; +#endif + msg_queue_id = inference_msg_queue_id_from_env; + } + static key_t key = ftok("./", msg_queue_id); + static int msgid = msgget(key, IPC_CREAT | 0666); + + msg_sed.mtype = 1; + bool not_need_stop_data = not_need_stop.data()[0]; + int inference_msg_id_from_env = 1; + if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { + std::string inference_msg_id_env_str(inference_msg_id_env_p); + inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); + if (inference_msg_id_from_env == 2) { + // 2 and -2 is preserve for no-output indication. + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be 2, please use other number."); + } + if (inference_msg_id_from_env < 0) { + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be negative, please use other " + "number."); + } + +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env + << std::endl; +#endif + } else { +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." + << std::endl; +#endif + } +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "save_output_key: " << key << std::endl; + std::cout << "save msgid: " << msgid << std::endl; +#endif + msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env + : -inference_msg_id_from_env; + int bsz = x.shape()[0]; + msg_sed.mtext[1] = bsz; + for (int i = 0; i < bsz; i++) { +#ifdef SAVE_WITH_OUTPUT_DEBUG + printf("bid: %d. 1: %d. 2: %d.\n", + i, + static_cast(x_data[i * x_dim]), + static_cast(x_data[i * x_dim + 1])); +#endif + msg_sed.mtext[i + 2] = 2; + msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] = + static_cast(x_data[i * x_dim]); + msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] = + static_cast(x_data[i * x_dim + 1]); +#ifdef SAVE_WITH_OUTPUT_DEBUG + printf("mtext[%d]:%d. mtext[%d]:%d. \n", + i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ, + msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ], + i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ, + msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ]); +#endif + } + +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "msg data: "; + for (int i = 0; i < bsz; i++) { + std::cout << " " << static_cast(x_data[2 * i]) << " "; + std::cout << " " << static_cast(x_data[2 * i + 1]); + } + std::cout << std::endl; +#endif + if ((msgsnd(msgid, + &msg_sed, + (2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS) * 4, + 0)) == -1) { + printf("full msg buffer\n"); + } + return; +} + +void MTPSaveFirstTokenStatic(const paddle::Tensor& x, + const paddle::Tensor& not_need_stop, + int64_t rank_id, + bool save_each_rank) { + MTPSaveFirstToken(x, not_need_stop, rank_id, 1, save_each_rank); +} + +void MTPSaveFirstTokenDynamic(const paddle::Tensor& x, + const paddle::Tensor& not_need_stop, + int64_t rank_id, + int msg_queue_id, + bool save_each_rank) { + MTPSaveFirstToken(x, not_need_stop, rank_id, msg_queue_id, save_each_rank); +} + +PD_BUILD_OP(mtp_save_first_token) + .Inputs({"x", "not_need_stop"}) + .Attrs({"rank_id: int64_t", "save_each_rank: bool"}) + .Outputs({"x_out"}) + .SetInplaceMap({{"x", "x_out"}}) + .SetKernelFn(PD_KERNEL(MTPSaveFirstTokenStatic)); + +PD_BUILD_OP(mtp_save_first_token_dynamic) + .Inputs({"x", "not_need_stop"}) + .Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"}) + .Outputs({"x_out"}) + .SetInplaceMap({{"x", "x_out"}}) + .SetKernelFn(PD_KERNEL(MTPSaveFirstTokenDynamic)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/mtp_step_paddle.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/mtp_step_paddle.cc new file mode 100644 index 000000000..c7bf2d7a1 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/mtp_step_paddle.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" + +namespace api = baidu::xpu::api; +void MTPStepPaddle( + const paddle::Tensor &base_model_stop_flags, + const paddle::Tensor &stop_flags, + const paddle::Tensor &batch_drop, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, + const paddle::Tensor &free_list_len, + const int block_size, + const int max_draft_tokens) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + api::Context *ctx = xpu_ctx->x_context(); + if (base_model_stop_flags.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + const int bsz = seq_lens_this_time.shape()[0]; + const int block_num_per_seq = block_tables.shape()[1]; + + int r = baidu::xpu::api::plugin::mtp_free_and_dispatch_block( + ctx, + const_cast(base_model_stop_flags.data()), + const_cast(stop_flags.data()), + const_cast(batch_drop.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(encoder_block_lens.data()), + const_cast(used_list_len.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + bsz, + block_size, + block_num_per_seq, + max_draft_tokens); + PD_CHECK(r == 0, "free_and_dispatch_block failed."); + if (base_model_stop_flags.is_cpu() && ctx != nullptr) { + delete ctx; + } +} + +PD_BUILD_OP(mtp_step_paddle) + .Inputs({"base_model_stop_flags", + "stop_flags", + "batch_drop", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "block_tables", + "encoder_block_lens", + "used_list_len", + "free_list", + "free_list_len"}) + .Attrs({"block_size: int", "max_draft_tokens: int"}) + .Outputs({"block_tables_out", + "stop_flags_out", + "used_list_len_out", + "free_list_out", + "free_list_len_out"}) + .SetInplaceMap({{"block_tables", "block_tables_out"}, + {"stop_flags", "stop_flags_out"}, + {"used_list_len", "used_list_len_out"}, + {"free_list", "free_list_out"}, + {"free_list_len", "free_list_len_out"}}) + .SetKernelFn(PD_KERNEL(MTPStepPaddle)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_clear_accept_nums.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_clear_accept_nums.cc new file mode 100644 index 000000000..f47244169 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_clear_accept_nums.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +void SpeculateClearAcceptNums(const paddle::Tensor& accept_num, + const paddle::Tensor& seq_lens_decoder) { + // printf("enter clear \n"); + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + const int max_bsz = seq_lens_decoder.shape()[0]; + int r = baidu::xpu::api::plugin::speculate_clear_accept_nums( + xpu_ctx->x_context(), + const_cast(accept_num.data()), + seq_lens_decoder.data(), + max_bsz); + PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed."); +} + +PD_BUILD_OP(speculate_clear_accept_nums) + .Inputs({"accept_num", "seq_lens_decoder"}) + .Outputs({"seq_lens_decoder_out"}) + .SetInplaceMap({{"seq_lens_decoder", "seq_lens_decoder_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateClearAcceptNums)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_output.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_output.cc new file mode 100644 index 000000000..f248a1088 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_output.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#define MAX_BSZ 256 +#define MAX_DRAFT_TOKENS 6 + +struct msgdata { + int64_t mtype; + int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + + 2]; // stop_flag, bsz, accept_num*bsz, tokens... +}; + +void SpeculateGetOutput(const paddle::Tensor& x, + int64_t rank_id, + bool wait_flag, + int msg_queue_id, + bool get_each_rank) { + if (!get_each_rank && rank_id > 0) { + return; + } + + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); +#ifdef GET_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; +#endif + msg_queue_id = inference_msg_queue_id_from_env; + } + + static struct msgdata msg_rcv; + + static key_t key = ftok("./", msg_queue_id); + + static int msgid = msgget(key, IPC_CREAT | 0666); + + int64_t* out_data = const_cast(x.data()); + int ret = -1; + if (!wait_flag) { + ret = msgrcv(msgid, + &msg_rcv, + (MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4, + 0, + IPC_NOWAIT); + } else { + ret = msgrcv( + msgid, &msg_rcv, (MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4, 0, 0); + } + if (ret == -1) { + out_data[0] = -2; + out_data[1] = 0; + return; + } + int bsz = msg_rcv.mtext[1]; + + for (int64_t i = 0; i < MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2; i++) { + out_data[i] = (int64_t)msg_rcv.mtext[i]; + } + return; +} + +void SpeculateGetOutputStatic(const paddle::Tensor& x, + int64_t rank_id, + bool wait_flag, + bool get_each_rank) { + SpeculateGetOutput(x, rank_id, wait_flag, 1, get_each_rank); +} + +void SpeculateGetOutputDynamic(const paddle::Tensor& x, + int64_t rank_id, + bool wait_flag, + int msg_queue_id, + bool get_each_rank) { + SpeculateGetOutput(x, rank_id, wait_flag, msg_queue_id, get_each_rank); +} + +PD_BUILD_OP(speculate_get_output) + .Inputs({"x"}) + .Attrs({"rank_id: int64_t", "wait_flag: bool", "get_each_rank: bool"}) + .Outputs({"x_out"}) + .SetInplaceMap({{"x", "x_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateGetOutputStatic)); + +PD_BUILD_OP(speculate_get_output_dynamic) + .Inputs({"x"}) + .Attrs({"rank_id: int64_t", + "wait_flag: bool", + "msg_queue_id: int", + "get_each_rank: bool"}) + .Outputs({"x_out"}) + .SetInplaceMap({{"x", "x_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateGetOutputDynamic)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_output_padding_offset.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_output_padding_offset.cc new file mode 100644 index 000000000..b29240a08 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_output_padding_offset.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +std::vector SpeculateGetOutputPaddingOffset( + const paddle::Tensor& output_cum_offsets_tmp, + const paddle::Tensor& out_token_num, + const paddle::Tensor& seq_lens_output, + const int max_seq_len) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + baidu::xpu::api::Context* ctx = + static_cast(dev_ctx)->x_context(); + + if (output_cum_offsets_tmp.is_cpu()) { + ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU); + } + std::vector output_cum_offsets_tmp_shape = + output_cum_offsets_tmp.shape(); + const int bsz = output_cum_offsets_tmp_shape[0]; + auto cpu_out_token_num = out_token_num.copy_to(paddle::CPUPlace(), false); + + auto output_padding_offset = paddle::full({cpu_out_token_num}, + 0, + paddle::DataType::INT32, + output_cum_offsets_tmp.place()); + auto output_cum_offsets = + output_cum_offsets_tmp.copy_to(output_cum_offsets_tmp.place(), false); + + int r = baidu::xpu::api::plugin::speculate_get_output_padding_offset( + ctx, + output_padding_offset.mutable_data(), + output_cum_offsets.mutable_data(), + output_cum_offsets_tmp.data(), + seq_lens_output.data(), + bsz, + max_seq_len); + PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed."); + + return {output_padding_offset, output_cum_offsets}; +} + +std::vector> SpeculateGetOutputPaddingOffsetInferShape( + const std::vector& output_cum_offsets_tmp_shape, + const std::vector& out_token_num_shape, + const std::vector& seq_lens_output_shape) { + int64_t bsz = output_cum_offsets_tmp_shape[0]; + return {{-1}, {bsz}}; +} + +std::vector SpeculateGetOutputPaddingOffsetInferDtype( + const paddle::DataType& output_cum_offsets_tmp_dtype, + const paddle::DataType& out_token_num_dtype, + const paddle::DataType& seq_lens_output_dtype) { + return {output_cum_offsets_tmp_dtype, output_cum_offsets_tmp_dtype}; +} + +PD_BUILD_OP(speculate_get_output_padding_offset) + .Inputs({"output_cum_offsets_tmp", "out_token_num", "seq_lens_output"}) + .Outputs({"output_padding_offset", "output_cum_offsets"}) + .Attrs({"max_seq_len: int"}) + .SetKernelFn(PD_KERNEL(SpeculateGetOutputPaddingOffset)) + .SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetOutputPaddingOffsetInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetOutputPaddingOffsetInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_padding_offset.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_padding_offset.cc new file mode 100644 index 000000000..bd06ef2be --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_padding_offset.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +std::vector SpeculateGetPaddingOffset( + const paddle::Tensor& input_ids, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& token_num, + const paddle::Tensor& seq_len, + const paddle::Tensor& seq_lens_encoder) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + + std::vector input_ids_shape = input_ids.shape(); + const int bsz = seq_len.shape()[0]; + const int seq_length = input_ids_shape[1]; + const int max_draft_tokens = draft_tokens.shape()[1]; + auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false); + auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); + + const int token_num_data = cpu_token_num.data()[0]; + auto x_remove_padding = paddle::empty( + {token_num_data}, paddle::DataType::INT64, input_ids.place()); + auto padding_offset = paddle::empty( + {token_num_data}, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_q = + paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_k = + paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); + + PD_CHECK(input_ids.is_contiguous(), "Input ids tensor must be contiguous"); + PD_CHECK(draft_tokens.is_contiguous(), + "Draft tokens tensor must be contiguous"); + PD_CHECK(cum_offsets.is_contiguous(), + "Cum offsets tensor must be contiguous"); + PD_CHECK(seq_len.is_contiguous(), "Seq lens tensor must be contiguous"); + + int r = baidu::xpu::api::plugin::speculate_get_padding_offset( + xpu_ctx->x_context(), + padding_offset.data(), + cum_offsets_out.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + cum_offsets.data(), + seq_len.data(), + seq_length, + bsz); + PD_CHECK(r == 0, "XPU speculate_get_padding_offset failed"); + + r = baidu::xpu::api::plugin::speculate_remove_padding( + xpu_ctx->x_context(), + x_remove_padding.data(), + input_ids.data(), + draft_tokens.data(), + seq_len.data(), + seq_lens_encoder.data(), + cum_offsets_out.data(), + seq_length, + max_draft_tokens, + bsz, + token_num_data); + PD_CHECK(r == 0, "XPU speculate_remove_padding failed"); + + return {x_remove_padding, + cum_offsets_out, + padding_offset, + cu_seqlens_q, + cu_seqlens_k}; // , enc_token_num, dec_token_num}; +} + +std::vector> SpeculateGetPaddingOffsetInferShape( + const std::vector& input_ids_shape, + const std::vector& draft_tokens_shape, + const std::vector& cum_offsets_shape, + const std::vector& token_num_shape, + const std::vector& seq_len_shape, + const std::vector& seq_lens_encoder_shape) { + int64_t bsz = seq_len_shape[0]; + int64_t seq_len = input_ids_shape[1]; + return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; +} + +std::vector SpeculateGetPaddingOffsetInferDtype( + const paddle::DataType& input_ids_dtype, + const paddle::DataType& draft_tokens_dtype, + const paddle::DataType& cum_offsets_dtype, + const paddle::DataType& token_num_dtype, + const paddle::DataType& seq_len_dtype, + const paddle::DataType& seq_lens_encoder_dtype) { + return {input_ids_dtype, + seq_len_dtype, + seq_len_dtype, + seq_len_dtype, + seq_len_dtype}; +} + +PD_BUILD_OP(speculate_get_padding_offset) + .Inputs({"input_ids", + "draft_tokens", + "cum_offsets", + "token_num", + "seq_len", + "seq_lens_encoder"}) + .Outputs({"x_remove_padding", + "cum_offsets_out", + "padding_offset", + "cu_seqlens_q", + "cu_seqlens_k"}) + .SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset)) + .SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetPaddingOffsetInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetPaddingOffsetInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_seq_lens_output.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_seq_lens_output.cc new file mode 100644 index 000000000..3caf47696 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_seq_lens_output.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +std::vector SpeculateGetSeqLensOutput( + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + baidu::xpu::api::Context* ctx = + static_cast(dev_ctx)->x_context(); + + if (seq_lens_this_time.is_cpu()) { + ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU); + } + std::vector seq_lens_this_time_shape = seq_lens_this_time.shape(); + const int bsz = seq_lens_this_time_shape[0]; + + auto seq_lens_output = paddle::full( + {bsz}, 0, paddle::DataType::INT32, seq_lens_this_time.place()); + + int r = baidu::xpu::api::plugin::speculate_get_seq_lens_output( + ctx, + seq_lens_output.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + bsz); + PD_CHECK(r == 0, "speculate_get_seq_lens_output failed."); + + return {seq_lens_output}; +} + +std::vector> SpeculateGetSeqLensOutputInferShape( + const std::vector& seq_lens_this_time_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape) { + int64_t bsz = seq_lens_this_time_shape[0]; + return {{bsz}}; +} + +std::vector SpeculateGetSeqLensOutputInferDtype( + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype) { + return {seq_lens_this_time_dtype}; +} + +PD_BUILD_OP(speculate_get_seq_lens_output) + .Inputs({"seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder"}) + .Outputs({"seq_lens_output"}) + .SetKernelFn(PD_KERNEL(SpeculateGetSeqLensOutput)) + .SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetSeqLensOutputInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetSeqLensOutputInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_msg.h b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_msg.h new file mode 100644 index 000000000..64bd87eab --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_msg.h @@ -0,0 +1,31 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#define MAX_BSZ 256 +#define MAX_DRAFT_TOKENS 6 + +struct speculate_msgdata { + long mtype; // NOLINT + int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + + 2]; // stop_flag, bsz, tokens +}; diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_rebuild_append_padding.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_rebuild_append_padding.cc new file mode 100644 index 000000000..041d8c65e --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_rebuild_append_padding.cc @@ -0,0 +1,130 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +namespace api = baidu::xpu::api; +std::vector RebuildAppendPadding( + const paddle::Tensor& full_hidden_states, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& seq_len_encoder, + const paddle::Tensor& seq_len_decoder, + const paddle::Tensor& output_padding_offset, + int max_seq_len) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + api::Context* ctx = xpu_ctx->x_context(); + if (full_hidden_states.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + + int dim_embed = full_hidden_states.shape()[1]; + int output_token_num = output_padding_offset.shape()[0]; + int elem_nums = output_token_num * dim_embed; + + auto out = paddle::full({output_token_num, dim_embed}, + 0, + full_hidden_states.dtype(), + full_hidden_states.place()); + + int r; + switch (full_hidden_states.dtype()) { + case paddle::DataType::BFLOAT16: + using XPUTypeBF16 = typename XPUTypeTrait::Type; + typedef paddle::bfloat16 bf16_data_t; + r = api::plugin::speculate_rebuild_append_padding( + ctx, + const_cast(reinterpret_cast( + full_hidden_states.data())), + const_cast(cum_offsets.data()), + const_cast(seq_len_encoder.data()), + const_cast(seq_len_decoder.data()), + const_cast(output_padding_offset.data()), + max_seq_len, + dim_embed, + elem_nums, + reinterpret_cast(out.data())); + PD_CHECK(r == 0, "xpu::plugin::speculate_rebuild_append_padding failed."); + return {out}; + case paddle::DataType::FLOAT16: + using XPUTypeFP16 = typename XPUTypeTrait::Type; + typedef paddle::float16 fp16_data_t; + r = api::plugin::speculate_rebuild_append_padding( + ctx, + const_cast(reinterpret_cast( + full_hidden_states.data())), + const_cast(cum_offsets.data()), + const_cast(seq_len_encoder.data()), + const_cast(seq_len_decoder.data()), + const_cast(output_padding_offset.data()), + max_seq_len, + dim_embed, + elem_nums, + reinterpret_cast(out.data())); + PD_CHECK(r == 0, "xpu::plugin::speculate_rebuild_append_padding failed."); + return {out}; + case paddle::DataType::FLOAT32: + r = api::plugin::speculate_rebuild_append_padding( + ctx, + const_cast(full_hidden_states.data()), + const_cast(cum_offsets.data()), + const_cast(seq_len_encoder.data()), + const_cast(seq_len_decoder.data()), + const_cast(output_padding_offset.data()), + max_seq_len, + dim_embed, + elem_nums, + out.data()); + PD_CHECK(r == 0, "xpu::plugin::speculate_rebuild_append_padding failed."); + return {out}; + default: + PD_THROW("Unsupported data type."); + } +} + +std::vector> RebuildAppendPaddingInferShape( + const std::vector& full_hidden_states_shape, + const std::vector& cum_offsets_shape, + const std::vector& seq_len_encoder_shape, + const std::vector& seq_len_decoder_shape, + const std::vector& output_padding_offset_shape) { + const int64_t output_token_num = output_padding_offset_shape[0]; + const int64_t dim_embed = full_hidden_states_shape[1]; + std::vector out_shape = {output_token_num, dim_embed}; + return {out_shape}; +} + +std::vector RebuildAppendPaddingInferDtype( + const paddle::DataType& full_hidden_states_dtype, + const paddle::DataType& cum_offsets_dtype, + const paddle::DataType& seq_len_encoder_dtype, + const paddle::DataType& seq_len_decoder_dtype, + const paddle::DataType& output_padding_offset_dtype) { + return {full_hidden_states_dtype}; +} + +PD_BUILD_OP(speculate_rebuild_append_padding) + .Inputs({"full_hidden_states", + "cum_offsets", + "seq_len_encoder", + "seq_len_decoder", + "output_padding_offset"}) + .Attrs({"max_seq_len: int"}) + .Outputs({"out"}) + .SetKernelFn(PD_KERNEL(RebuildAppendPadding)) + .SetInferShapeFn(PD_INFER_SHAPE(RebuildAppendPaddingInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(RebuildAppendPaddingInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_save_output.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_save_output.cc new file mode 100644 index 000000000..60764b26a --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_save_output.cc @@ -0,0 +1,162 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +#define MAX_BSZ 256 +#define MAX_DRAFT_TOKENS 6 + +struct msgdata { + long mtype; // NOLINT + int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + + 2]; // stop_flag, bsz, tokens +}; + +void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& not_need_stop, + int64_t rank_id, + int msg_queue_id, + int save_each_rank) { + // printf("enter save output"); + if (!save_each_rank && rank_id > 0) { + return; + } + + int max_draft_tokens = accept_tokens.shape()[1]; + + auto accept_tokens_cpu = accept_tokens.copy_to(paddle::CPUPlace(), true); + auto accept_num_cpu = accept_num.copy_to(paddle::CPUPlace(), true); + int64_t* accept_tokens_data = accept_tokens_cpu.data(); + int* accept_num_data = accept_num_cpu.data(); + + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); +#ifdef GET_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; +#endif + msg_queue_id = inference_msg_queue_id_from_env; + } + static struct msgdata msg_sed; + static key_t key = ftok("./", msg_queue_id); + static int msgid = msgget(key, IPC_CREAT | 0666); + + msg_sed.mtype = 1; + bool not_need_stop_data = not_need_stop.data()[0]; + + int inference_msg_id_from_env = 1; + if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { + std::string inference_msg_id_env_str(inference_msg_id_env_p); + inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); + if (inference_msg_id_from_env == 2) { + // 2 and -2 is preserve for no-output indication. + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be 2, please use other number."); + } + if (inference_msg_id_from_env < 0) { + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be negative, please use other " + "number."); + } + +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env + << std::endl; +#endif + } else { +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." + << std::endl; +#endif + } + + msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env + : -inference_msg_id_from_env; + int bsz = accept_tokens.shape()[0]; + msg_sed.mtext[1] = bsz; + + for (int i = 2; i < MAX_BSZ + 2; i++) { + if (i - 2 >= bsz) { + msg_sed.mtext[i] = 0; + } else { + msg_sed.mtext[i] = static_cast(accept_num_data[i - 2]); + } + } + for (int i = MAX_BSZ + 2; i < MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2; i++) { + int token_id = i - MAX_BSZ - 2; + int bid = token_id / MAX_DRAFT_TOKENS; + int local_token_id = token_id % MAX_DRAFT_TOKENS; + if (token_id / MAX_DRAFT_TOKENS >= bsz) { + msg_sed.mtext[i] = 0; + } else { + msg_sed.mtext[i] = + accept_tokens_data[bid * max_draft_tokens + local_token_id]; + } + } + if ((msgsnd(msgid, + &msg_sed, + (MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4, + 0)) == -1) { + printf("full msg buffer\n"); + } + return; +} + +void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& not_need_stop, + int64_t rank_id, + bool save_each_rank) { + SpeculateSaveWithOutputMsg( + accept_tokens, accept_num, not_need_stop, rank_id, 1, save_each_rank); +} + +void SpeculateSaveWithOutputMsgDynamic(const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& not_need_stop, + int64_t rank_id, + int msg_queue_id, + bool save_each_rank) { + SpeculateSaveWithOutputMsg(accept_tokens, + accept_num, + not_need_stop, + rank_id, + msg_queue_id, + save_each_rank); +} + +PD_BUILD_OP(speculate_save_output) + .Inputs({"accept_tokens", "accept_num", "not_need_stop"}) + .Attrs({"rank_id: int64_t", "save_each_rank: bool"}) + .Outputs({"x_out"}) + .SetInplaceMap({{"accept_tokens", "x_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgStatic)); + +PD_BUILD_OP(speculate_save_output_dynamic) + .Inputs({"accept_tokens", "accept_num", "not_need_stop"}) + .Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"}) + .Outputs({"x_out"}) + .SetInplaceMap({{"accept_tokens", "x_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgDynamic)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_set_stop_value_multi_seqs.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_set_stop_value_multi_seqs.cc new file mode 100644 index 000000000..f54cba6c1 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_set_stop_value_multi_seqs.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" + +namespace api = baidu::xpu::api; +void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens, + const paddle::Tensor &accept_num, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens, + const paddle::Tensor &stop_seqs, + const paddle::Tensor &stop_seqs_len, + const paddle::Tensor &end_ids) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + api::Context *ctx = + static_cast(dev_ctx)->x_context(); + if (accept_tokens.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + PD_CHECK(accept_tokens.dtype() == paddle::DataType::INT64); + PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); + + std::vector shape = accept_tokens.shape(); + std::vector stop_seqs_shape = stop_seqs.shape(); + int bs_now = shape[0]; + int stop_seqs_bs = stop_seqs_shape[0]; + int stop_seqs_max_len = stop_seqs_shape[1]; + int pre_ids_len = pre_ids.shape()[1]; + int accept_tokens_len = accept_tokens.shape()[1]; + + int r = baidu::xpu::api::plugin::speculate_set_stop_value_multi_seqs( + ctx, + const_cast(stop_flags.data()), + const_cast(accept_tokens.data()), + const_cast(accept_num.data()), + pre_ids.data(), + step_idx.data(), + stop_seqs.data(), + stop_seqs_len.data(), + seq_lens.data(), + end_ids.data(), + bs_now, + accept_tokens_len, + stop_seqs_bs, + stop_seqs_max_len, + pre_ids_len); + PD_CHECK(r == 0, "xpu::plugin::speculate_set_stop_value_multi_seqs failed."); +} + +PD_BUILD_OP(speculate_set_stop_value_multi_seqs) + .Inputs({"accept_tokens", + "accept_num", + "pre_ids", + "step_idx", + "stop_flags", + "seq_lens", + "stop_seqs", + "stop_seqs_len", + "end_ids"}) + .Outputs({"accept_tokens_out", "stop_flags_out"}) + .SetInplaceMap({{"accept_tokens", "accept_tokens_out"}, + {"stop_flags", "stop_flags_out"}}) + .SetKernelFn(PD_KERNEL(SpecGetStopFlagsMultiSeqs)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_set_value_by_flags.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_set_value_by_flags.cc new file mode 100644 index 000000000..60843e88e --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_set_value_by_flags.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, + const paddle::Tensor &accept_tokens, + const paddle::Tensor &accept_num, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &step_idx) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + baidu::xpu::api::Context *ctx = + static_cast(dev_ctx)->x_context(); + + // auto xpu_ctx = static_cast(dev_ctx); + if (pre_ids_all.is_cpu()) { + ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU); + } + std::vector pre_ids_all_shape = pre_ids_all.shape(); + int bs = seq_lens_this_time.shape()[0]; + int length = pre_ids_all_shape[1]; + int max_draft_tokens = accept_tokens.shape()[1]; + + int r = baidu::xpu::api::plugin::speculate_set_value_by_flag_and_id( + ctx, + const_cast(pre_ids_all.data()), + accept_tokens.data(), + accept_num.data(), + stop_flags.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + step_idx.data(), + bs, + length, + max_draft_tokens); + PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed."); +} + +PD_BUILD_OP(speculate_set_value_by_flags_and_idx) + .Inputs({"pre_ids_all", + "accept_tokens", + "accept_num", + "stop_flags", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "step_idx"}) + .Outputs({"pre_ids_all_out"}) + .SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateSetValueByFlagsAndIdx)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_step_reschedule.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_step_reschedule.cc new file mode 100644 index 000000000..bc3675d4c --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_step_reschedule.cc @@ -0,0 +1,216 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/phi/core/enforce.h" +#include "speculate_msg.h" // NOLINT +#include "xpu/plugin.h" + +// 为不修改接口调用方式,入参暂不改变 +void SpeculateStepSchedule( + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &ori_seq_lens_encoder, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &step_block_list, + const paddle::Tensor &step_lens, + const paddle::Tensor &recover_block_list, + const paddle::Tensor &recover_lens, + const paddle::Tensor &need_block_list, + const paddle::Tensor &need_block_len, + const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, + const paddle::Tensor &free_list_len, + const paddle::Tensor &input_ids, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &next_tokens, + const paddle::Tensor &first_token_ids, + const paddle::Tensor &accept_num, + const int block_size, + const int encoder_decoder_block_num, + const int max_draft_tokens) { + namespace api = baidu::xpu::api; + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + api::Context *ctx = xpu_ctx->x_context(); + if (stop_flags.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + const int bsz = seq_lens_this_time.shape()[0]; + const int block_num_per_seq = block_tables.shape()[1]; + const int length = input_ids.shape()[1]; + const int pre_id_length = pre_ids.shape()[1]; + constexpr int BlockSize = 256; // bsz <= 256 + const int max_decoder_block_num = + length / block_size - + encoder_decoder_block_num; // 最大输出长度对应的block - + // 服务为解码分配的block数量 + auto step_lens_inkernel = + paddle::full({1}, 0, paddle::DataType::INT32, stop_flags.place()); + auto step_bs_list = + paddle::full({bsz}, 0, paddle::DataType::INT32, stop_flags.place()); + int r = baidu::xpu::api::plugin::speculate_free_and_reschedule( + ctx, + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(encoder_block_lens.data()), + const_cast(is_block_step.data()), + const_cast(step_bs_list.data()), + const_cast(step_lens_inkernel.data()), + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(need_block_list.data()), + const_cast(need_block_len.data()), + const_cast(used_list_len.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + const_cast(first_token_ids.data()), + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + PD_CHECK(r == 0, "speculate_free_and_reschedule failed."); + // save output + auto step_lens_cpu = step_lens_inkernel.copy_to(paddle::CPUPlace(), false); + if (step_lens_cpu.data()[0] > 0) { + auto step_bs_list_cpu = step_bs_list.copy_to(paddle::CPUPlace(), false); + auto next_tokens = + paddle::full({bsz}, -1, paddle::DataType::INT64, paddle::CPUPlace()); + for (int i = 0; i < step_lens_cpu.data()[0]; i++) { + const int step_bid = step_bs_list_cpu.data()[i]; + next_tokens.data()[step_bid] = -3; // need reschedule + } + const int rank_id = static_cast(stop_flags.place().GetDeviceId()); + printf("reschedule rank_id: %d, step_lens: %d", + rank_id, + step_lens_cpu.data()[0]); + const int64_t *x_data = next_tokens.data(); + static struct speculate_msgdata msg_sed; + int msg_queue_id = rank_id; + if (const char *inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); + msg_queue_id = inference_msg_queue_id_from_env; + } else { + std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default." + << std::endl; + } + int inference_msg_id_from_env = 1; + if (const char *inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { + std::string inference_msg_id_env_str(inference_msg_id_env_p); + inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); + if (inference_msg_id_from_env == 2) { + // 2 and -2 is perserve for no-output indication. + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be 2, please use other number."); + } + if (inference_msg_id_from_env < 0) { + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be negative, please use other " + "number."); + } + + } else { + } + // static key_t key = ftok("/dev/shm", msg_queue_id); + static key_t key = ftok("./", msg_queue_id); + + static int msgid = msgget(key, IPC_CREAT | 0666); + msg_sed.mtype = 1; + msg_sed.mtext[0] = inference_msg_id_from_env; + msg_sed.mtext[1] = bsz; + for (int i = 2; i < bsz + 2; i++) { + msg_sed.mtext[i] = static_cast(x_data[i - 2]); + } + if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) { + printf("full msg buffer\n"); + } + } +} + +PD_BUILD_OP(speculate_step_reschedule) + .Inputs({"stop_flags", + "seq_lens_this_time", + "ori_seq_lens_encoder", + "seq_lens_encoder", + "seq_lens_decoder", + "block_tables", + "encoder_block_lens", + "is_block_step", + "step_block_list", + "step_lens", + "recover_block_list", + "recover_lens", + "need_block_list", + "need_block_len", + "used_list_len", + "free_list", + "free_list_len", + "input_ids", + "pre_ids", + "step_idx", + "next_tokens", + "first_token_ids", + "accept_num"}) + .Attrs({"block_size: int", + "encoder_decoder_block_num: int", + "max_draft_tokens: int"}) + .Outputs({"stop_flags_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "block_tables_out", + "encoder_block_lens_out", + "is_block_step_out", + "step_block_list_out", + "step_lens_out", + "recover_block_list_out", + "recover_lens_out", + "need_block_list_out", + "need_block_len_out", + "used_list_len_out", + "free_list_out", + "free_list_len_out", + "input_ids_out", + "first_token_ids_out"}) + .SetInplaceMap({{"stop_flags", "stop_flags_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"block_tables", "block_tables_out"}, + {"encoder_block_lens", "encoder_block_lens_out"}, + {"is_block_step", "is_block_step_out"}, + {"step_block_list", "step_block_list_out"}, + {"step_lens", "step_lens_out"}, + {"recover_block_list", "recover_block_list_out"}, + {"recover_lens", "recover_lens_out"}, + {"need_block_list", "need_block_list_out"}, + {"need_block_len", "need_block_len_out"}, + {"used_list_len", "used_list_len_out"}, + {"free_list", "free_list_out"}, + {"free_list_len", "free_list_len_out"}, + {"input_ids", "input_ids_out"}, + {"first_token_ids", "first_token_ids_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateStepSchedule)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_token_penalty_multi_scores.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_token_penalty_multi_scores.cc new file mode 100644 index 000000000..0ecd4e139 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_token_penalty_multi_scores.cc @@ -0,0 +1,157 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" + +void TokenPenaltyMultiScores(const paddle::Tensor& pre_ids, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_scores, + const paddle::Tensor& presence_scores, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& output_padding_offset, + const paddle::Tensor& output_cum_offsets, + const int max_seq_len) { + namespace api = baidu::xpu::api; + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + api::Context* ctx = xpu_ctx->x_context(); + if (pre_ids.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + + int64_t bs = seq_lens_this_time.shape()[0]; + int64_t token_num = logits.shape()[0]; + PADDLE_ENFORCE_LE(bs, + 640, + phi::errors::InvalidArgument( + "Only support bs <= 640, but received bsz is %d", bs)); + int64_t length = logits.shape()[1]; + int64_t length_id = pre_ids.shape()[1]; + int64_t length_bad_words = bad_tokens.shape()[0]; + int64_t end_length = eos_token_id.shape()[0]; + switch (logits.type()) { + case paddle::DataType::BFLOAT16: { + using XPUType = typename XPUTypeTrait::Type; + typedef paddle::bfloat16 data_t; + int r = baidu::xpu::api::plugin::speculate_token_penalty_multi_scores( + ctx, + pre_ids.data(), + reinterpret_cast( + const_cast(logits.data())), + reinterpret_cast(penalty_scores.data()), + reinterpret_cast(frequency_scores.data()), + reinterpret_cast(presence_scores.data()), + temperatures.data(), + cur_len.data(), + min_len.data(), + eos_token_id.data(), + bad_tokens.data(), + output_padding_offset.data(), + output_cum_offsets.data(), + bs, + length, + length_id, + end_length, + length_bad_words, + token_num, + max_seq_len); + PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed."); + } break; + case paddle::DataType::FLOAT16: { + using XPUType = typename XPUTypeTrait::Type; + typedef paddle::float16 data_t; + int r = baidu::xpu::api::plugin::speculate_token_penalty_multi_scores( + ctx, + pre_ids.data(), + reinterpret_cast( + const_cast(logits.data())), + reinterpret_cast(penalty_scores.data()), + reinterpret_cast(frequency_scores.data()), + reinterpret_cast(presence_scores.data()), + temperatures.data(), + cur_len.data(), + min_len.data(), + eos_token_id.data(), + bad_tokens.data(), + output_padding_offset.data(), + output_cum_offsets.data(), + bs, + length, + length_id, + end_length, + length_bad_words, + token_num, + max_seq_len); + PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed."); + } break; + case paddle::DataType::FLOAT32: { + int r = baidu::xpu::api::plugin::speculate_token_penalty_multi_scores( + ctx, + pre_ids.data(), + const_cast(logits.data()), + penalty_scores.data(), + frequency_scores.data(), + presence_scores.data(), + temperatures.data(), + cur_len.data(), + min_len.data(), + eos_token_id.data(), + bad_tokens.data(), + output_padding_offset.data(), + output_cum_offsets.data(), + bs, + length, + length_id, + end_length, + length_bad_words, + token_num, + max_seq_len); + PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed."); + } break; + default: + PD_THROW( + "NOT supported data type. " + "Only float16 and float32 are supported. "); + break; + } +} + +PD_BUILD_OP(speculate_get_token_penalty_multi_scores) + .Inputs({"pre_ids", + "logits", + "penalty_scores", + "frequency_scores", + "presence_scores", + "temperatures", + "bad_tokens", + "cur_len", + "min_len", + "eos_token_id", + "seq_lens_this_time", + "output_padding_offset", + "output_cum_offsets"}) + .Outputs({"logits_out"}) + .Attrs({"max_seq_len: int"}) + .SetInplaceMap({{"logits", "logits_out"}}) + .SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_update_input_ids_cpu.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_update_input_ids_cpu.cc new file mode 100644 index 000000000..fceeb129e --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_update_input_ids_cpu.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +void UpdateInputIdsCPU(const paddle::Tensor& input_ids_cpu, + const std::vector& task_input_ids, + const int bid, + const int max_seq_len) { + int64_t* input_ids_cpu_data = + const_cast(input_ids_cpu.data()); + // printf("Input len is %d\n", task_input_ids.size()); + + for (int i = 0; i < task_input_ids.size(); i++) { + // printf("%lld\n", task_input_ids[i]); + input_ids_cpu_data[bid * max_seq_len + i] = task_input_ids[i]; + } +} + +PD_BUILD_OP(speculate_update_input_ids_cpu) + .Inputs({"input_ids_cpu"}) + .Outputs({"input_ids_cpu_out"}) + .Attrs({"task_input_ids: std::vector", + "bid: int", + "max_seq_len: int"}) + .SetInplaceMap({{"input_ids_cpu", "input_ids_cpu_out"}}) + .SetKernelFn(PD_KERNEL(UpdateInputIdsCPU)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_update_v3.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_update_v3.cc new file mode 100644 index 000000000..7d06582d9 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_update_v3.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +namespace api = baidu::xpu::api; + +void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor ¬_need_stop, + const paddle::Tensor &draft_tokens, + const paddle::Tensor &actual_draft_token_nums, + const paddle::Tensor &accept_tokens, + const paddle::Tensor &accept_num, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &is_block_step, + const paddle::Tensor &stop_nums) { + const int real_bsz = seq_lens_this_time.shape()[0]; + const int max_bsz = stop_flags.shape()[0]; + auto max_draft_tokens = draft_tokens.shape()[1]; + + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + api::Context *ctx = + static_cast(dev_ctx)->x_context(); + if (draft_tokens.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + + auto not_need_stop_xpu = not_need_stop.copy_to(stop_flags.place(), false); + int r = baidu::xpu::api::plugin::speculate_update_v3( + ctx, + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(not_need_stop_xpu.data()), + const_cast(draft_tokens.data()), + const_cast(actual_draft_token_nums.data()), + accept_tokens.data(), + accept_num.data(), + stop_flags.data(), + seq_lens_this_time.data(), + is_block_step.data(), + stop_nums.data(), + real_bsz, + max_bsz, + max_draft_tokens); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "speculate_update_v3"); + + auto not_need_stop_cpu = + not_need_stop_xpu.copy_to(not_need_stop.place(), true); + bool *not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; +} + +PD_BUILD_OP(speculate_update_v3) + .Inputs({"seq_lens_encoder", + "seq_lens_decoder", + "not_need_stop", + "draft_tokens", + "actual_draft_token_nums", + "accept_tokens", + "accept_num", + "stop_flags", + "seq_lens_this_time", + "is_block_step", + "stop_nums"}) + .Outputs({"seq_lens_encoder_out", + "seq_lens_decoder_out", + "not_need_stop_out", + "draft_tokens_out", + "actual_draft_token_nums_out"}) + .SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"not_need_stop", "not_need_stop_out"}, + {"draft_tokens", "draft_tokens_out"}, + {"actual_draft_token_nums", "actual_draft_token_nums_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateUpdateV3)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_verify.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_verify.cc new file mode 100644 index 000000000..2316d5ad7 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_verify.cc @@ -0,0 +1,251 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/common/flags.h" +#include "paddle/extension.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "ops/utility/debug.h" +#include "xpu/internal/infra_op.h" +#include "xpu/plugin.h" + +namespace api = baidu::xpu::api; + +void SpeculateVerify(const paddle::Tensor &accept_tokens, + const paddle::Tensor &accept_num, + const paddle::Tensor &step_idx, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &draft_tokens, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &verify_tokens, + const paddle::Tensor &verify_scores, + const paddle::Tensor &max_dec_len, + const paddle::Tensor &end_tokens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &output_cum_offsets, + const paddle::Tensor &actual_candidate_len, + const paddle::Tensor &actual_draft_token_nums, + const paddle::Tensor &topp, + int max_seq_len, + int verify_window, + bool enable_topp) { + auto bsz = accept_tokens.shape()[0]; + int real_bsz = seq_lens_this_time.shape()[0]; + auto max_draft_tokens = draft_tokens.shape()[1]; + auto end_length = end_tokens.shape()[0]; + auto max_candidate_len = verify_tokens.shape()[1]; + + constexpr int BlockSize = 512; + // set topp_seed if needed + const paddle::optional &topp_seed = nullptr; + + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + api::Context *ctx = + static_cast(dev_ctx)->x_context(); + bool xpu_ctx_flag = true; + if (draft_tokens.is_cpu()) { + ctx = new api::Context(api::kCPU); + xpu_ctx_flag = false; + } + + // phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + // auto dev_ctx = + // paddle::experimental::DeviceContextPool::Instance().Get(place); auto + // xpu_ctx = static_cast(dev_ctx); + + bool use_topk = false; + char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK"); + if (env_var) { + use_topk = static_cast(std::stoi(env_var)); + } + bool prefill_one_step_stop = false; + if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { + // std::cout << "Your PATH is: " << env_p << '\n'; + if (env_p[0] == '1') { + prefill_one_step_stop = true; + } + } + // random + int random_seed = 0; + std::vector infer_seed(bsz, random_seed); + std::uniform_real_distribution dist(0.0, 1.0); + std::vector dev_curand_states_cpu; + for (int i = 0; i < bsz; i++) { + std::mt19937_64 engine(infer_seed[i]); + dev_curand_states_cpu.push_back(dist(engine)); + } + float *dev_curand_states_xpu; + if (xpu_ctx_flag) { + xpu::ctx_guard RAII_GUARD(ctx); + dev_curand_states_xpu = + RAII_GUARD.alloc(dev_curand_states_cpu.size()); + xpu_memcpy(dev_curand_states_xpu, + dev_curand_states_cpu.data(), + dev_curand_states_cpu.size() * sizeof(float), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + } + + auto dev_curand_states = + !xpu_ctx_flag ? dev_curand_states_cpu.data() : dev_curand_states_xpu; + if (use_topk) { + if (enable_topp) { + baidu::xpu::api::plugin::speculate_verify( + ctx, + const_cast(accept_tokens.data()), + const_cast(accept_num.data()), + const_cast(step_idx.data()), + const_cast(stop_flags.data()), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + draft_tokens.data(), + actual_draft_token_nums.data(), + dev_curand_states, + topp.data(), + seq_lens_this_time.data(), + verify_tokens.data(), + verify_scores.data(), + max_dec_len.data(), + end_tokens.data(), + is_block_step.data(), + output_cum_offsets.data(), + actual_candidate_len.data(), + real_bsz, + max_draft_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + prefill_one_step_stop); + } else { + baidu::xpu::api::plugin::speculate_verify( + ctx, + const_cast(accept_tokens.data()), + const_cast(accept_num.data()), + const_cast(step_idx.data()), + const_cast(stop_flags.data()), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + draft_tokens.data(), + actual_draft_token_nums.data(), + dev_curand_states, + topp.data(), + seq_lens_this_time.data(), + verify_tokens.data(), + verify_scores.data(), + max_dec_len.data(), + end_tokens.data(), + is_block_step.data(), + output_cum_offsets.data(), + actual_candidate_len.data(), + real_bsz, + max_draft_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + prefill_one_step_stop); + } + } else { + if (enable_topp) { + baidu::xpu::api::plugin::speculate_verify( + ctx, + const_cast(accept_tokens.data()), + const_cast(accept_num.data()), + const_cast(step_idx.data()), + const_cast(stop_flags.data()), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + draft_tokens.data(), + actual_draft_token_nums.data(), + dev_curand_states, + topp.data(), + seq_lens_this_time.data(), + verify_tokens.data(), + verify_scores.data(), + max_dec_len.data(), + end_tokens.data(), + is_block_step.data(), + output_cum_offsets.data(), + actual_candidate_len.data(), + real_bsz, + max_draft_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + prefill_one_step_stop); + } else { + baidu::xpu::api::plugin::speculate_verify( + ctx, + const_cast(accept_tokens.data()), + const_cast(accept_num.data()), + const_cast(step_idx.data()), + const_cast(stop_flags.data()), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + draft_tokens.data(), + actual_draft_token_nums.data(), + dev_curand_states, + topp.data(), + seq_lens_this_time.data(), + verify_tokens.data(), + verify_scores.data(), + max_dec_len.data(), + end_tokens.data(), + is_block_step.data(), + output_cum_offsets.data(), + actual_candidate_len.data(), + real_bsz, + max_draft_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + prefill_one_step_stop); + } + } +} + +PD_BUILD_OP(speculate_verify) + .Inputs({"accept_tokens", + "accept_num", + "step_idx", + "stop_flags", + "seq_lens_encoder", + "seq_lens_decoder", + "draft_tokens", + "seq_lens_this_time", + "verify_tokens", + "verify_scores", + "max_dec_len", + "end_tokens", + "is_block_step", + "output_cum_offsets", + "actual_candidate_len", + "actual_draft_token_nums", + "topp"}) + .Outputs({"accept_tokens_out", + "accept_num_out", + "step_idx_out", + "stop_flags_out"}) + .Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"}) + .SetInplaceMap({{"accept_tokens", "accept_tokens_out"}, + {"accept_num", "accept_num_out"}, + {"step_idx", "step_idx_out"}, + {"stop_flags", "stop_flags_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateVerify)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/top_p_candidates.cc b/custom_ops/xpu_ops/src/ops/mtp_ops/top_p_candidates.cc new file mode 100644 index 000000000..f5c47ce7d --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp_ops/top_p_candidates.cc @@ -0,0 +1,158 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +#define FIXED_TOPK_BASE(topk, ...) \ + case (topk): { \ + constexpr auto kTopK = topk; \ + __VA_ARGS__; \ + } break + +#define FIXED_TOPK(...) \ + FIXED_TOPK_BASE(2, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(3, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(4, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(5, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(8, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(10, ##__VA_ARGS__); \ + default: { \ + PD_THROW("Unsupported candidates_len."); \ + } + +namespace api = baidu::xpu::api; +std::vector TopPCandidates( + const paddle::Tensor& probs, + const paddle::Tensor& top_p, + const paddle::Tensor& output_padding_offset, + int candidates_len, + int max_seq_len) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + api::Context* ctx = xpu_ctx->x_context(); + if (probs.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + + std::vector input_shape = probs.shape(); + const int token_num = input_shape[0]; + const int vocab_size = input_shape[1]; + + auto verify_scores = + paddle::empty({token_num, candidates_len}, probs.dtype(), probs.place()); + auto verify_tokens = paddle::empty( + {token_num, candidates_len}, paddle::DataType::INT64, probs.place()); + auto actual_candidate_lens = + paddle::empty({token_num}, paddle::DataType::INT32, probs.place()); + + constexpr int TopKMaxLength = 2; + int r; + switch (probs.dtype()) { + case paddle::DataType::BFLOAT16: + using XPUTypeBF16 = typename XPUTypeTrait::Type; + typedef paddle::bfloat16 bf16_data_t; + switch (candidates_len) { + FIXED_TOPK( + r = api::plugin::top_p_candidates( + ctx, + reinterpret_cast(probs.data()), + reinterpret_cast(top_p.data()), + output_padding_offset.data(), + verify_tokens.data(), + reinterpret_cast( + verify_scores.data()), + actual_candidate_lens.data(), + vocab_size, + token_num, + candidates_len, + max_seq_len); + PD_CHECK(r == 0, "xpu::plugin::top_p_candidates failed."); + return {verify_scores, verify_tokens, actual_candidate_lens}); + } + case paddle::DataType::FLOAT16: + using XPUTypeFP16 = typename XPUTypeTrait::Type; + typedef paddle::float16 fp16_data_t; + switch (candidates_len) { + FIXED_TOPK( + r = api::plugin::top_p_candidates( + ctx, + reinterpret_cast(probs.data()), + reinterpret_cast(top_p.data()), + output_padding_offset.data(), + verify_tokens.data(), + reinterpret_cast( + verify_scores.data()), + actual_candidate_lens.data(), + vocab_size, + token_num, + candidates_len, + max_seq_len); + PD_CHECK(r == 0, "xpu::plugin::top_p_candidates failed."); + return {verify_scores, verify_tokens, actual_candidate_lens}); + } + case paddle::DataType::FLOAT32: + switch (candidates_len) { + FIXED_TOPK( + r = api::plugin::top_p_candidates( + ctx, + probs.data(), + top_p.data(), + output_padding_offset.data(), + verify_tokens.data(), + verify_scores.data(), + actual_candidate_lens.data(), + vocab_size, + token_num, + candidates_len, + max_seq_len); + PD_CHECK(r == 0, "xpu::plugin::top_p_candidates failed."); + return {verify_scores, verify_tokens, actual_candidate_lens}); + } + default: + PD_THROW("Unsupported data type."); + } +} + +std::vector> TopPCandidatesInferShape( + const std::vector& probs_shape, + const std::vector& top_p_shape, + const std::vector& output_padding_offset_shape, + int max_candidates_len) { + int token_num = probs_shape[0]; + return {{token_num, max_candidates_len}, + {token_num, max_candidates_len}, + {token_num}}; +} + +std::vector TopPCandidatesInferDtype( + const paddle::DataType& probs_dtype, + const paddle::DataType& top_p_dtype, + const paddle::DataType& output_padding_offset_dtype) { + return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32}; +} + +PD_BUILD_OP(top_p_candidates) + .Inputs({"probs", "top_p", "output_padding_offset"}) + .Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"}) + .Attrs({"candidates_len: int", "max_seq_len: int"}) + .SetKernelFn(PD_KERNEL(TopPCandidates)) + .SetInferShapeFn(PD_INFER_SHAPE(TopPCandidatesInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(TopPCandidatesInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/utility/debug.h b/custom_ops/xpu_ops/src/ops/utility/debug.h new file mode 100755 index 000000000..8a42998d9 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/utility/debug.h @@ -0,0 +1,95 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/extension.h" + +namespace paddle { + +std::string string_format(const std::string fmt_str, ...); + +template +static T string_parse(const std::string& v) { + return v; +} + +template <> +int32_t string_parse(const std::string& v) { + return std::stoi(v); +} + +template <> +int64_t string_parse(const std::string& v) { + return std::stoll(v); +} + +template <> +float string_parse(const std::string& v) { + return std::stof(v); +} + +template <> +double string_parse(const std::string& v) { + return std::stod(v); +} + +template <> +bool string_parse(const std::string& v) { + std::string upper; + for (size_t i = 0; i < v.length(); i++) { + char ch = v[i]; + if (ch >= 'a' && ch <= 'z') { + ch = ch - 'a' + 'A'; + } + upper.push_back(ch); + } + return upper == "TRUE" || upper == "1"; +} + +template +static std::vector string_split(const std::string& original, + const std::string& separator) { + std::vector results; + std::string::size_type pos1, pos2; + pos2 = original.find(separator); + pos1 = 0; + while (std::string::npos != pos2) { + results.push_back(string_parse(original.substr(pos1, pos2 - pos1))); + pos1 = pos2 + separator.size(); + pos2 = original.find(separator, pos1); + } + if (pos1 != original.length()) { + results.push_back(string_parse(original.substr(pos1))); + } + return results; +} + +std::string shape_to_string(const std::vector& shape); + +template +void DebugPrintXPUTensor(const phi::XPUContext* xpu_ctx, + const paddle::Tensor& input, + std::string tag = "", + int len = 1); + +template +void DebugPrintXPUTensorv2(const paddle::Tensor& input, + std::string tag = "", + int len = 1); + +} // namespace paddle diff --git a/custom_ops/xpu_ops/src/plugin/CMakeLists.txt b/custom_ops/xpu_ops/src/plugin/CMakeLists.txt index 2830db578..7941025f0 100644 --- a/custom_ops/xpu_ops/src/plugin/CMakeLists.txt +++ b/custom_ops/xpu_ops/src/plugin/CMakeLists.txt @@ -174,6 +174,10 @@ macro( separate_arguments(arg_device_o_extra_flags) set(arg_host_o_extra_flags ${host_o_extra_flags}) separate_arguments(arg_host_o_extra_flags) + set(MTP_KERNEL_COMPILE_FLAGS "") + if(${kernel_path} MATCHES "mtp_kernel") + list(APPEND MTP_KERNEL_COMPILE_FLAGS -mllvm -fix-mfence=all) + endif() add_custom_command( OUTPUT ${kernel_name}.device.bin.o ${kernel_name}.o @@ -181,8 +185,8 @@ macro( ${XPU_CLANG} -std=c++11 ${OPT_LEVEL} ${arg_device_o_extra_flags} -c ${kernel_path} -D ${xpu_n_macro} --target=${TARGET_ARCH} ${HOST_XPU_FLAGS} --basename ${kernel_name} -fno-builtin --xpu-arch=${xpu_n} -fPIC - -Wno-int-to-void-pointer-cast -Wno-int-to-pointer-cast -Werror -mllvm - --xpu-inline-cost -mllvm --xpu-inline-hot-call -I${XDNN_INC_DIR} -I${XRE_INC_DIR} + -Wno-int-to-void-pointer-cast -Wno-int-to-pointer-cast -Werror ${MTP_KERNEL_COMPILE_FLAGS} + -mllvm --xpu-inline-cost -mllvm --xpu-inline-hot-call -I${XDNN_INC_DIR} -I${XRE_INC_DIR} -fxpu-launch-return -I${CMAKE_CURRENT_SOURCE_DIR}/include -I${CMAKE_CURRENT_SOURCE_DIR}/src -I${CMAKE_CURRENT_SOURCE_DIR}/src/kernel diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index ce6262044..0033e89de 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -139,6 +139,307 @@ template DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x, const TSCALE *scale_in, TY *y, TSCALE *scale_out, int64_t m, int64_t n); + + +/*--------------------------------------- MTP being --------------------------------------------*/ + +template +DLL_EXPORT int speculate_token_penalty_multi_scores( + Context* ctx, + const int64_t* pre_ids, + T* logits, + const T* penalty_scores, + const T* frequency_scores, + const T* presence_scores, + const float* temperatures, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t* bad_words, + const int* output_padding_offset, + const int* output_cum_offsets, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words, + const int64_t token_num, + const int64_t max_seq_len); +DLL_EXPORT int mtp_free_and_dispatch_block(Context* ctx, + bool* base_model_stop_flags, + bool* stop_flags, + bool* batch_drop, + int* seq_lens_this_time, + int* seq_lens_decoder, + int* block_tables, + int* encoder_block_lens, + int* used_list_len, + int* free_list, + int* free_list_len, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_draft_tokens); + + +template +DLL_EXPORT int speculate_verify(Context* ctx, + int64_t* accept_tokens, + int* accept_num, + int64_t* step_idx, + bool* stop_flags, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int64_t* draft_tokens, + const int* actual_draft_token_nums, + const float* dev_curand_states, + const float* topp, + const int* seq_lens_this_time, + const int64_t* verify_tokens, + const float* verify_scores, + const int64_t* max_dec_len, + const int64_t* end_tokens, + const bool* is_block_step, + const int* output_cum_offsets, + const int* actual_candidate_len, + const int real_bsz, + const int max_draft_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const bool prefill_one_step_stop); + +DLL_EXPORT int speculate_clear_accept_nums(Context* ctx, + int* accept_num, + const int* seq_lens_decoder, + const int max_bsz); + +DLL_EXPORT int speculate_get_seq_lens_output(Context* ctx, + int* seq_lens_output, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int real_bsz); + +DLL_EXPORT int draft_model_update(Context* ctx, + const int64_t* inter_next_tokens, + int64_t* draft_tokens, + int64_t* pre_ids, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + const int* output_cum_offsets, + bool* stop_flags, + bool* not_need_stop, + const int64_t* max_dec_len, + const int64_t* end_ids, + int64_t* base_model_draft_tokens, + const int bsz, + const int max_draft_token, + const int pre_id_length, + const int max_base_model_draft_token, + const int end_ids_len, + const int max_seq_len, + const int substep, + const bool prefill_one_step_stop); + +DLL_EXPORT int draft_model_preprocess(api::Context* ctx, + int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + int* seq_lens_encoder_record, + int* seq_lens_decoder_record, + bool* not_need_stop, + bool* batch_drop, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + int real_bsz, + int max_draft_token, + int accept_tokens_len, + int draft_tokens_len, + int input_ids_len, + int base_model_draft_tokens_len, + bool truncate_first_token, + bool splitwise_prefill); + +DLL_EXPORT int speculate_set_stop_value_multi_seqs(Context* ctx, + bool* stop_flags, + int64_t* accept_tokens, + int* accept_nums, + const int64_t* pre_ids, + const int64_t* step_idx, + const int64_t* stop_seqs, + const int* stop_seqs_len, + const int* seq_lens, + const int64_t* end_ids, + const int bs_now, + const int accept_tokens_len, + const int stop_seqs_bs, + const int stop_seqs_max_len, + const int pre_ids_len); +template +DLL_EXPORT int speculate_rebuild_append_padding(api::Context* ctx, + T* full_hidden_states, + int* cum_offsets, + int* seq_len_encoder, + int* seq_len_decoder, + int* output_padding_offset, + int max_seq_len, + int dim_embed, + int elem_nums, + T* out); + +template +DLL_EXPORT int speculate_remove_padding(Context* ctx, + T* x_remove_padding, + const T* input_ids, + const T* draft_tokens, + const int* seq_lens, + const int* seq_lens_encoder, + const int* cum_offsets_out, + int seq_length, + int max_draft_tokens, + int bsz, + int token_num_data); + +DLL_EXPORT int speculate_get_padding_offset(Context* ctx, + int* padding_offset, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz); + +DLL_EXPORT int compute_self_order(api::Context* ctx, + const int* last_seq_lens_this_time, + const int* seq_lens_this_time, + const int64_t* step_idx, + int* src_map, + int* output_token_num, + int bsz); + +DLL_EXPORT int compute_order(api::Context* ctx, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* accept_nums, + int* position_map, + int* output_token_num, + const int bsz, + const int actual_draft_token_num, + const int input_token_num); + +DLL_EXPORT int draft_model_postprocess(Context* ctx, + const int64_t* base_model_draft_tokens, + int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const bool* base_model_stop_flags, + int bsz, + int base_model_draft_token_len); + +DLL_EXPORT int speculate_set_value_by_flag_and_id(Context* ctx, + int64_t* pre_ids_all, + const int64_t* accept_tokens, + const int* accept_num, + const bool* stop_flags, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int64_t* step_idx, + int bs, + int length, + int max_draft_tokens); + +DLL_EXPORT int speculate_get_output_padding_offset( + Context* ctx, + int* output_padding_offset, + int* output_cum_offsets, + const int* output_cum_offsets_tmp, + const int* seq_lens_output, + const int bsz, + const int max_seq_len); + +template +DLL_EXPORT int top_p_candidates(api::Context* ctx, + const T* src, + const T* top_ps, + const int* output_padding_offset, + int64_t* out_id, + T* out_val, + int* actual_candidates_lens, + int vocab_size, + int token_num, + int max_cadidate_len, + int max_seq_len); + +DLL_EXPORT int speculate_free_and_reschedule(Context* ctx, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_decoder, + int* block_tables, + int* encoder_block_lens, + bool* is_block_step, + int* step_block_list, // [bsz] + int* step_len, + int* recover_block_list, + int* recover_len, + int* need_block_list, + int* need_block_len, + int* used_list_len, + int* free_list, + int* free_list_len, + int64_t* first_token_ids, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens); + +DLL_EXPORT int speculate_update_v3(Context* ctx, + int* seq_lens_encoder, + int* seq_lens_decoder, + bool* not_need_stop, + int64_t* draft_tokens, + int* actual_draft_token_nums, + const int64_t* accept_tokens, + const int* accept_num, + const bool* stop_flags, + const int* seq_lens_this_time, + const bool* is_block_step, + const int64_t* stop_nums, + const int real_bsz, + const int max_bsz, + const int max_draft_tokens); +template +DLL_EXPORT int rebuild_hidden_states(api::Context* ctx, + const T* input, + const int* position_map, + T* out, + int dim_embed, + int elem_cnt); +template +DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx, + const T* input, + int* src_map, + T* output, + int dim_embed, + int elem_cnt); +/*--------------------------------------- MTP end --------------------------------------------*/ + } // namespace plugin } // namespace api } // namespace xpu diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu new file mode 100644 index 000000000..7cd399d09 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu @@ -0,0 +1,112 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +__global__ void ComputeOrderKernel(const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* accept_nums, + int* position_map, + int* output_token_num, + const int bsz, + const int actual_draft_token_num, + const int input_token_num) { + int tid = core_id() * cluster_num() + cluster_id(); + if (tid != 0) { + return; + } + + char lm[6 * 1024]; + int buf_size = 6 * 1024 / (6 * sizeof(int)); + int* lm_base_model_seq_lens_this_time = (int*)lm; + int* lm_base_model_seq_lens_encoder = + lm_base_model_seq_lens_this_time + buf_size; + int* lm_seq_lens_this_time = lm_base_model_seq_lens_encoder + buf_size; + int* lm_accept_nums = lm_seq_lens_this_time + buf_size; + int* lm_seq_lens_encoder = lm_accept_nums + buf_size; + int* lm_position_map = lm_seq_lens_encoder + buf_size; + + int in_offset = 0; + int out_offset = 0; + for (int i = 0; i < bsz; i += buf_size) { + int64_t read_size = min(static_cast(bsz - i), buf_size); + GM2LM_ASYNC(base_model_seq_lens_this_time + i, + lm_base_model_seq_lens_this_time, + read_size * sizeof(int)); + GM2LM_ASYNC(base_model_seq_lens_encoder + i, + lm_base_model_seq_lens_encoder, + read_size * sizeof(int)); + GM2LM_ASYNC( + seq_lens_this_time + i, lm_seq_lens_this_time, read_size * sizeof(int)); + GM2LM_ASYNC(accept_nums + i, lm_accept_nums, read_size * sizeof(int)); + GM2LM(seq_lens_encoder + i, lm_seq_lens_encoder, read_size * sizeof(int)); + for (int j = 0; j < read_size; j++) { + int cur_base_model_seq_lens_this_time = + lm_base_model_seq_lens_this_time[j]; + int cur_base_model_seq_lens_encoder = lm_base_model_seq_lens_encoder[j]; + int cur_seq_lens_this_time = lm_seq_lens_this_time[j]; + int accept_num = lm_accept_nums[j]; + int cur_seq_lens_encoder = lm_seq_lens_encoder[j]; + // 1. eagle encoder. Base step=1 + if (cur_seq_lens_encoder > 0) { + for (int k = 0; k < cur_seq_lens_encoder; k += buf_size) { + int64_t write_size = + min(static_cast(cur_seq_lens_encoder - k), + static_cast(buf_size)); + for (int l = 0; l < write_size; l++) { + lm_position_map[l] = out_offset; + out_offset++; + } + mfence_lm(); + LM2GM(lm_position_map, + position_map + in_offset, + write_size * sizeof(int)); + in_offset += write_size; + } + mfence_lm(); + // 2. base model encoder. Base step=0 + } else if (cur_base_model_seq_lens_encoder != 0) { + // nothing happens + // 3. New end + } else if (cur_base_model_seq_lens_this_time != 0 && + cur_seq_lens_this_time == 0) { + in_offset += cur_base_model_seq_lens_this_time; + // 4. stopped + } else if (cur_base_model_seq_lens_this_time == 0 && + cur_seq_lens_this_time == 0) { + // nothing happens + } else { + if (accept_num <= actual_draft_token_num) { + int position_map_val = out_offset; + LM2GM(&position_map_val, + position_map + in_offset + accept_num - 1, + sizeof(int)); + out_offset++; + in_offset += cur_base_model_seq_lens_this_time; + } else { + int position_map_val_1 = out_offset; + LM2GM(&position_map_val_1, + position_map + in_offset + accept_num - 2, + sizeof(int)); + out_offset++; + int position_map_val_2 = out_offset; + LM2GM(&position_map_val_2, + position_map + in_offset + accept_num - 1, + sizeof(int)); + out_offset++; + in_offset += cur_base_model_seq_lens_this_time; + } + mfence_lm(); + } + } + } + mfence_lm(); + LM2GM(&out_offset, output_token_num, sizeof(int)); +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_self_order.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_self_order.xpu new file mode 100644 index 000000000..a64f46caa --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_self_order.xpu @@ -0,0 +1,75 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +__global__ void ComputeSelfOrderKernel(const int* last_seq_lens_this_time, + const int* seq_lens_this_time, + const int64_t* step_idx, + int* src_map, + int* output_token_num, + int bsz) { + int tid = core_id() * cluster_num() + cluster_id(); + if (tid != 0) { + return; + } + + char lm[6 * 1024]; + int buf_size = 256; + int* lm_last_seq_lens_this_time = (int*)lm; + int* lm_seq_lens_this_time = lm_last_seq_lens_this_time + buf_size; + int64_t* lm_step_idx = (int64_t*)(lm_seq_lens_this_time + buf_size); + int* lm_src_map = (int*)(lm_step_idx + buf_size); + + int in_offset = 0; + int out_offset = 0; + int previous_out_offset = out_offset; + for (int i = 0; i < bsz; i += buf_size) { + int64_t read_size = min(static_cast(bsz - i), buf_size); + GM2LM_ASYNC(last_seq_lens_this_time + i, + lm_last_seq_lens_this_time, + read_size * sizeof(int)); + GM2LM_ASYNC( + seq_lens_this_time + i, lm_seq_lens_this_time, read_size * sizeof(int)); + GM2LM(step_idx + i, lm_step_idx, read_size * sizeof(int64_t)); + for (int j = 0; j < read_size; j++) { + int cur_seq_lens_this_time = lm_seq_lens_this_time[j]; + int cur_last_seq_lens_this_time = lm_last_seq_lens_this_time[j]; + int64_t cur_step_idx = lm_step_idx[j]; + // 1. encoder + if (cur_step_idx == 1 && cur_seq_lens_this_time > 0) { + in_offset += 1; + lm_src_map[j] = in_offset - 1; + out_offset++; + // 2. decoder + } else if (cur_seq_lens_this_time > 0) /* =1 */ { + in_offset += cur_last_seq_lens_this_time; + lm_src_map[j] = in_offset - 1; + out_offset++; + // 3. stop + } else { + // first token end + if (cur_step_idx == 1) { + in_offset += cur_last_seq_lens_this_time > 0 ? 1 : 0; + // normal end + } else { + in_offset += cur_last_seq_lens_this_time; + } + } + } + mfence_lm(); + if (out_offset > previous_out_offset) { + LM2GM_ASYNC(lm_src_map, + src_map + previous_out_offset, + (out_offset - previous_out_offset) * sizeof(int)); + } + previous_out_offset = out_offset; + } + mfence_lm(); + LM2GM(&out_offset, output_token_num, sizeof(int)); +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_postprocess.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_postprocess.xpu new file mode 100644 index 000000000..61c21e79d --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_postprocess.xpu @@ -0,0 +1,189 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/cluster_primitive_template.h" +#include "xpu/kernel/cluster_simd.h" +#include "xpu/kernel/xtdk_io.h" +namespace xpu3 { +namespace plugin { + +static inline __device__ int v_reduce(int32x16_t v) { + auto v0 = vsrlp_int32x16(256, v); + v = vvadd_int32x16(v0, v); + v0 = vsrlp_int32x16(128, v); + v = vvadd_int32x16(v0, v); + v0 = vsrlp_int32x16(64, v); + v = vvadd_int32x16(v0, v); + v0 = vsrlp_int32x16(32, v); + v = vvadd_int32x16(v0, v); + int res; + res = vextract_int32x16(v); + return res; +} + +__device__ int do_calc(int64_t* lmptr, int read_len) { + int res = 0; + int32x16_t v0; + int32x16_t v1; + int32x16_t v2 = {0}; + int* lmptr_i16 = (int*)lmptr; + int rounddown_size = rounddown16(read_len * 2); + int comp = -1; + int i = 0; + for (; i < rounddown_size; i += 16) { + v0 = vload_lm_int32x16(lmptr_i16 + i); + v1 = vload_lm_int32x16(lmptr_i16 + i); + unsigned int mask0 = + static_cast(sveq_int32x16_mz(comp, v0, 0xAAAA)); + unsigned int mask1 = + static_cast(sveq_int32x16_mz(comp, v1, 0x5555)); + mask1 = mask1 << 1; + unsigned int mask2 = (mask0 & 0xFFFFFFFF) & (mask1 & 0xFFFFFFFF); + v2 = svadd_int32x16_mh(1, v2, v2, mask2); + } + res = i / 2 - v_reduce(v2); + mfence_lm(); + for (int j = i / 2; j < read_len; j++) { + if (lmptr[j] != -1) { + res += 1; + } + } + return res; +} + +__global__ void draft_model_postprocess(const int64_t* base_model_draft_tokens, + int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const bool* base_model_stop_flags, + int bsz, + int base_model_draft_token_len) { + int cid = core_id(); + int ncores = core_num(); + int nclusters = cluster_num(); + int nthreads = ncores * nclusters; + const int max_sm_len = 256 * 1024 / sizeof(int); + const int core_limit_row = max_sm_len / ncores; + const int clusetr_limit_row = max_sm_len * nclusters; + int bsz_start_cluster; + int bsz_end_cluster; + int bsz_start_core; + int bsz_end_core; + + int row_to_partition = min(bsz, clusetr_limit_row); + // cluster partition + partition(cluster_id(), + nclusters, + row_to_partition, + 1, + &bsz_start_cluster, + &bsz_end_cluster); + if (bsz_start_cluster >= bsz_end_cluster) { + return; + } + int rows_cluster = + bsz_end_cluster - bsz_start_cluster; // total rows for a cluster + // core partition + partition( + core_id(), core_num(), rows_cluster, 1, &bsz_start_core, &bsz_end_core); + __shared__ int base_model_sm[max_sm_len]; + const int LM_SIZE = 3072; + const int BUFSIZE = LM_SIZE / sizeof(int64_t); + __simd__ int64_t output_lm[BUFSIZE * 2]; + DoublePtr> local_base_model( + (LmPtr((int64_t*)output_lm))); + const int BSZ_BUF = 16; + __simd__ bool base_model_stop_lm[BSZ_BUF]; + __simd__ int base_model_seq_lm[BSZ_BUF]; + + bsz_start_core += bsz_start_cluster; + bsz_end_core += bsz_start_cluster; + int read_len_sm = 0; + int offset_loop = 0; + int offset_cluster = 0; + int cur_row_to_all_clusetr = 0; + for (int limit_loop = 0; limit_loop < roundup_div(bsz, clusetr_limit_row); + limit_loop += 1) { + offset_loop = limit_loop * clusetr_limit_row; + if (bsz_start_core + offset_loop >= bsz) { + break; + } + cur_row_to_all_clusetr = min(bsz - offset_loop, clusetr_limit_row); + offset_cluster = 0; + //计算offset_cluster + for (int start_cluster = 0; start_cluster < cluster_id(); + start_cluster += 1) { + offset_cluster += (rounddown_div(cur_row_to_all_clusetr, nclusters) + + (start_cluster < cur_row_to_all_clusetr % nclusters)); + } + if (core_id() == 0) { + if (cur_row_to_all_clusetr < nclusters) { + // bsz很小, 每个cluster平均分不到一个,读一个长度就好 + read_len_sm = 1; + } else { + // bsz足够大, 每个cluster读一部分, 最大个数max_sm_len + read_len_sm = + min(max_sm_len, + rounddown_div(cur_row_to_all_clusetr, nclusters) + + (cluster_id() < (cur_row_to_all_clusetr % nclusters))); + } + GM2SM(base_model_seq_lens_this_time + offset_loop + offset_cluster, + base_model_sm + offset_cluster, + sizeof(int) * read_len_sm); + } + cur_row_to_all_clusetr -= clusetr_limit_row; + sync_cluster(); + for (int bsz_index = bsz_start_core + offset_loop; + (bsz_index < bsz_end_core + offset_loop) && (bsz_index < bsz); + bsz_index += 1) { + int bsz_offset = bsz_index - bsz_start_core; + if (bsz_offset % BSZ_BUF == 0) { + int64_t readm = min(bsz - bsz_index, BSZ_BUF); + GM2LM_ASYNC(base_model_stop_flags + bsz_index, + base_model_stop_lm, + sizeof(bool) * readm); + GM2LM(base_model_seq_lens_encoder + bsz_index, + base_model_seq_lm, + sizeof(int) * readm); + } + if (!base_model_stop_lm[bsz_offset % BSZ_BUF] && + (base_model_seq_lm[bsz_offset % BSZ_BUF] == 0)) { + // 计算有效token数量(非-1的token) + int token_num = 0; + int j = 0; + int read_len = min(base_model_draft_token_len - j, BUFSIZE); + local_base_model.gm_load(base_model_draft_tokens + + bsz_index * base_model_draft_token_len + j, + read_len); + for (; j < base_model_draft_token_len; j += BUFSIZE) { + int next_idx = j + BUFSIZE; + int read_len_next = + min(base_model_draft_token_len - next_idx, BUFSIZE); + if (read_len_next > 0) { + local_base_model.next().gm_load_async( + base_model_draft_tokens + + bsz_index * base_model_draft_token_len + next_idx, + read_len_next); + } + + token_num += do_calc(local_base_model.ptr, read_len); + read_len = read_len_next; + local_base_model.toggle(); + mfence_lm(); + } + base_model_sm[bsz_index % max_sm_len] = token_num; + } else if (base_model_stop_lm[bsz_offset % BSZ_BUF]) { + int token_num = 0; + base_model_sm[bsz_index % max_sm_len] = token_num; + } + } + sync_cluster(); + if (core_id() == 0) { + SM2GM(base_model_sm + offset_cluster, + base_model_seq_lens_this_time + offset_loop + offset_cluster, + sizeof(int) * read_len_sm); + } + sync_cluster(); + } +} +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess.xpu new file mode 100644 index 000000000..9471fd096 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess.xpu @@ -0,0 +1,243 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_debug.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/cluster_simd.h" + +namespace xpu3 { +namespace plugin { +__global__ void draft_model_preprocess(int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + int* seq_lens_encoder_record, + int* seq_lens_decoder_record, + bool* not_need_stop, + bool* batch_drop, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + int real_bsz, + int max_draft_token, + int accept_tokens_len, + int draft_tokens_len, + int input_ids_len, + int base_model_draft_tokens_len, + bool truncate_first_token, + bool splitwise_prefill) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + int tid = clusterid * ncores + cid; + __shared__ int not_stop_flag_sm[64]; + not_stop_flag_sm[cid] = 0; + int64_t accept_tokens_now[128]; + + int value_zero = 0; + int64_t value_fu = -1; + + if (splitwise_prefill) { + for (; tid < real_bsz; tid += ncores * nclusters) { + int64_t base_model_step_idx_now = 0; + int seq_lens_encoder_now = 0; + int seq_lens_this_time_now = 0; + bool stop_flags_now = false; + int64_t base_model_first_token; + int seq_lens_encoder_record_now = 0; + int64_t input_ids_now = 0; + + GM2LM_ASYNC( + base_model_step_idx + tid, &base_model_step_idx_now, sizeof(int64_t)); + GM2LM_ASYNC(seq_lens_encoder_record + tid, + &seq_lens_encoder_record_now, + sizeof(int)); + GM2LM(accept_tokens + tid * accept_tokens_len, + &base_model_first_token, + sizeof(int64_t)); + + if (base_model_step_idx_now == 1 && seq_lens_encoder_record_now > 0) { + not_stop_flag_sm[cid] += 1; + int seq_len_encoder_record = seq_lens_encoder_record_now; + seq_lens_encoder_now = seq_len_encoder_record; + seq_lens_encoder_record_now = -1; + stop_flags_now = false; + int position = seq_len_encoder_record; + if (truncate_first_token) { + position = position - 1; + input_ids_now = base_model_first_token; + seq_lens_this_time_now = seq_len_encoder_record; + } else { + input_ids_now = base_model_first_token; + seq_lens_this_time_now = seq_len_encoder_record + 1; + } + LM2GM_ASYNC(&input_ids_now, + input_ids + tid * input_ids_len + position, + sizeof(int64_t)); + LM2GM_ASYNC(&seq_lens_encoder_record_now, + seq_lens_encoder_record + tid, + sizeof(int)); + + } else { + stop_flags_now = true; + seq_lens_this_time_now = 0; + seq_lens_encoder_now = 0; + not_stop_flag_sm[cid] += 0; + LM2GM_ASYNC(&value_zero, seq_lens_decoder + tid, sizeof(int)); + } + LM2GM_ASYNC(&seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int)); + LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool)); + LM2GM(&seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int)); + } + } else { + for (; tid < real_bsz; tid += ncores * nclusters) { + bool base_model_stop_flags_now = false; + bool base_model_is_block_step_now = false; + bool batch_drop_now = false; + bool stop_flags_now = false; + int seq_lens_this_time_now = 0; + int seq_lens_encoder_record_now = 0; + int seq_lens_encoder_now = 0; + int seq_lens_decoder_new = 0; + int seq_lens_decoder_record_now = 0; + int accept_num_now = 0; + int base_model_seq_lens_decoder_now = 0; + int64_t step_id_now = 0; + int64_t base_model_step_idx_now; + mfence(); + GM2LM_ASYNC(base_model_stop_flags + tid, + &base_model_stop_flags_now, + sizeof(bool)); + GM2LM_ASYNC(base_model_is_block_step + tid, + &base_model_is_block_step_now, + sizeof(bool)); + GM2LM_ASYNC(batch_drop + tid, &batch_drop_now, sizeof(bool)); + GM2LM_ASYNC(stop_flags + tid, &stop_flags_now, sizeof(bool)); + GM2LM_ASYNC(seq_lens_encoder_record + tid, + &seq_lens_encoder_record_now, + sizeof(int)); + GM2LM_ASYNC(seq_lens_decoder_record + tid, + &seq_lens_decoder_record_now, + sizeof(int)); + GM2LM_ASYNC(seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int)); + GM2LM_ASYNC(seq_lens_decoder + tid, &seq_lens_decoder_new, sizeof(int)); + + GM2LM_ASYNC(accept_tokens + tid * accept_tokens_len, + accept_tokens_now, + accept_tokens_len * sizeof(int64_t)); + GM2LM_ASYNC(accept_num + tid, &accept_num_now, sizeof(int)); + + GM2LM_ASYNC(base_model_seq_lens_decoder + tid, + &base_model_seq_lens_decoder_now, + sizeof(int)); + GM2LM_ASYNC(step_idx + tid, &step_id_now, sizeof(int64_t)); + GM2LM( + base_model_step_idx + tid, &base_model_step_idx_now, sizeof(int64_t)); + + for (int i = 1; i < base_model_draft_tokens_len; i++) { + LM2GM_ASYNC( + &value_fu, + base_model_draft_tokens + tid * base_model_draft_tokens_len + i, + sizeof(int)); + } + if (base_model_stop_flags_now && base_model_is_block_step_now) { + batch_drop_now = true; + stop_flags_now = true; + } + + if (!(base_model_stop_flags_now || batch_drop_now)) { + not_stop_flag_sm[cid] += 1; + if (base_model_step_idx_now == 0) { + seq_lens_this_time_now = 0; + not_stop_flag_sm[cid] -= 1; // 因为上面加过,这次减去,符合=0逻辑 + } else if (base_model_step_idx_now == 1 && + seq_lens_encoder_record_now > 0) { + int seq_len_encoder_record = seq_lens_encoder_record_now; + seq_lens_encoder_now = seq_len_encoder_record; + seq_lens_encoder_record_now = -1; + seq_lens_decoder_new = seq_lens_decoder_record_now; + seq_lens_decoder_record_now = 0; + stop_flags_now = false; + int64_t base_model_first_token = accept_tokens_now[0]; + int position = seq_len_encoder_record; + if (truncate_first_token) { + LM2GM(&base_model_first_token, + input_ids + tid * input_ids_len + position - 1, + sizeof(int64_t)); + seq_lens_this_time_now = seq_len_encoder_record; + } else { + LM2GM(&base_model_first_token, + input_ids + tid * input_ids_len + position, + sizeof(int64_t)); + seq_lens_this_time_now = seq_len_encoder_record + 1; + } + } else if (accept_num_now <= max_draft_token) { + if (stop_flags_now) { + stop_flags_now = false; + seq_lens_decoder_new = base_model_seq_lens_decoder_now; + step_id_now = base_model_step_idx_now; + } else { + seq_lens_decoder_new -= max_draft_token - accept_num_now; + step_id_now -= max_draft_token - accept_num_now; + } + int64_t modified_token = accept_tokens_now[accept_num_now - 1]; + LM2GM(&modified_token, + draft_tokens + tid * draft_tokens_len, + sizeof(int64_t)); + seq_lens_this_time_now = 1; + + } else /*Accept all draft tokens*/ { + LM2GM(accept_tokens_now + max_draft_token, + draft_tokens + tid * draft_tokens_len + 1, + sizeof(int64_t)); + seq_lens_this_time_now = 2; + } + + } else { + stop_flags_now = true; + seq_lens_this_time_now = 0; + seq_lens_encoder_now = 0; + seq_lens_decoder_new = 0; + } + LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool)); + LM2GM_ASYNC(&batch_drop_now, batch_drop + tid, sizeof(bool)); + + LM2GM_ASYNC(&seq_lens_decoder_new, seq_lens_decoder + tid, sizeof(int)); + LM2GM_ASYNC( + &seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int)); + LM2GM_ASYNC(&seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int)); + LM2GM_ASYNC(&seq_lens_encoder_record_now, + seq_lens_encoder_record + tid, + sizeof(int)); + LM2GM_ASYNC(&seq_lens_decoder_record_now, + seq_lens_decoder_record + tid, + sizeof(int)); + LM2GM_ASYNC(&step_id_now, step_idx + tid, sizeof(int64_t)); + } + } + mfence(); + sync_cluster(); + bool value_true = true; + bool value_false = false; + if (cid == 0) { + for (int i = 0; i < ncores; i++) { + not_stop_flag_sm[0] += not_stop_flag_sm[i]; + } + if (not_stop_flag_sm[0] > 0) { + LM2GM(&value_true, not_need_stop, sizeof(bool)); + } else { + LM2GM(&value_false, not_need_stop, sizeof(bool)); + } + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu new file mode 100644 index 000000000..0334995f9 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu @@ -0,0 +1,114 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_debug.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/cluster_primitive_template.h" +namespace xpu3 { +namespace plugin { +inline __device__ bool is_in_end(const int64_t id, + const __global_ptr__ int64_t* end_ids, + int length) { + bool flag = false; + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; + } + } + return flag; +} +__global__ void draft_model_update(const int64_t* inter_next_tokens, + int64_t* draft_tokens, + int64_t* pre_ids, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + const int* output_cum_offsets, + bool* stop_flags, + bool* not_need_stop, + const int64_t* max_dec_len, + const int64_t* end_ids, + int64_t* base_model_draft_tokens, + const int bsz, + const int max_draft_token, + const int pre_id_length, + const int max_base_model_draft_token, + const int end_ids_len, + const int max_seq_len, + const int substep, + const bool prefill_one_step_stop) { + int cid = core_id(); + int ncores = core_num(); + __shared__ float stop_flag_now_int_sm[64]; + stop_flag_now_int_sm[cid] = 0; + for (int tid = cid; tid < bsz; tid += ncores) { + auto* draft_token_now = draft_tokens + tid * max_draft_token; + auto* pre_ids_now = pre_ids + tid * pre_id_length; + auto* base_model_draft_tokens_now = + base_model_draft_tokens + tid * max_base_model_draft_token; + const int next_tokens_start_id = + tid * max_seq_len - output_cum_offsets[tid]; + auto* next_tokens_start = inter_next_tokens + next_tokens_start_id; + auto seq_len_this_time = seq_lens_this_time[tid]; + auto seq_len_encoder = seq_lens_encoder[tid]; + auto seq_len_decoder = seq_lens_decoder[tid]; + if (!stop_flags[tid] /* seq_lens_decoder > 0 or seq_lens_encoder > 0 */) { + int64_t token_this_time = -1; + // decoder step + if (seq_len_decoder > 0 && seq_len_encoder <= 0) { + seq_lens_decoder[tid] += seq_len_this_time; + token_this_time = next_tokens_start[seq_len_this_time - 1]; + draft_token_now[0] = next_tokens_start[seq_len_this_time - 1]; + base_model_draft_tokens_now[substep + 1] = token_this_time; + for (int i = 0; i < seq_len_this_time; ++i) { + pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i]; + } + step_idx[tid] += seq_len_this_time; + } else { + token_this_time = next_tokens_start[0]; + seq_lens_decoder[tid] = seq_len_encoder + seq_len_decoder; + // mfence(); + seq_lens_encoder[tid] = 0; + pre_ids_now[1] = token_this_time; + step_idx[tid] += 1; + draft_token_now[0] = token_this_time; + base_model_draft_tokens_now[substep + 1] = token_this_time; + } + // multi_end + if (is_in_end(token_this_time, end_ids, end_ids_len) || + prefill_one_step_stop) { + stop_flags[tid] = true; + stop_flag_now_int_sm[cid] += 1; + // max_dec_len + } else if (step_idx[tid] >= max_dec_len[tid]) { + stop_flags[tid] = true; + draft_token_now[seq_len_this_time - 1] = end_ids[0]; + base_model_draft_tokens_now[substep + 1] = end_ids[0]; + stop_flag_now_int_sm[cid] += 1; + } + } else { + draft_token_now[0] = -1; + base_model_draft_tokens_now[substep + 1] = -1; + stop_flag_now_int_sm[cid] += 1; + } + // 2. set end + if (!stop_flags[tid]) { + seq_lens_this_time[tid] = 1; + } else { + seq_lens_this_time[tid] = 0; + seq_lens_encoder[tid] = 0; + } + } + mfence(); + sync_all(); + if (cid == 0) { + int sum_stop = 0; + for (int i = 0; i < 64; i++) { + sum_stop += stop_flag_now_int_sm[i]; + } + not_need_stop[0] = sum_stop < bsz; + } + mfence(); +} +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/mtp_free_and_dispatch_block.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/mtp_free_and_dispatch_block.xpu new file mode 100644 index 000000000..bf0952b58 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/mtp_free_and_dispatch_block.xpu @@ -0,0 +1,209 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +static __device__ inline int loada_float(_shared_ptr_ const int *ptr) { + int ret; + __asm__ __volatile__("loada.w %0,%1" : "=&r"(ret) : "r"(ptr)); + return ret; +} + +static __device__ inline bool storea_float(_shared_ptr_ int *ptr, int value) { + bool ret; + __asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(ret) : "r"(value), "r"(ptr)); + return ret; +} + +static __device__ int atomic_add(_shared_ptr_ int *ptr, int value) { + bool fail = true; + int old_value; + while (fail) { + old_value = loada_float(ptr); + int new_value = old_value + value; + fail = storea_float(ptr, new_value); + } + return old_value; +} + +__global__ void mtp_free_and_dispatch_block(bool *base_model_stop_flags, + bool *stop_flags, + bool *batch_drop, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + int *used_list_len, + int *free_list, + int *free_list_len, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_draft_tokens) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + if (clusterid != 0 || cid >= bsz) return; + + // assert bsz <= 640 + const int max_bs = 640; + int value_zero = 0; + bool flag_true = true; + + __shared__ int free_list_len_sm; + // 每次最多处理free_list数量为block_table_now_len + const int block_table_now_len = 128; + int block_table_now[block_table_now_len]; + for (int i = 0; i < block_table_now_len; i++) { + block_table_now[i] = -1; + } + __shared__ bool base_model_stop_flags_sm[max_bs]; + __shared__ bool batch_drop_sm[max_bs]; + __shared__ int encoder_block_lens_sm[max_bs]; + __shared__ int seq_lens_decoder_sm[max_bs]; + + int free_list_now[block_table_now_len]; + __shared__ int need_block_len_sm; + __shared__ int need_block_list_sm[max_bs]; + __shared__ int used_list_len_sm[max_bs]; + __shared__ bool step_max_block_flag; + + if (cid == 0) { + // len = 1 + need_block_len_sm = 0; + GM2SM_ASYNC(free_list_len, &free_list_len_sm, sizeof(int)); + // len = bsz + GM2SM_ASYNC( + base_model_stop_flags, &base_model_stop_flags_sm, bsz * sizeof(bool)); + GM2SM_ASYNC(batch_drop, &batch_drop_sm, bsz * sizeof(bool)); + GM2SM_ASYNC(encoder_block_lens, &encoder_block_lens_sm, bsz * sizeof(int)); + GM2SM_ASYNC(used_list_len, used_list_len_sm, bsz * sizeof(int)); + GM2SM_ASYNC(seq_lens_decoder, seq_lens_decoder_sm, bsz * sizeof(int)); + } + for (int tid = cid; tid < bsz; tid += ncores) { + need_block_list_sm[tid] = 0; + } + mfence(); + sync_all(); + + for (int tid = cid; tid < bsz; tid += ncores) { + int64_t first_token_id_lm = -1; + if (base_model_stop_flags_sm[tid] || batch_drop_sm[tid]) { + // 回收block块 + const int encoder_block_len_lm = encoder_block_lens_sm[tid]; + const int decoder_used_len_lm = used_list_len_sm[tid]; + if (decoder_used_len_lm > 0) { + const int ori_free_list_len = + atomic_add(&free_list_len_sm, decoder_used_len_lm); + for (int i = 0; i < decoder_used_len_lm; i += block_table_now_len) { + int process_len = min(block_table_now_len, decoder_used_len_lm - i); + GM2LM( + block_tables + tid * block_num_per_seq + encoder_block_len_lm + i, + free_list_now, + process_len * sizeof(int)); + LM2GM(free_list_now, + free_list + ori_free_list_len + i, + process_len * sizeof(int)); + LM2GM( + block_table_now, + block_tables + tid * block_num_per_seq + encoder_block_len_lm + i, + process_len * sizeof(int)); + } + encoder_block_lens_sm[tid] = 0; + used_list_len_sm[tid] = 0; + } + mfence(); + } + int max_possible_block_idx = + (seq_lens_decoder_sm[tid] + max_draft_tokens + 1) / block_size; + int next_block_id; + GM2LM(block_tables + tid * block_num_per_seq + max_possible_block_idx, + &next_block_id, + sizeof(int)); + + if (!base_model_stop_flags[tid] && !batch_drop[tid] && + max_possible_block_idx < block_num_per_seq && next_block_id == -1) { + // 统计需要分配block的位置和总数 + const int ori_need_block_len = atomic_add(&need_block_len_sm, 1); + need_block_list_sm[ori_need_block_len] = tid; + mfence(); + } + + } // for + sync_cluster(); + + if (cid == 0) { + while (need_block_len_sm > free_list_len_sm) { + // 调度block,根据used_list_len从大到小回收block,直到满足need_block_len + int max_used_list_len_id = 0; + int max_used_list_len = 0; + for (int i = 0; i < bsz; i++) { + if ((!base_model_stop_flags_sm[i]) && + (used_list_len_sm[i] > max_used_list_len)) { + max_used_list_len_id = i; + max_used_list_len = used_list_len_sm[i]; + } + } + const int encoder_block_len_lm = + encoder_block_lens_sm[max_used_list_len_id]; + for (int i = 0; i < max_used_list_len; i += block_table_now_len) { + int process_len = min(block_table_now_len, max_used_list_len - i); + GM2LM(block_tables + max_used_list_len_id * block_num_per_seq + + encoder_block_len_lm + i, + free_list_now, + process_len * sizeof(int)); + LM2GM(free_list_now, + free_list + free_list_len_sm + i, + process_len * sizeof(int)); + LM2GM(block_table_now, + block_tables + max_used_list_len_id * block_num_per_seq + + encoder_block_len_lm + i, + process_len * sizeof(int)); + } + free_list_len_sm += max_used_list_len; + LM2GM_ASYNC(&flag_true, stop_flags + max_used_list_len_id, sizeof(bool)); + LM2GM_ASYNC( + &value_zero, seq_lens_this_time + max_used_list_len_id, sizeof(int)); + + // 后面还要用,所以先放到sm中,用完在写回GM + batch_drop_sm[max_used_list_len_id] = true; + seq_lens_decoder_sm[max_used_list_len_id] = 0; + used_list_len_sm[max_used_list_len_id] = 0; + mfence(); + } + } + sync_cluster(); + + int need_block_len_all = need_block_len_sm; + for (int tid = cid; tid < need_block_len_all; tid += ncores) { + // 为需要block的位置分配block,每个位置分配一个block + const int need_block_id = need_block_list_sm[tid]; + if (!batch_drop_sm[need_block_id]) { + used_list_len_sm[need_block_id]++; + const int ori_free_list_len = atomic_add(&free_list_len_sm, -1); + int free_block_id; + GM2LM(free_list + ori_free_list_len - 1, &free_block_id, sizeof(int)); + LM2GM(&free_block_id, + block_tables + need_block_id * block_num_per_seq + + (seq_lens_decoder_sm[need_block_id] + max_draft_tokens + 1) / + block_size, + sizeof(int)); + } + } + sync_cluster(); + + if (cid == 0) { + mfence(); + SM2GM_ASYNC(&free_list_len_sm, free_list_len, sizeof(int)); + SM2GM_ASYNC(used_list_len_sm, used_list_len, sizeof(int) * bsz); + SM2GM_ASYNC(seq_lens_decoder_sm, seq_lens_decoder, sizeof(int) * bsz); + SM2GM_ASYNC(batch_drop_sm, batch_drop, sizeof(bool) * bsz); + SM2GM_ASYNC(encoder_block_lens_sm, encoder_block_lens, sizeof(int) * bsz); + mfence(); + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/rebuild_append_padding.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/rebuild_append_padding.xpu new file mode 100644 index 000000000..a098a0e6f --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/rebuild_append_padding.xpu @@ -0,0 +1,90 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +template +__global__ void RebuildAppendPaddingKernel(const T *full_hidden_states, + const int *cum_offset, + const int *seq_len_encoder, + const int *seq_len_decoder, + const int *output_padding_offset, + int max_seq_len, + int dim_embed, + int elem_nums, + T *out) { + int ncores = core_num(); + int cid = core_id(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + int64_t mstart = -1; + int64_t mend = -1; + int64_t nstart = -1; + int64_t nend = -1; + partition2d(tid, + nthreads, + elem_nums / dim_embed, + dim_embed, + &mstart, + &mend, + &nstart, + &nend); + + const int64_t BUFFER_LEN = rounddown(6144 / sizeof(T), 64); + __simd__ T lm_full_hidden_states[BUFFER_LEN]; + int output_padding_offset_val, cum_offset_val, seq_len_encoder_val, + seq_len_decoder_val; + + for (int64_t _m = mstart; _m < mend; _m++) { + int out_token_id = _m; + GM2LM(output_padding_offset + out_token_id, + &output_padding_offset_val, + sizeof(int)); + int ori_token_id = out_token_id + output_padding_offset_val; + int bi = ori_token_id / max_seq_len; + GM2LM_ASYNC(seq_len_encoder + bi, &seq_len_encoder_val, sizeof(int)); + GM2LM(seq_len_decoder + bi, &seq_len_decoder_val, sizeof(int)); + int seq_id = 0; + if (seq_len_encoder_val == 0 and seq_len_decoder_val == 0) { + continue; + } else if (seq_len_encoder_val != 0) { + seq_id = seq_len_encoder_val - 1; + } + GM2LM(cum_offset + bi, &cum_offset_val, sizeof(int)); + int input_token_id = ori_token_id - cum_offset_val + seq_id; + for (int64_t _n = nstart; _n < nend; _n += BUFFER_LEN) { + int64_t read_size = min(BUFFER_LEN, nend - _n); + // out[i] = full_hidden_states[(i / dim_embed + + // output_padding_offset[i / dim_embed] - cum_offset[(i / dim_embed + // + output_padding_offset[i / dim_embed]) / max_seq_len] + seq_id) + // * dim_embed + i % dim_embed] + GM2LM(full_hidden_states + input_token_id * dim_embed + _n, + lm_full_hidden_states, + read_size * sizeof(T)); + LM2GM(lm_full_hidden_states, + out + _m * dim_embed + _n, + read_size * sizeof(T)); + } + } +} + +#define _XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(T) \ + template __global__ void RebuildAppendPaddingKernel( \ + const T *full_hidden_states, \ + const int *cum_offset, \ + const int *seq_len_encoder, \ + const int *seq_len_decoder, \ + const int *output_padding_offset, \ + int max_seq_len, \ + int dim_embed, \ + int elem_nums, \ + T *out); + +_XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(bfloat16); +_XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(float16); +_XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(float); + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/rebuild_hidde_states.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/rebuild_hidde_states.xpu new file mode 100644 index 000000000..293dc0f40 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/rebuild_hidde_states.xpu @@ -0,0 +1,65 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +template +__global__ void rebuildHiddenStatesKernel(const T* input, + const int* position_map, + T* output, + int dim_embed, + int elem_cnt) { + int ncores = core_num(); + int cid = core_id(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + int64_t mstart = -1; + int64_t mend = -1; + int64_t nstart = -1; + int64_t nend = -1; + partition2d(tid, + nthreads, + elem_cnt / dim_embed, + dim_embed, + &mstart, + &mend, + &nstart, + &nend); + + const int64_t BUFFER_LEN = 6144 / sizeof(T); + T lm_input[BUFFER_LEN]; + + for (int64_t _m = mstart; _m < mend; _m++) { + int ori_token_idx = _m; + int token_idx; + GM2LM(position_map + _m, &token_idx, sizeof(int)); + if (token_idx >= 0) { + for (int64_t _n = nstart; _n < nend; _n += BUFFER_LEN) { + int64_t read_size = min(BUFFER_LEN, nend - _n); + GM2LM(input + ori_token_idx * dim_embed + _n, + lm_input, + read_size * sizeof(T)); + LM2GM(lm_input, + output + token_idx * dim_embed + _n, + read_size * sizeof(T)); + } + } + } +} + +#define _XPU_DEF_REBUILD_HIDDEN_STATES_KERNEL(T) \ + template __global__ void rebuildHiddenStatesKernel( \ + const T* input, \ + const int* position_map, \ + T* output, \ + int dim_embed, \ + int elem_cnt); + +_XPU_DEF_REBUILD_HIDDEN_STATES_KERNEL(bfloat16); +_XPU_DEF_REBUILD_HIDDEN_STATES_KERNEL(float); +_XPU_DEF_REBUILD_HIDDEN_STATES_KERNEL(float16); + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/rebuild_self_hidde_states.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/rebuild_self_hidde_states.xpu new file mode 100644 index 000000000..f697bfdaf --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/rebuild_self_hidde_states.xpu @@ -0,0 +1,56 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +template +__global__ void rebuildSelfHiddenStatesKernel( + const T* input, int* src_map, T* output, int dim_embed, int elem_cnt) { + int ncores = core_num(); + int cid = core_id(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + int64_t mstart = -1; + int64_t mend = -1; + int64_t nstart = -1; + int64_t nend = -1; + partition2d(tid, + nthreads, + elem_cnt / dim_embed, + dim_embed, + &mstart, + &mend, + &nstart, + &nend); + + const int64_t BUFFER_LEN = 6144 / sizeof(T); + T lm_input[BUFFER_LEN]; + + for (int64_t _m = mstart; _m < mend; _m++) { + int output_token_idx = _m; + int input_token_idx; + GM2LM(src_map + _m, &input_token_idx, sizeof(int)); + if (input_token_idx >= 0) { + for (int64_t _n = nstart; _n < nend; _n += BUFFER_LEN) { + int64_t read_size = min(BUFFER_LEN, nend - _n); + GM2LM(input + input_token_idx * dim_embed + _n, + lm_input, + read_size * sizeof(T)); + LM2GM(lm_input, output + _m * dim_embed + _n, read_size * sizeof(T)); + } + } + } +} + +#define _XPU_DEF_REBUILD_SELF_HIDDEN_STATES_KERNEL(T) \ + template __global__ void rebuildSelfHiddenStatesKernel( \ + const T* input, int* src_map, T* output, int dim_embed, int elem_cnt); + +_XPU_DEF_REBUILD_SELF_HIDDEN_STATES_KERNEL(bfloat16); +_XPU_DEF_REBUILD_SELF_HIDDEN_STATES_KERNEL(float); +_XPU_DEF_REBUILD_SELF_HIDDEN_STATES_KERNEL(float16); + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_ban_bad_words.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_ban_bad_words.xpu new file mode 100644 index 000000000..114337f07 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_ban_bad_words.xpu @@ -0,0 +1,78 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +template +inline __device__ void update_bad_words_logit(_global_ptr_ T* logits) { + __local__ T min_value = -1e10f; + mfence_lm(); + LM2GM((void*)&(min_value), logits, sizeof(T)); +} + +template <> +inline __device__ void update_bad_words_logit( + _global_ptr_ float16* logits) { + __local__ short min_value = 0xFBFF; + mfence_lm(); + LM2GM((void*)&(min_value), logits, sizeof(float16)); +} + +template +__global__ void speculate_ban_bad_words(T* logits, + const int64_t* bad_words_list, + const int* output_padding_offset, + const int64_t bs, + const int64_t length, + const int64_t bad_words_length, + const int64_t token_num, + const int64_t max_seq_len) { + int tid = core_id() * cluster_num() + cluster_id(); + int nthreads = cluster_num() * core_num(); + int start = -1; + int end = -1; + int output_padding_offset_lm; + partition(tid, + nthreads, + static_cast(token_num * bad_words_length), + 1, + &start, + &end); + for (int i = start; i < end; i++) { + int token_idx = i / bad_words_length; + GM2LM(output_padding_offset + token_idx, + &output_padding_offset_lm, + sizeof(int)); + int bs_idx = (token_idx + output_padding_offset_lm) / max_seq_len; + if (bs_idx >= bs) { + continue; + } + int bad_words_idx = i - token_idx * bad_words_length; + int64_t bad_words_token_id = -1; + mfence_lm(); + GM2LM(bad_words_list + bad_words_idx, + (void*)&(bad_words_token_id), + sizeof(int64_t)); + if (bad_words_token_id >= length || bad_words_token_id < 0) continue; + update_bad_words_logit(logits + token_idx * length + bad_words_token_id); + } +} + +#define _XPU_DEF__BAN_BAD_WORDS_(DATA_TYPE) \ + template __global__ void speculate_ban_bad_words( \ + DATA_TYPE* logits, \ + const int64_t* bad_words_list, \ + const int* output_padding_offset, \ + const int64_t bs, \ + const int64_t length, \ + const int64_t bad_words_length, \ + const int64_t token_num, \ + const int64_t max_seq_len); +_XPU_DEF__BAN_BAD_WORDS_(float); +_XPU_DEF__BAN_BAD_WORDS_(float16); +_XPU_DEF__BAN_BAD_WORDS_(bfloat16); + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_clear_accept_nums.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_clear_accept_nums.xpu new file mode 100644 index 000000000..e7917f962 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_clear_accept_nums.xpu @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/xtdk_io.h" + +namespace xpu3 { +namespace plugin { + +__global__ void speculate_clear_accept_nums(int* accept_num, + const int* seq_lens_decoder, + const int max_bsz) { + int cid = core_id(); + int ncores = core_num(); + + int accept_num_lm = 0; + int seq_lens_decoder_lm; + for (int i = cid; i < max_bsz; i += ncores) { + GM2LM(seq_lens_decoder + i, &seq_lens_decoder_lm, sizeof(int)); + if (seq_lens_decoder_lm == 0) { + LM2GM_ASYNC(&accept_num_lm, accept_num + i, sizeof(int)); + } + mfence_lm(); + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_reschedule.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_reschedule.xpu new file mode 100644 index 000000000..e3bb8cc51 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_reschedule.xpu @@ -0,0 +1,288 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +static __device__ inline int loada_float(_shared_ptr_ const int *ptr) { + int ret; + __asm__ __volatile__("loada.w %0,%1" : "=&r"(ret) : "r"(ptr)); + return ret; +} + +static __device__ inline bool storea_float(_shared_ptr_ int *ptr, int value) { + bool ret; + __asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(ret) : "r"(value), "r"(ptr)); + return ret; +} + +static __device__ int atomic_add(_shared_ptr_ int *ptr, int value) { + bool fail = true; + int old_value; + while (fail) { + old_value = loada_float(ptr); + int new_value = old_value + value; + fail = storea_float(ptr, new_value); + } + return old_value; +} + +static __device__ bool in_need_block_list(const int qid, + _shared_ptr_ int *need_block_list, + const int need_block_len) { + bool res = false; + for (int i = 0; i < need_block_len; i++) { + if (qid == need_block_list[i]) { + need_block_list[i] = -1; + res = true; + break; + } + } + return res; +} + +__global__ void speculate_free_and_reschedule(bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + if (clusterid != 0 || cid >= bsz) return; + + // assert bsz <= 640 + const int max_bs = 640; + int value_zero = 0; + bool flag_true = true; + + // 128 = seq_len(8192) / block_size(64) + // 每次最多处理block_table数量为128 + const int block_table_now_len = 128; + int block_table_now[block_table_now_len]; + for (int i = 0; i < block_table_now_len; i++) { + block_table_now[i] = -1; + } + bool stop_flag_lm; + int seq_lens_decoder_lm; + + __shared__ int free_list_len_sm; + // 每次最多处理free_list数量为block_table_now_len + int free_list_now[block_table_now_len]; + __shared__ int need_block_len_sm; + __shared__ int need_block_list_sm[max_bs]; + __shared__ int used_list_len_sm[max_bs]; + __shared__ bool step_max_block_flag; + __shared__ int in_need_block_list_len; + if (cid == 0) { + step_max_block_flag = false; + in_need_block_list_len = 0; + GM2SM_ASYNC(free_list_len, &free_list_len_sm, sizeof(int)); + GM2SM_ASYNC(need_block_len, &need_block_len_sm, sizeof(int)); + mfence(); + if (need_block_len_sm > 0) { + GM2SM_ASYNC( + need_block_list, need_block_list_sm, sizeof(int) * need_block_len_sm); + } + GM2SM_ASYNC(used_list_len, used_list_len_sm, sizeof(int) * bsz); + mfence(); + } + sync_cluster(); + + for (int tid = cid; tid < bsz; tid += ncores) { + int seq_lens_this_time_lm; + mfence(); + GM2LM_ASYNC(stop_flags + tid, &stop_flag_lm, sizeof(bool)); + GM2LM_ASYNC(seq_lens_decoder + tid, &seq_lens_decoder_lm, sizeof(int)); + GM2LM_ASYNC(seq_lens_this_time + tid, &seq_lens_this_time_lm, sizeof(int)); + mfence(); + int max_possible_block_idx = + (seq_lens_decoder_lm + max_draft_tokens + 1) / block_size; + if (stop_flag_lm) { + // 回收block块 + int64_t first_token_id_lm = -1; + mfence_lm(); + LM2GM(&first_token_id_lm, first_token_ids + tid, sizeof(int64_t)); + int encoder_block_len_lm; + int decoder_used_len_lm = used_list_len_sm[tid]; + GM2LM(encoder_block_lens + tid, &encoder_block_len_lm, sizeof(int)); + if (decoder_used_len_lm > 0) { + const int ori_free_list_len = + atomic_add(&free_list_len_sm, decoder_used_len_lm); + for (int i = 0; i < decoder_used_len_lm; i += block_table_now_len) { + int process_len = min(block_table_now_len, decoder_used_len_lm - i); + GM2LM( + block_tables + tid * block_num_per_seq + encoder_block_len_lm + i, + free_list_now, + process_len * sizeof(int)); + LM2GM(free_list_now, + free_list + ori_free_list_len + i, + process_len * sizeof(int)); + LM2GM( + block_table_now, + block_tables + tid * block_num_per_seq + encoder_block_len_lm + i, + process_len * sizeof(int)); + } + used_list_len_sm[tid] = 0; + mfence(); + LM2GM(&value_zero, encoder_block_lens + tid, sizeof(int)); + } + } else if (seq_lens_this_time_lm != 0 && + max_possible_block_idx < block_num_per_seq) { + int next_block_id; + GM2LM(block_tables + tid * block_num_per_seq + + (seq_lens_decoder_lm + max_draft_tokens + 1) / block_size, + &next_block_id, + sizeof(int)); + if (next_block_id == -1) { + // 统计需要分配block的位置和总数 + const int ori_need_block_len = atomic_add(&need_block_len_sm, 1); + need_block_list_sm[ori_need_block_len] = tid; + } + } + } + sync_cluster(); + + bool is_block_step_lm[max_bs]; + int step_len_lm; + int step_block_list_lm[max_bs]; + int recover_len_lm; + int recover_block_list_lm[max_bs]; + if (cid == 0) { + GM2LM_ASYNC(is_block_step, is_block_step_lm, sizeof(bool) * bsz); + GM2LM_ASYNC(step_len, &step_len_lm, sizeof(int)); + GM2LM_ASYNC(step_block_list, step_block_list_lm, sizeof(int) * bsz); + GM2LM_ASYNC(recover_len, &recover_len_lm, sizeof(int)); + GM2LM_ASYNC(recover_block_list, recover_block_list_lm, sizeof(int) * bsz); + mfence(); + } + + if (cid == 0) { + while (need_block_len_sm > free_list_len_sm) { + // 调度block,根据used_list_len从大到小回收block,直到满足need_block_len,已解码到最后一个block的query不参与调度(马上就结束) + int max_used_list_len_id = 0; + int max_used_list_len = 0; + for (int i = 0; i < bsz; i++) { + if (used_list_len_sm[i] > max_used_list_len) { + max_used_list_len_id = i; + max_used_list_len = used_list_len_sm[i]; + } + } + + if (max_used_list_len == 0) { + step_max_block_flag = true; + } else { + int encoder_block_len; + GM2LM(encoder_block_lens + max_used_list_len_id, + &encoder_block_len, + sizeof(int)); + for (int i = 0; i < max_used_list_len; i += block_table_now_len) { + int process_len = min(block_table_now_len, max_used_list_len - i); + GM2LM(block_tables + max_used_list_len_id * block_num_per_seq + + encoder_block_len + i, + free_list_now, + process_len * sizeof(int)); + LM2GM(free_list_now, + free_list + free_list_len_sm + i, + process_len * sizeof(int)); + LM2GM(block_table_now, + block_tables + max_used_list_len_id * block_num_per_seq + + encoder_block_len + i, + process_len * sizeof(int)); + } + step_block_list_lm[step_len_lm] = max_used_list_len_id; + int need_block_len_all = need_block_len_sm + in_need_block_list_len; + if (in_need_block_list( + max_used_list_len_id, need_block_list_sm, need_block_len_all)) { + need_block_len_sm--; + in_need_block_list_len++; + } + step_len_lm++; + free_list_len_sm += max_used_list_len; + LM2GM_ASYNC( + &flag_true, stop_flags + max_used_list_len_id, sizeof(bool)); + LM2GM_ASYNC(&value_zero, + seq_lens_this_time + max_used_list_len_id, + sizeof(int)); + LM2GM_ASYNC( + &value_zero, seq_lens_decoder + max_used_list_len_id, sizeof(int)); + LM2GM_ASYNC(&value_zero, + encoder_block_lens + max_used_list_len_id, + sizeof(int)); + used_list_len_sm[max_used_list_len_id] = 0; + mfence(); + } + } + } + sync_cluster(); + + int need_block_len_all = need_block_len_sm + in_need_block_list_len; + for (int tid = cid; tid < need_block_len_all; tid += ncores) { + // 为需要block的位置分配block,每个位置分配一个block + const int need_block_id = need_block_list_sm[tid]; + if (need_block_id != -1) { + GM2LM(stop_flags + need_block_id, &stop_flag_lm, sizeof(bool)); + if (!stop_flag_lm) { + // 如果需要的位置正好是上一步中被释放的位置,不做处理 + used_list_len_sm[need_block_id]++; + const int ori_free_list_len = atomic_add(&free_list_len_sm, -1); + int tmp_seq_lens_decoder; + GM2LM(seq_lens_decoder + need_block_id, + &tmp_seq_lens_decoder, + sizeof(int)); + int free_block_id; + GM2LM(free_list + ori_free_list_len - 1, &free_block_id, sizeof(int)); + LM2GM(&free_block_id, + block_tables + need_block_id * block_num_per_seq + + (tmp_seq_lens_decoder + max_draft_tokens + 1) / block_size, + sizeof(int)); + } + need_block_list_sm[tid] = -1; + } + } + sync_cluster(); + + int ori_need_block_len; + if (cid == 0) { + ori_need_block_len = need_block_len_sm; + need_block_len_sm = 0; + } + + if (cid == 0) { + mfence(); + LM2GM_ASYNC(step_block_list_lm, step_block_list, sizeof(int) * bsz); + LM2GM_ASYNC(is_block_step_lm, is_block_step, sizeof(bool) * bsz); + LM2GM_ASYNC(&step_len_lm, step_len, sizeof(int)); + LM2GM_ASYNC(&recover_len_lm, recover_len, sizeof(int)); + LM2GM_ASYNC(recover_block_list_lm, recover_block_list, sizeof(int) * bsz); + SM2GM_ASYNC(&free_list_len_sm, free_list_len, sizeof(int)); + SM2GM_ASYNC(&need_block_len_sm, need_block_len, sizeof(int)); + if (ori_need_block_len > 0) { + SM2GM_ASYNC(need_block_list_sm, + need_block_list, + sizeof(int) * ori_need_block_len); + } + SM2GM_ASYNC(used_list_len_sm, used_list_len, sizeof(int) * bsz); + mfence(); + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_output_padding_offset.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_output_padding_offset.xpu new file mode 100644 index 000000000..b4ffcccfe --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_output_padding_offset.xpu @@ -0,0 +1,63 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/xtdk_io.h" + +namespace xpu3 { +namespace plugin { + +__global__ void speculate_get_output_padding_offset( + int* output_padding_offset, + int* output_cum_offsets, + const int* output_cum_offsets_tmp, + const int* seq_lens_output, + const int bsz, + const int max_seq_len) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + + int seq_lens_output_lm; + int cum_offset_lm; + + for (int bi = clusterid; bi < bsz; bi += nclusters) { + if (bi == 0) { + cum_offset_lm = 0; + } else { + GM2LM_ASYNC(output_cum_offsets_tmp + bi - 1, &cum_offset_lm, sizeof(int)); + } + GM2LM_ASYNC(seq_lens_output + bi, &seq_lens_output_lm, sizeof(int)); + mfence_lm(); + + for (int i = cid; i < seq_lens_output_lm; i += ncores) { + LM2GM_ASYNC(&cum_offset_lm, + output_padding_offset + bi * max_seq_len - cum_offset_lm + i, + sizeof(int)); + } + if (cid == 0) { + LM2GM_ASYNC(&cum_offset_lm, output_cum_offsets + bi, sizeof(int)); + } + mfence_lm(); + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu new file mode 100644 index 000000000..c08d756d7 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu @@ -0,0 +1,122 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/cluster_simd.h" +#include "xpu/kernel/xtdk.h" + +namespace xpu3 { +namespace plugin { + +template +__global__ void speculate_remove_padding(T* output_data, + const T* input_data, + const T* draft_tokens, + const int* seq_lens, + const int* seq_lens_encoder, + const int* cum_offsets, + int sequence_length, + int max_draft_tokens, + int bsz, + int token_num_data) { + int bid = cluster_id(); + int tid = core_id(); + int ncores = core_num(); + int nclusters = cluster_num(); + + int seq_lens_now = 0; + int seq_lens_encoder_now = 0; + int cum_offsets_now = 0; + T input_date_now; + for (int bi = bid; bi < bsz; bi += nclusters) { + GM2LM(seq_lens + bi, &seq_lens_now, sizeof(int)); + GM2LM(seq_lens_encoder + bi, &seq_lens_encoder_now, sizeof(int)); + GM2LM(cum_offsets + bi, &cum_offsets_now, sizeof(int)); + + for (int i = tid; i < seq_lens_now; i += ncores) { + const int tgt_seq_id = bi * sequence_length - cum_offsets_now + i; + + if (seq_lens_encoder_now > 0) { + const int src_seq_id = bi * sequence_length + i; + GM2LM(input_data + src_seq_id, &input_date_now, sizeof(T)); + LM2GM(&input_date_now, output_data + tgt_seq_id, sizeof(T)); + } else { + const int src_seq_id = bi * max_draft_tokens + i; + GM2LM(draft_tokens + src_seq_id, &input_date_now, sizeof(T)); + LM2GM(&input_date_now, output_data + tgt_seq_id, sizeof(T)); + } + } + } +} + +__global__ void speculate_get_padding_offset(int* padding_offset, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz) { + int bid = cluster_id(); + int tid = core_id(); + int ncores = core_num(); + int nclusters = cluster_num(); + int seq_lens_now = 0; + int cum_offsets_now = 0; + int cum_offsets_now_ind = 0; + for (int bi = bid; bi < bsz; bi += nclusters) { + GM2LM(seq_lens + bi, &seq_lens_now, sizeof(int)); + if (bi == 0) { + cum_offsets_now = 0; + } else { + GM2LM(cum_offsets + bi - 1, &cum_offsets_now, sizeof(int)); + } + GM2LM(cum_offsets + bi, &cum_offsets_now_ind, sizeof(int)); + + for (int i = tid; i < seq_lens_now; i += ncores) { + LM2GM(&cum_offsets_now, + padding_offset + bi * max_seq_len - cum_offsets_now + i, + sizeof(int)); + } + LM2GM(&cum_offsets_now, cum_offsets_out + bi, sizeof(int)); + int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets_now_ind; + LM2GM(&cum_seq_len, cu_seqlens_q + bi + 1, sizeof(int)); + LM2GM(&cum_seq_len, cu_seqlens_k + bi + 1, sizeof(int)); + } +} + +#define _XPU_DEF_SPECULATE_KERNELS_(T) \ + template __global__ void speculate_remove_padding(T*, \ + const T*, \ + const T*, \ + const int*, \ + const int*, \ + const int*, \ + int, \ + int, \ + int, \ + int); + +_XPU_DEF_SPECULATE_KERNELS_(float); +_XPU_DEF_SPECULATE_KERNELS_(float16); +_XPU_DEF_SPECULATE_KERNELS_(bfloat16); +_XPU_DEF_SPECULATE_KERNELS_(int64_t); + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_seq_lens_output.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_seq_lens_output.xpu new file mode 100644 index 000000000..5044ceb0e --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_seq_lens_output.xpu @@ -0,0 +1,58 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/xtdk_io.h" + +namespace xpu3 { +namespace plugin { + +__global__ void speculate_get_seq_lens_output(int* seq_lens_output, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int real_bsz) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + int thread_num = ncores * nclusters; + int bid = clusterid * ncores + cid; + + int one = 1; + int lm_seq_lens_this_time; + int lm_seq_lens_encoder; + for (; bid < real_bsz; bid += thread_num) { + GM2LM_ASYNC(seq_lens_this_time + bid, &lm_seq_lens_this_time, sizeof(int)); + GM2LM(seq_lens_encoder + bid, &lm_seq_lens_encoder, sizeof(int)); + if (lm_seq_lens_this_time == 0) { + continue; + } else if (lm_seq_lens_this_time == 1) { + LM2GM_ASYNC(&one, seq_lens_output + bid, sizeof(int)); + } else if (lm_seq_lens_encoder != 0) { + LM2GM_ASYNC(&one, seq_lens_output + bid, sizeof(int)); + } else { + LM2GM_ASYNC(&lm_seq_lens_this_time, seq_lens_output + bid, sizeof(int)); + } + mfence_lm(); + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_min_length_logits_process.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_min_length_logits_process.xpu new file mode 100644 index 000000000..54ae6abeb --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_min_length_logits_process.xpu @@ -0,0 +1,91 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_debug.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +template +__global__ void speculate_min_length_logits_process( + T* logits, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int* output_padding_offset, + const int* output_cum_offsets, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t token_num, + const int64_t max_seq_len) { + int ncores = core_num(); + int cid = core_id(); + int tid = cluster_num() * cid + cluster_id(); + int nthreads = cluster_num() * ncores; + + int64_t cur_len_now; + int64_t min_len_now; + int64_t eos_token_id_now; + int64_t bi; + int64_t end_num; + int output_padding_offset_now; + int output_cum_offsets_now; + __simd__ float float32logits_now[32]; + + for (int64_t i = tid; i < token_num * end_length; i += nthreads) { + int64_t token_idx = i / end_length; + GM2LM(output_padding_offset + token_idx, + &output_padding_offset_now, + sizeof(int)); + bi = (token_idx + output_padding_offset_now) / max_seq_len; + if (bi >= bs) { + continue; + } + end_num = i % end_length; + GM2LM_ASYNC( + output_cum_offsets + bi, (void*)&output_cum_offsets_now, sizeof(int)); + GM2LM_ASYNC(cur_len + bi, (void*)&(cur_len_now), sizeof(int64_t)); + GM2LM_ASYNC(min_len + bi, (void*)&(min_len_now), sizeof(int64_t)); + mfence(); + int query_start_token_idx = bi * max_seq_len - output_cum_offsets_now; + if (cur_len_now >= 0 && + (cur_len_now + (token_idx - query_start_token_idx) < min_len_now)) { + GM2LM( + eos_token_id + end_num, (void*)&(eos_token_id_now), sizeof(int64_t)); + GM2LM(logits + token_idx * length + eos_token_id_now, + (void*)float32logits_now, + sizeof(T)); + primitive_cast( + (const T*)(float32logits_now), float32logits_now, 1); + float32logits_now[0] = std::is_same::value ? -1e4 : -1e10; + mfence_lm(); + primitive_cast(float32logits_now, (T*)float32logits_now, 1); + LM2GM((void*)float32logits_now, + logits + token_idx * length + eos_token_id_now, + sizeof(T)); + } + } +} + +#define _XPU_DEF__UPDATE_LOGITS_REPEAT_TIMES_(DATA_TYPE) \ + template __global__ void speculate_min_length_logits_process( \ + DATA_TYPE * logits, \ + const int64_t* cur_len, \ + const int64_t* min_len, \ + const int64_t* eos_token_id, \ + const int* output_padding_offset, \ + const int* output_cum_offsets, \ + const int64_t bs, \ + const int64_t length, \ + const int64_t length_id, \ + const int64_t end_length, \ + const int64_t token_num, \ + const int64_t max_seq_len); +_XPU_DEF__UPDATE_LOGITS_REPEAT_TIMES_(float); +_XPU_DEF__UPDATE_LOGITS_REPEAT_TIMES_(float16); +_XPU_DEF__UPDATE_LOGITS_REPEAT_TIMES_(bfloat16); + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_set_stop_value_multi_seqs.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_set_stop_value_multi_seqs.xpu new file mode 100644 index 000000000..507aab9d3 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_set_stop_value_multi_seqs.xpu @@ -0,0 +1,98 @@ +#include "xpu/kernel/cluster_debug.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/xtdk.h" +#include "xpu/kernel/xtdk_math.h" +#include "xpu/kernel/xtdk_simd.h" + +namespace xpu3 { +namespace plugin { + +__global__ void speculate_set_stop_value_multi_seqs(bool *stop_flags, + int64_t *accept_tokens, + int *accept_nums, + const int64_t *pre_ids, + const int64_t *step_idx, + const int64_t *stop_seqs, + const int *stop_seqs_len, + const int *seq_lens, + const int64_t *end_ids, + const int bs, + const int accept_tokens_len, + const int stop_seqs_bs, + const int stop_seqs_max_len, + const int pre_ids_len) { + int cls_id = cluster_id(); + int cid = core_id(); + int ncores = core_num(); + int nclusters = cluster_num(); + + int accept_num = 0; + int64_t step_idx_now = 0; + bool stop_flags_now = false; + int stop_seq_len = 0; + for (int bid = cls_id; bid < bs; bid += nclusters) { + GM2LM_ASYNC(accept_nums + bid, &accept_num, sizeof(int)); + GM2LM_ASYNC(step_idx + bid, &step_idx_now, sizeof(int64_t)); + GM2LM(stop_flags + bid, &stop_flags_now, sizeof(bool)); + if (stop_flags_now) { + continue; + } + for (int tid = cid; tid < stop_seqs_bs; tid += ncores) { + GM2LM_ASYNC(stop_seqs_len + tid, &stop_seq_len, sizeof(int)); + if (stop_seq_len <= 0) { + continue; + } + int accept_idx = 0; + bool is_end = false; + int64_t stop_seq_now_lm = 0; + for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) { + if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) { + continue; + } + // 遍历一个 stop_seqs + for (int i = stop_seq_len - 1; i >= 0; --i) { + int64_t cur_token_idx = -1; + // 通过当前值判断 token 是在 pre_ids 还是 accept_token 里 + if (stop_seq_len - 1 - i < accept_idx) { + GM2LM(accept_tokens + bid * accept_tokens_len + accept_idx - + (stop_seq_len - 1 - i) - 1, + &cur_token_idx, + sizeof(int64_t)); + } else { + int pre_ids_idx = + step_idx_now - accept_num + accept_idx - (stop_seq_len - 1 - i); + // EC3 + // 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23, + // 导致异常结束 + if (pre_ids_idx <= 0) { + break; + } + GM2LM(pre_ids + bid * pre_ids_len + pre_ids_idx, + &cur_token_idx, + sizeof(int64_t)); + } + GM2LM(stop_seqs + tid * stop_seqs_max_len + i, + &stop_seq_now_lm, + sizeof(int64_t)); + if (cur_token_idx != stop_seq_now_lm) { + break; + } + if (i == 0) { + is_end = true; + } + } + } + if (is_end) { + int64_t end_id_lm; + bool value_true = true; + GM2LM(end_ids, &end_id_lm, sizeof(int64_t)); + LM2GM_ASYNC(&end_id_lm, + accept_tokens + bid * accept_tokens_len + accept_idx - 1, + sizeof(int64_t)); + LM2GM(&value_true, stop_flags + bid, sizeof(bool)); + } + } + } +} +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_set_value_by_flags.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_set_value_by_flags.xpu new file mode 100644 index 000000000..e9bdfa348 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_set_value_by_flags.xpu @@ -0,0 +1,83 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/xtdk_io.h" + +namespace xpu3 { +namespace plugin { + +__global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all, + const int64_t *accept_tokens, + const int *accept_num, + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *step_idx, + int bs, + int length, + int max_draft_tokens) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + if (clusterid != 0) return; + + int64_t pre_ids_all_lm[max_draft_tokens]; + int64_t accept_tokens_lm[max_draft_tokens]; + int accept_num_lm; + bool stop_flags_lm; + int seq_lens_encoder_lm; + int seq_lens_decoder_lm; + int64_t step_idx_lm; + + for (int i = cid; i < bs; i += ncores) { + GM2LM_ASYNC(stop_flags + i, &stop_flags_lm, sizeof(bool)); + GM2LM_ASYNC(seq_lens_encoder + i, &seq_lens_encoder_lm, sizeof(int)); + GM2LM_ASYNC(seq_lens_decoder + i, &seq_lens_decoder_lm, sizeof(int)); + GM2LM_ASYNC(step_idx + i, &step_idx_lm, sizeof(int64_t)); + GM2LM_ASYNC(accept_num + i, &accept_num_lm, sizeof(int)); + mfence_lm(); + + if (stop_flags_lm || + (seq_lens_encoder_lm == 0 && seq_lens_decoder_lm == 0) || + step_idx_lm < 0) + continue; + + // Avoid loading large amounts of data + int pre_ids_start_idx = i * length + step_idx_lm - max_draft_tokens + 1; + GM2LM_ASYNC(pre_ids_all + pre_ids_start_idx, + pre_ids_all_lm, + max_draft_tokens * sizeof(int64_t)); + GM2LM_ASYNC(accept_tokens + i * max_draft_tokens, + accept_tokens_lm, + max_draft_tokens * sizeof(int64_t)); + mfence_lm(); + + for (int j = 0; j < accept_num_lm; j++) { + pre_ids_all_lm[max_draft_tokens - 1 - j] = + accept_tokens_lm[accept_num_lm - 1 - j]; + } + LM2GM(&pre_ids_all_lm, + pre_ids_all + pre_ids_start_idx, + max_draft_tokens * sizeof(int64_t)); + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_repeat_times.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_repeat_times.xpu new file mode 100644 index 000000000..ce3898fb2 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_repeat_times.xpu @@ -0,0 +1,268 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/cluster_primitive_template.h" + +namespace xpu3 { +namespace plugin { + +static __device__ void atomic_add(_shared_ptr_ int *ptr, int v) { + bool fail = true; + while (fail) { + int a; + __asm__ __volatile__("loada.w %0,%1" : "=&r"(a) : "r"(ptr)); + a += v; + __asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(fail) : "r"(a), "r"(ptr)); + } +} + +// original version +__device__ void speculate_update_repeat_times_normal( + char *lm, + __shared_ptr__ char *sm, + __global_ptr__ const int64_t *pre_ids, + __global_ptr__ const int64_t *cur_len, + __global_ptr__ int *repeat_times, + __global_ptr__ const int *output_padding_offset, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t token_num, + const int64_t max_seq_len) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + int tid = clusterid * ncores + cid; + + const int max_sm_len = 256 * 1024 / sizeof(int); + __shared_ptr__ int *repeated_times_sm = (__shared_ptr__ int *)sm; + int64_t pre_id_lm; + int n_length = (length + max_sm_len - 1) / max_sm_len; + + int64_t *cur_len_lm = (int64_t *)lm; + int output_padding_offset_now; + GM2LM(cur_len, cur_len_lm, bs * sizeof(int64_t)); + + for (int nli = 0; nli < n_length; nli++) { + int step = nli * max_sm_len; + int cur_length = min(max_sm_len, length - step); + for (int64_t i = clusterid; i < token_num; i += nclusters) { + GM2LM(output_padding_offset + i, &output_padding_offset_now, sizeof(int)); + int64_t bi = (i + output_padding_offset_now) / max_seq_len; + if (bi >= bs || cur_len_lm[bi] < 0) { + continue; + } + if (cid == 0) { + GM2SM_ASYNC(repeat_times + i * length + step, + repeated_times_sm, + sizeof(int) * cur_length); + } + mfence(); + sync_cluster(); + for (int j = cid; j < length_id; j += ncores) { + GM2LM(pre_ids + bi * length_id + j, &pre_id_lm, sizeof(int64_t)); + if (pre_id_lm < 0) { + break; + } + if (pre_id_lm >= step && pre_id_lm < step + cur_length) { + atomic_add(repeated_times_sm + pre_id_lm - step, 1); + } + mfence(); + } + sync_cluster(); + if (cid == 0) { + SM2GM_ASYNC(repeated_times_sm, + repeat_times + i * length + step, + sizeof(int) * cur_length); + } + mfence(); + sync_cluster(); + } + } +} + +// best optimized version +// about 49000+ ns +__device__ void speculate_update_repeat_times_optimized( + char *lm, + __shared_ptr__ char *sm, + __global_ptr__ const int64_t *pre_ids, // {bs, length_id} + __global_ptr__ const int64_t *cur_len, // {bs} + __global_ptr__ int *repeat_times, // {token_num, length} + __global_ptr__ const int *output_padding_offset, // {token_num} + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t token_num, + const int64_t max_seq_len) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + int tid = clusterid * ncores + cid; + + const int repeat_times_sm_len = 250 * 1024 / sizeof(int); + __shared_ptr__ int *repeat_times_sm = (__shared_ptr__ int *)sm; + + // assert bs <= 640 + int cur_len_sm_len = 640; + __shared_ptr__ int64_t *cur_len_sm = + (__shared_ptr__ int64_t *)(repeat_times_sm + repeat_times_sm_len); + __shared_ptr__ int *output_padding_offset_sm = + (__shared_ptr__ int *)(cur_len_sm + cur_len_sm_len); + DoublePtr<1, SmPtr> buffer_ptr_output_padding_offset( + (SmPtr(output_padding_offset_sm))); + int pre_ids_lm_len = 4; + int64_t *pre_ids_lm = (int64_t *)lm; + DoublePtr<4, LmPtr> buffer_ptr_pre_ids((LmPtr(pre_ids_lm))); + + int64_t i = clusterid; + if (i < token_num && cid == 0) { + GM2SM_ASYNC(cur_len, cur_len_sm, bs * sizeof(int64_t)); + buffer_ptr_output_padding_offset.gm_load_async(output_padding_offset + i, + 1); + mfence_sm(); + } + sync_all(); + for (; i < token_num; i += nclusters) { + if (cid == 0 && i + nclusters < token_num) { + buffer_ptr_output_padding_offset.next().gm_load_async( + output_padding_offset + i + nclusters, 1); + } + int64_t bi = (i + (buffer_ptr_output_padding_offset.ptr[0])) / max_seq_len; + buffer_ptr_output_padding_offset.toggle(); + if (bi >= bs || cur_len_sm[bi] < 0) { + mfence_sm(); + sync_all(); + continue; + } + int64_t boundary = -1; + for (int64_t repeat_times_start = 0; repeat_times_start < length; + repeat_times_start += repeat_times_sm_len) { + int64_t repeat_times_read_size = + min(length - repeat_times_start, repeat_times_sm_len); + int64_t start, end; + partition(cid, ncores, repeat_times_read_size, 1, &start, &end); + int64_t load_start = repeat_times_start + start; + int64_t repeat_times_read_size_per_core = end - start; + if (repeat_times_read_size_per_core > 0) { + GM2SM(repeat_times + i * length + load_start, + repeat_times_sm + start, + repeat_times_read_size_per_core * sizeof(int)); + } + sync_all(); + // each core loads pre_ids step by step and record the index of pre_ids + // which is less than zero, and store the index to boundary + if (repeat_times_start == 0) { + bool do_prone = false; + int64_t j = cid * pre_ids_lm_len; + int64_t pre_ids_read_size = + min(static_cast(pre_ids_lm_len), length_id - j); + buffer_ptr_pre_ids.gm_load(pre_ids + bi * length_id + j, + pre_ids_read_size); + for (; j < length_id && !do_prone; j += ncores * pre_ids_lm_len) { + int64_t pre_ids_read_size_next = + min(static_cast(pre_ids_lm_len), + length_id - (j + ncores * pre_ids_lm_len)); + if (buffer_ptr_pre_ids.ptr[pre_ids_read_size - 1] >= 0 && + pre_ids_read_size_next > 0) { + buffer_ptr_pre_ids.next().gm_load_async( + pre_ids + bi * length_id + j + ncores * pre_ids_lm_len, + pre_ids_read_size_next); + } + for (int k = 0; k < pre_ids_read_size; k++) { + if (buffer_ptr_pre_ids.ptr[k] < 0) { + do_prone = true; + boundary = j + k; + break; + } + if (buffer_ptr_pre_ids.ptr[k] >= repeat_times_start && + buffer_ptr_pre_ids.ptr[k] < + repeat_times_start + repeat_times_read_size) { + atomic_add(repeat_times_sm + buffer_ptr_pre_ids.ptr[k] - + repeat_times_start, + 1); + } + } + mfence_lm(); + pre_ids_read_size = pre_ids_read_size_next; + buffer_ptr_pre_ids.toggle(); + } + } + // each core loads all the needed pre_ids into lm without mfence inbetween + // according to the index recorded by previous iteration + else { + int cnt = -1; + int64_t pre_ids_read_size = 0; + for (int64_t j = cid * pre_ids_lm_len; j < boundary; + j += ncores * pre_ids_lm_len) { + cnt++; + pre_ids_read_size = + min(static_cast(pre_ids_lm_len), boundary - j); + GM2LM_ASYNC(pre_ids + bi * length_id + j, + pre_ids_lm + cnt * pre_ids_lm_len, + pre_ids_read_size * sizeof(int64_t)); + } + mfence_lm(); + cnt = max(0, cnt); + for (int k = 0; k < cnt * pre_ids_lm_len + pre_ids_read_size; k++) { + if (pre_ids_lm[k] >= repeat_times_start && + pre_ids_lm[k] < repeat_times_start + repeat_times_read_size) { + atomic_add(repeat_times_sm + pre_ids_lm[k] - repeat_times_start, 1); + } + } + } + mfence_sm(); + sync_cluster(); + if (repeat_times_read_size_per_core > 0) { + SM2GM(repeat_times_sm + start, + repeat_times + i * length + load_start, + repeat_times_read_size_per_core * sizeof(int)); + } + sync_all(); + } + } +} + +__global__ void speculate_update_repeat_times(const int64_t *pre_ids, + const int64_t *cur_len, + int *repeat_times, + const int *output_padding_offset, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t token_num, + const int64_t max_seq_len) { + char lm[6 * 1024]; + __shared__ char sm[256 * 1024]; + + if (length_id <= 6 * 1024 * 64 / sizeof(int64_t)) { + speculate_update_repeat_times_optimized(lm, + sm, + pre_ids, + cur_len, + repeat_times, + output_padding_offset, + bs, + length, + length_id, + token_num, + max_seq_len); + } else { + speculate_update_repeat_times_normal(lm, + sm, + pre_ids, + cur_len, + repeat_times, + output_padding_offset, + bs, + length, + length_id, + token_num, + max_seq_len); + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_v3.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_v3.xpu new file mode 100644 index 000000000..d58ba574a --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_v3.xpu @@ -0,0 +1,202 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2025 KUNLUNXIN, Inc + */ + +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/cluster_primitive_template.h" + +namespace xpu3 { +namespace plugin { + +static inline __device__ int v_reduce(int32x16_t &v0, int32x16_t &v1) { + int res; + v1 = vvadd_int32x16(v0, v1); + auto v = vsrlp_int32x16(256, v1); + v1 = vvadd_int32x16(v, v1); + v = vsrlp_int32x16(128, v1); + v1 = vvadd_int32x16(v, v1); + v = vsrlp_int32x16(64, v1); + v1 = vvadd_int32x16(v, v1); + v = vsrlp_int32x16(32, v1); + v1 = vvadd_int32x16(v, v1); + res = vextract_int32x16(v1, 1); + return res; +} + +static inline __device__ int ClusterReduce( + const _shared_ptr_ int *stop_flag_now_int_sm, int len) { + int sum = 0; + if (core_id() == 0) { + int32x16_t vec_x_0; + int32x16_t vec_x_1; + int32x16_t vec_y_0 = vzero(); + int32x16_t vec_y_1 = vzero(); + for (int i = 0; i < len; i += 32) { + vload2_sm(stop_flag_now_int_sm + i, vec_x_0, vec_x_1); + vec_y_0 = vvadd_int32x16(vec_y_0, vec_x_0); + vec_y_1 = vvadd_int32x16(vec_y_1, vec_x_1); + } + sum = v_reduce(vec_y_0, vec_y_1); + } + return sum; +} + +template +__global__ void speculate_update_v3( + int *seq_lens_encoder, // 输入&输出 [B_max, ] + int *seq_lens_decoder, // 输入&输出 [B_max, ] + bool *not_need_stop, // 输出 [1,] + int64_t *draft_tokens, // 输出 [B_max, T_max] + int *actual_draft_token_nums, // 输入&输出 [B_max, ] + const int64_t *accept_tokens, // 输入 [B_max, T_max] + const int *accept_num, // 输入 [B_max, ] + const bool *stop_flags, // 输入 [B_max, ] + const int *seq_lens_this_time, // 输入 [B_real,] + const bool *is_block_step, // 输入 [B_max, ] + const int64_t *stop_nums, // 输入 [1,] + const int real_bsz, + const int max_bsz, + const int max_draft_tokens) { + // real_bsz <= max_bsz <= THREADBLOCK_SIZE; + + const int cid = core_id(); + const int tid = core_id() * cluster_num() + cluster_id(); + const int nthreads = core_num() * cluster_num(); + + __shared__ int seq_lens_encoder_sm[THREADBLOCK_SIZE]; // 输入&输出 [B_max] 2K + __shared__ int seq_lens_decoder_sm[THREADBLOCK_SIZE]; // 输入&输出 [B_max] 2K + __shared__ int + actual_draft_token_nums_sm[THREADBLOCK_SIZE]; // 输出 [B_max] 2K + __shared__ int accept_num_sm[THREADBLOCK_SIZE]; // 输入&输出 [B_max] 2K + __shared__ bool stop_flags_sm[THREADBLOCK_SIZE]; // 输入 [B_max] 512B + __shared__ int seq_lens_this_time_sm[THREADBLOCK_SIZE]; // 输入 [B_real] 2K + __shared__ bool is_block_step_sm[THREADBLOCK_SIZE]; // 输入 [B_max] 512B + __shared__ int stop_flag_now_int_sm[64]; + + bool not_need_stop_lm; // 输出[1] + int64_t stop_nums_lm; // 输入[1] + + int bid_start_core, bid_end_core; + partition(tid, nthreads, max_bsz, 1, &bid_start_core, &bid_end_core); + + if (cid == 0) { + GM2SM_ASYNC(seq_lens_encoder, seq_lens_encoder_sm, max_bsz * sizeof(int)); + GM2SM_ASYNC(seq_lens_decoder, seq_lens_decoder_sm, max_bsz * sizeof(int)); + GM2SM_ASYNC(actual_draft_token_nums, + actual_draft_token_nums_sm, + max_bsz * sizeof(int)); + GM2SM_ASYNC(accept_num, accept_num_sm, max_bsz * sizeof(int)); + GM2SM_ASYNC(stop_flags, stop_flags_sm, max_bsz * sizeof(bool)); + GM2SM_ASYNC( + seq_lens_this_time, seq_lens_this_time_sm, max_bsz * sizeof(int)); + GM2SM_ASYNC(is_block_step, is_block_step_sm, max_bsz * sizeof(bool)); + GM2LM_ASYNC(stop_nums, &stop_nums_lm, sizeof(int64_t)); + mfence_lm_sm(); + } + sync_all(); + + stop_flag_now_int_sm[cid] = 0; + for (int bid = bid_start_core; bid < bid_end_core; bid++) { + const int accept_num_now = accept_num_sm[bid]; + int stop_flag_now_int = 0; + if (!is_block_step_sm[bid] && bid < real_bsz) { + if (stop_flags_sm[bid]) { + stop_flag_now_int = 1; + } + if (seq_lens_encoder_sm[bid] == 0) { + seq_lens_decoder_sm[bid] += accept_num_now; + } + + // 对于append模式,需要根据接收与否确定是否要降低下次draft + // token的数量 + if (seq_lens_this_time_sm[bid] > 1 && seq_lens_encoder_sm[bid] == 0) { + auto current_actual_draft_token_num = actual_draft_token_nums_sm[bid]; + if (accept_num_now - 1 == current_actual_draft_token_num) { + if (current_actual_draft_token_num + 2 <= max_draft_tokens - 1) { + actual_draft_token_nums_sm[bid] = + current_actual_draft_token_num + 2; + } else if (current_actual_draft_token_num + 1 <= + max_draft_tokens - 1) { + actual_draft_token_nums_sm[bid] = + current_actual_draft_token_num + 1; + } else { + actual_draft_token_nums_sm[bid] = max_draft_tokens - 1; + } + } else { + actual_draft_token_nums_sm[bid] = + actual_draft_token_nums_sm[bid] - 1 >= 1 + ? actual_draft_token_nums_sm[bid] - 1 + : 1; + } + } + + if (seq_lens_encoder_sm[bid] != 0) { + seq_lens_decoder_sm[bid] += seq_lens_encoder_sm[bid]; + seq_lens_encoder_sm[bid] = 0; + } + + if (stop_flag_now_int) { + seq_lens_decoder_sm[bid] = 0; + } else { + // 这里试下编译器的新特性 + draft_tokens[bid * max_draft_tokens] = + accept_tokens[bid * max_draft_tokens + accept_num_now - 1]; + } + } else if (bid >= real_bsz && bid < max_bsz) { + stop_flag_now_int = 1; + } + stop_flag_now_int_sm[cid] += stop_flag_now_int; + mfence_lm(); + } + mfence_sm(); + sync_all(); + // printf("cid = %d, stop_sum = %d \n", cid, stop_flag_now_int_sm[cid]); + int64_t stop_sum = ClusterReduce(stop_flag_now_int_sm, 64); + sync_all(); + + if (cid == 0) { + // printf("stop_sum = %d \n", static_cast(stop_sum)); + not_need_stop_lm = stop_sum < stop_nums_lm; + mfence_lm(); + SM2GM_ASYNC(seq_lens_encoder_sm, seq_lens_encoder, max_bsz * sizeof(int)); + SM2GM_ASYNC(seq_lens_decoder_sm, seq_lens_decoder, max_bsz * sizeof(int)); + LM2GM_ASYNC(¬_need_stop_lm, not_need_stop, 1 * sizeof(bool)); + SM2GM_ASYNC(actual_draft_token_nums_sm, + actual_draft_token_nums, + max_bsz * sizeof(int)); + mfence(); + } +} + +template __global__ void speculate_update_v3<512>(int *seq_lens_encoder, + int *seq_lens_decoder, + bool *not_need_stop, + int64_t *draft_tokens, + int *actual_draft_token_nums, + const int64_t *accept_tokens, + const int *accept_num, + const bool *stop_flags, + const int *seq_lens_this_time, + const bool *is_block_step, + const int64_t *stop_nums, + const int real_bsz, + const int max_bsz, + const int max_draft_tokens); + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_value_by_repeat_times.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_value_by_repeat_times.xpu new file mode 100644 index 000000000..f685fcf9e --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_value_by_repeat_times.xpu @@ -0,0 +1,281 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_debug.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +__device__ void do_cast(const int *xlm, float *ylm, int64_t len) { + for (int64_t i = 0; i < len; i += 32) { + int32x16_t xl = vload_lm_int32x16(xlm + i); + int32x16_t xh = vload_lm_int32x16(xlm + i + 16); + float32x16_t yl = vfix2float(xl); + float32x16_t yh = vfix2float(xh); + vstore_lm_float32x16(ylm + i, yl); + vstore_lm_float32x16(ylm + i + 16, yh); + } + mfence_lm(); +} + +template +__global__ void speculate_update_value_by_repeat_times( + const int *repeat_times, + const T *penalty_scores, + const T *frequency_score, + const T *presence_score, + const float *temperatures, + T *logits, + const int *output_padding_offset, + const int64_t bs, + const int64_t length, + const int64_t token_num, + const int64_t max_seq_len) { + int ncores = core_num(); + int cid = core_id(); + int thread_id = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + int64_t start = -1; + int64_t end = -1; + partition(thread_id, nthreads, token_num * length, 1, &start, &end); + if (start >= end) { + return; + } + + int64_t token_start = start / length; + int64_t token_end = end / length; + if (token_end >= token_num) { + token_end = token_num - 1; + } + int output_padding_offset_start_lm; + int output_padding_offset_end_lm; + GM2LM_ASYNC(output_padding_offset + token_start, + (void *)&output_padding_offset_start_lm, + sizeof(int)); + GM2LM(output_padding_offset + token_end, + (void *)&output_padding_offset_end_lm, + sizeof(int)); + int64_t bs_start = + (token_start + output_padding_offset_start_lm) / max_seq_len; + int64_t bs_end = (token_end + output_padding_offset_end_lm) / max_seq_len; + const int param_len = 256; + // ncores = 64 for xpu2 + __shared__ __simd__ float alpha_buf[param_len * 64]; + __shared__ __simd__ float beta_buf[param_len * 64]; + __shared__ __simd__ float gamma_buf[param_len * 64]; + __shared__ __simd__ float temperatures_buf[param_len * 64]; + _shared_ptr_ float *alpha_sm = alpha_buf + cid * param_len; + _shared_ptr_ float *beta_sm = beta_buf + cid * param_len; + _shared_ptr_ float *gamma_sm = gamma_buf + cid * param_len; + _shared_ptr_ float *temperatures_sm = temperatures_buf + cid * param_len; + int read_param_len = bs_end - bs_start + 1; + GM2SM_ASYNC(penalty_scores + bs_start, alpha_sm, read_param_len * sizeof(T)); + GM2SM_ASYNC(frequency_score + bs_start, beta_sm, read_param_len * sizeof(T)); + GM2SM_ASYNC(presence_score + bs_start, gamma_sm, read_param_len * sizeof(T)); + GM2SM( + temperatures + bs_start, temperatures_sm, read_param_len * sizeof(float)); + primitive_cast_sm( + (const _shared_ptr_ T *)(alpha_sm), alpha_sm, read_param_len); + primitive_cast_sm( + (const _shared_ptr_ T *)(beta_sm), beta_sm, read_param_len); + primitive_cast_sm( + (const _shared_ptr_ T *)(gamma_sm), gamma_sm, read_param_len); + + float logit_now; + float alpha; + float beta; + float gamma; + float temperature; + int time; + const int buffer_len = 512; + __simd__ float logits_lm[buffer_len]; + int times_lm[buffer_len]; + int output_padding_offset_lm[buffer_len]; + + for (int64_t i = start; i < end; i += buffer_len) { + int read_len = min(end - i, buffer_len); + GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T)); + GM2LM_ASYNC(output_padding_offset + i / length, + output_padding_offset_lm, + ((read_len + length - 1) / length + 1) * sizeof(int)); + GM2LM(repeat_times + i, times_lm, read_len * sizeof(int)); + primitive_cast((const T *)(logits_lm), logits_lm, read_len); + for (int j = 0; j < read_len; j++) { + time = times_lm[j]; + logit_now = logits_lm[j]; + int token_idx = (i + j) / length; + int bs_idx = + (token_idx + output_padding_offset_lm[token_idx - i / length]) / + max_seq_len; + if (bs_idx >= bs) { + continue; + } + int param_idx = bs_idx - bs_start; + temperature = temperatures_sm[param_idx]; + if (time != 0) { + alpha = alpha_sm[param_idx]; + beta = beta_sm[param_idx]; + gamma = gamma_sm[param_idx]; + logit_now = logit_now < 0.0f ? logit_now * alpha : logit_now / alpha; + logit_now = logit_now - time * beta - gamma; + } + logits_lm[j] = logit_now / temperature; + } + mfence_lm(); + primitive_cast(logits_lm, (T *)logits_lm, read_len); + LM2GM(logits_lm, logits + i, read_len * sizeof(T)); + } +} + +#define _XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_(DATA_TYPE) \ + template __global__ void speculate_update_value_by_repeat_times( \ + const int *repeat_times, \ + const DATA_TYPE *penalty_scores, \ + const DATA_TYPE *frequency_score, \ + const DATA_TYPE *presence_score, \ + const float *temperatures, \ + DATA_TYPE *logits, \ + const int *output_padding_offset, \ + const int64_t bs, \ + const int64_t length, \ + const int64_t token_num, \ + const int64_t max_seq_len); +_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_(float); +_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_(float16); +_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_(bfloat16); + +template +__global__ void speculate_update_value_by_repeat_times_simd( + const int *repeat_times, // [bs * length] + const T *penalty_scores, // [bs] + const T *frequency_score, // [bs] + const T *presence_score, // [bs] + const float *temperatures, // [bs] + T *logits, // [bs * length] + const int *output_padding_offset, + const int64_t bs, + const int64_t length, + const int64_t token_num, + const int64_t max_seq_len) { + int ncores = core_num(); + int cid = core_id(); + int thread_id = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + int64_t start = -1; + int64_t end = -1; + partition(thread_id, nthreads, token_num * length, 16, &start, &end); + if (start >= end) { + return; + } + + const int param_len = 256; + // ncores = 64 for xpu3 + __shared__ __simd__ float alpha_buf[param_len * 64]; + __shared__ __simd__ float beta_buf[param_len * 64]; + __shared__ __simd__ float gamma_buf[param_len * 64]; + __shared__ __simd__ float temperatures_buf[param_len * 64]; + // assert bs <= param_len * 64 + if (cid == 0) { + GM2SM_ASYNC(penalty_scores, alpha_buf, bs * sizeof(T)); + GM2SM_ASYNC(frequency_score, beta_buf, bs * sizeof(T)); + GM2SM_ASYNC(presence_score, gamma_buf, bs * sizeof(T)); + GM2SM(temperatures, temperatures_buf, bs * sizeof(float)); + primitive_cast_sm( + (const _shared_ptr_ T *)(alpha_buf), alpha_buf, bs); + primitive_cast_sm( + (const _shared_ptr_ T *)(beta_buf), beta_buf, bs); + primitive_cast_sm( + (const _shared_ptr_ T *)(gamma_buf), gamma_buf, bs); + } + mfence(); + sync_all(); + + float logit_now; + float alpha; + float beta; + float gamma; + float temperature; + int time; + const int buffer_len = 512; + __simd__ float logits_lm[buffer_len]; + __simd__ float times_lm[buffer_len]; + int output_padding_offset_lm[buffer_len]; + + float32x16_t logits_; + float32x16_t logits_tmp_0; + float32x16_t logits_tmp_1; + float32x16_t time_; + + for (int64_t i = start; i < end; i += buffer_len) { + int read_len = min(end - i, buffer_len); + GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T)); + GM2LM_ASYNC(output_padding_offset + i / length, + output_padding_offset_lm, + ((read_len + length - 1) / length + 1) * sizeof(int)); + GM2LM(repeat_times + i, times_lm, read_len * sizeof(int)); + primitive_cast((const T *)(logits_lm), logits_lm, read_len); + do_cast((const int *)(times_lm), times_lm, read_len); + int time_mask = 0; + int logit_mask = 0; + for (int j = 0; j < read_len; j += 16) { + time_ = vload_lm_float32x16(times_lm + j); + logits_ = vload_lm_float32x16(logits_lm + j); + int token_idx = (i + j) / length; + int bs_idx = + (token_idx + output_padding_offset_lm[token_idx - i / length]) / + max_seq_len; + if (bs_idx >= bs) { + continue; + } + int param_idx = bs_idx; + temperature = temperatures_buf[param_idx]; + alpha = alpha_buf[param_idx]; + beta = beta_buf[param_idx]; + gamma = gamma_buf[param_idx]; + time_mask = svneq_float32x16(0.f, time_); // time != 0 mask + logit_mask = svle_float32x16(0.f, logits_); // logit >= 0 mask + time_ = svmul_float32x16(beta, time_); // time * beta + time_ = svadd_float32x16(gamma, time_); // time * beta + gamma + logits_ = svmul_float32x16_mh( + alpha, + logits_, + logits_, + (time_mask & + ~logit_mask)); // when time != 0 && logit < 0, do alpha * logit + logits_ = svmul_float32x16_mh( + 1.0f / alpha, + logits_, + logits_, + (time_mask & logit_mask)); // when time != 0 && >=0, do logit / alpha + logits_ = vvsub_float32x16_mh( + logits_, time_, logits_, time_mask); // when time != 0, do logit = + // logit - time * beta - gamma; + logits_ = + svmul_float32x16(1.0f / temperature, logits_); // logit / temperature + vstore_lm_float32x16(logits_lm + j, logits_); + } + mfence_lm(); + primitive_cast(logits_lm, (T *)logits_lm, read_len); + LM2GM(logits_lm, logits + i, read_len * sizeof(T)); + } +} + +#define _XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(DATA_TYPE) \ + template __global__ void speculate_update_value_by_repeat_times_simd( \ + const int *repeat_times, \ + const DATA_TYPE *penalty_scores, \ + const DATA_TYPE *frequency_score, \ + const DATA_TYPE *presence_score, \ + const float *temperatures, \ + DATA_TYPE *logits, \ + const int *output_padding_offset, \ + const int64_t bs, \ + const int64_t length, \ + const int64_t token_num, \ + const int64_t max_seq_len); +_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(float); +_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(float16); +_XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(bfloat16); + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu new file mode 100644 index 000000000..68eb2bd60 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu @@ -0,0 +1,335 @@ +#include "xpu/kernel/cluster_debug.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/xtdk.h" +#include "xpu/kernel/xtdk_math.h" +#include "xpu/kernel/xtdk_simd.h" +// #include "xpu/internal/aten/xrand_philox4x32_10.h" +// #include "xpu/internal/aten/xrand_uniform.h" +// #include "xpu/internal/aten/xrand_global.h" +namespace xpu3 { +namespace plugin { +static inline __device__ int v_reduce(int32x16_t &v0, int32x16_t &v1) { + int res; + v1 = vvadd_int32x16(v0, v1); + auto v = vsrlp_int32x16(256, v1); + v1 = vvadd_int32x16(v, v1); + v = vsrlp_int32x16(128, v1); + v1 = vvadd_int32x16(v, v1); + v = vsrlp_int32x16(64, v1); + v1 = vvadd_int32x16(v, v1); + v = vsrlp_int32x16(32, v1); + v1 = vvadd_int32x16(v, v1); + res = vextract_int32x16(v1, 1); + return res; +} +static inline __device__ int ClusterReduce( + const _shared_ptr_ int *stop_flag_now_int_sm, int len) { + int sum = 0; + if (core_id() == 0) { + int32x16_t vec_x_0; + int32x16_t vec_x_1; + int32x16_t vec_y_0 = vzero(); + int32x16_t vec_y_1 = vzero(); + for (int i = 0; i < len; i += 32) { + vload2_sm(stop_flag_now_int_sm + i, vec_x_0, vec_x_1); + vec_y_0 = vvadd_int32x16(vec_y_0, vec_x_0); + vec_y_1 = vvadd_int32x16(vec_y_1, vec_x_1); + } + sum = v_reduce(vec_y_0, vec_y_1); + } + return sum; +} +__device__ bool is_in_end(const int64_t id, + __global_ptr__ const int64_t *end_ids, + int length) { + bool flag = false; + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; + } + } + return flag; +} +__device__ inline bool is_in(__global_ptr__ const int64_t *candidates, + const int64_t draft, + const int candidate_len) { + for (int i = 0; i < candidate_len; i++) { + if (draft == candidates[i]) { + return true; + } + } + return false; +} +// static __device__ inline unsigned int xorwow(unsigned int& state) { +// state ^= state >> 7; +// state ^= state << 9; +// state ^= state >> 13; +// return state; +// } +static __device__ inline unsigned int xorwow(unsigned int &state) { + state ^= state >> 7; + state ^= state << 9; + state ^= state >> 13; + return state; +} +typedef uint32_t curandStatePhilox4_32_10_t; +__device__ int64_t +topp_sampling_kernel(__global_ptr__ const int64_t *candidate_ids, + __global_ptr__ const float *candidate_scores, + __global_ptr__ const float *dev_curand_states, + const int candidate_len, + const float topp) { + const int tid = core_id(); + float sum_scores = 0.0f; + float rand_top_p = *dev_curand_states * topp; + // printf("debug rand_top_p:%f\n",rand_top_p); + for (int i = 0; i < candidate_len; i++) { + sum_scores += candidate_scores[i]; + if (rand_top_p <= sum_scores) { + return candidate_ids[i]; + } + } + return candidate_ids[0]; +} +#define sm_size 1024 +template +__global__ void speculate_verify( + int64_t *accept_tokens, // out [real_bsz, max_draft_tokens], 输出最终接收的 + // token(通过验证或采样) + int *accept_num, // out [real_bsz], 每个序列最终接受的 token + // 数量(只统计通过验证的) + int64_t + *step_idx, // out [real_bsz], 记录每个bid序列已经生成或接受的token数 + bool *stop_flags, // out [real_bsz], 每个序列的停止标志,遇到 + // 或长度超限时置 true + const int *seq_lens_encoder, // [real_bsz], 每个样本 encoder + // 输入长度,用于判断 prefill 阶段 + const int *seq_lens_decoder, // [real_bsz], 每个样本 decoder 输出的 token + // 数(即 draft token 数) + const int64_t * + draft_tokens, // [real_bsz, max_draft_tokens], draft model 输出的 token + const int *actual_draft_token_nums, // [real_bsz], draft_tokens + // 中实际有效的 token 数量 + const float *dev_curand_states, // used for random + const float *topp, // [real_bsz],TopP 阈值(如 + // 0.9),用于控制核采样截断概率和候选数 + const int *seq_lens_this_time, // [real_bsz], 本轮 verify + // 阶段每个样本实际参与验证的 token 数 + const int64_t + *verify_tokens, // [sum(seq_lens_this_time), max_candidate_len], verify + // decoder 输出的候选 token + const float + *verify_scores, // 同上, 每个 verify token 对应的概率分布,用于采样 + const int64_t *max_dec_len, // [real_bsz], + // 每个样本允许生成的最大长度(超过则触发终止) + const int64_t + *end_tokens, // [end_length], 终止 token 列表(如 ),命中即终止 + const bool *is_block_step, // [real_bsz], 指示是否当前为 block step(为 + // true 时跳过 verify) + const int + *output_cum_offsets, // [real_bsz], verify_tokens 的起始偏移,用于定位 + // token 所在 verify 索引 + const int *actual_candidate_len, // [sum(seq_lens_this_time)], 每个 verify + // token 实际可用候选数(用于 TopP 截断) + const int real_bsz, // batch size + const int max_draft_tokens, // scalar, 每个样本最多允许的 draft token 数 + const int end_length, + const int max_seq_len, // scalar, 每个序列的最大 token 数(用于偏移计算) + const int max_candidate_len, // scalar, 每个 verify token + // 的最大候选数(用于验证或采样) + const int verify_window, // scalar, TopK 验证窗口(允许连续 top1 匹配次数) + const bool prefill_one_step_stop) { + const int cid = core_id(); + const int64_t tid = cluster_id() * core_num() + core_id(); + const int64_t nthreads = cluster_num() * core_num(); + for (int64_t bid = tid; bid < real_bsz; bid += nthreads) { + int stop_flag_now_int = 0; + int accept_num_now = 1; + if (is_block_step[bid]) { + continue; + } + const int start_token_id = bid * max_seq_len - output_cum_offsets[bid]; + if (stop_flags[bid]) { + stop_flag_now_int = 1; + } else { // 这里prefill阶段也会进入,但是因为draft + // tokens会置零,因此会直接到最后的采样阶段 + auto *verify_tokens_now = + verify_tokens + start_token_id * max_candidate_len; + auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens; + auto *actual_candidate_len_now = actual_candidate_len + start_token_id; + int i = 0; + // printf("seq_lens_this_time[%d]-1: %d \n",bid, + // seq_lens_this_time[bid]-1); + for (; i < seq_lens_this_time[bid] - 1; i++) { + if (seq_lens_encoder[bid] != 0) { + break; + } + if (USE_TOPK) { + if (verify_tokens_now[i * max_candidate_len] == + draft_tokens_now[i + 1]) { + // accept_num_now++; + step_idx[bid]++; + auto accept_token = draft_tokens_now[i + 1]; + // printf("[USE_TOPK] bid %d Top 1 verify write accept + // %d is %lld\n", bid, i, accept_token); + accept_tokens[bid * max_draft_tokens + i] = accept_token; + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + // printf("[USE_TOPK] bid %d Top 1 verify write + // accept %d is %lld\n", bid, i, accept_token); + break; + } else { + accept_num_now++; + } + } else { + break; + } + } else { + auto actual_candidate_len_value = + actual_candidate_len_now[i] > max_candidate_len + ? max_candidate_len + : actual_candidate_len_now[i]; + if (is_in(verify_tokens_now + i * max_candidate_len, + draft_tokens_now[i + 1], + actual_candidate_len_value)) { + // Top P verify + // accept_num_now++; + step_idx[bid]++; + auto accept_token = draft_tokens_now[i + 1]; + accept_tokens[bid * max_draft_tokens + i] = accept_token; + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + // printf("bid %d Top P verify write accept %d is + // %lld\n", bid, i, accept_token); + break; + } else { + accept_num_now++; + } + } else { + // TopK verify + int ii = i; + if (max_candidate_len >= 2 && + verify_tokens_now[ii * max_candidate_len + 1] == + draft_tokens_now[ii + 1]) { // top-2 + int j = 0; + ii += 1; + for (; j < verify_window && ii < seq_lens_this_time[bid] - 1; + j++, ii++) { + if (verify_tokens_now[ii * max_candidate_len] != + draft_tokens_now[ii + 1]) { + break; + } + } + if (j >= verify_window) { // accept all + accept_num_now += verify_window + 1; + step_idx[bid] += verify_window + 1; + for (; i < ii; i++) { + auto accept_token = draft_tokens_now[i + 1]; + accept_tokens[bid * max_draft_tokens + i] = accept_token; + // printf( + // "bid %d TopK verify write accept %d + // is " + // "%lld\n", + // bid, + // i, + // accept_token); + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + // printf("bid %d TopK verify write + // accept %d is %lld\n", bid, i, + // end_tokens[0]); + accept_num_now--; + step_idx[bid]--; + break; + } + } + } + } + break; + } + } + } + // sampling阶段 + // 第一种,draft_token[i+1]被拒绝,需要从verify_tokens_now[i]中选一个 + // 第二种,i == seq_lens_this_time[bid]-1, + // 也是从verify_tokens_now[i]中选一个 但是停止的情况不算 + if (!stop_flag_now_int) { + int64_t accept_token; + __global_ptr__ const float *verify_scores_now = + verify_scores + start_token_id * max_candidate_len; + step_idx[bid]++; + if (ENABLE_TOPP) { + auto actual_candidate_len_value = + actual_candidate_len_now[i] > max_candidate_len + ? max_candidate_len + : actual_candidate_len_now[i]; + accept_token = + topp_sampling_kernel(verify_tokens_now + i * max_candidate_len, + verify_scores_now + i * max_candidate_len, + dev_curand_states, + actual_candidate_len_value, + topp[bid]); + } else { + accept_token = verify_tokens_now[i * max_candidate_len]; + } + accept_tokens[bid * max_draft_tokens + i] = accept_token; + if (prefill_one_step_stop) { + stop_flags[bid] = true; + } + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + } + } + accept_num[bid] = accept_num_now; + } + } +} +#define SPECULATE_VERIFY_INSTANTIATE(ENABLE_TOPP, USE_TOPK) \ + template __global__ void speculate_verify( \ + int64_t * accept_tokens, \ + int *accept_num, \ + int64_t *step_idx, \ + bool *stop_flags, \ + const int *seq_lens_encoder, \ + const int *seq_lens_decoder, \ + const int64_t *draft_tokens, \ + const int *actual_draft_token_nums, \ + const float *dev_curand_states, \ + const float *topp, \ + const int *seq_lens_this_time, \ + const int64_t *verify_tokens, \ + const float *verify_scores, \ + const int64_t *max_dec_len, \ + const int64_t *end_tokens, \ + const bool *is_block_step, \ + const int *output_cum_offsets, \ + const int *actual_candidate_len, \ + int real_bsz, \ + int max_draft_tokens, \ + int end_length, \ + int max_seq_len, \ + int max_candidate_len, \ + int verify_window, \ + bool prefill_one_step_stop); +SPECULATE_VERIFY_INSTANTIATE(true, true) +SPECULATE_VERIFY_INSTANTIATE(true, false) +SPECULATE_VERIFY_INSTANTIATE(false, true) +SPECULATE_VERIFY_INSTANTIATE(false, false) +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/top_p_candidates.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/top_p_candidates.xpu new file mode 100644 index 000000000..d3b4ce8a6 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/top_p_candidates.xpu @@ -0,0 +1,349 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/cluster_primitive_template.h" + +namespace xpu3 { +namespace plugin { + +template +__device__ void top_p_candidates_big_n( + char* lm, + __global_ptr__ const T* src, + __global_ptr__ const T* top_ps, + __global_ptr__ const int* output_padding_offset, + __global_ptr__ int64_t* out_id, + __global_ptr__ T* out_val, + __global_ptr__ int* actual_candidates_lens, + int vocab_size, + int token_num, + int max_cadidate_len, + int max_seq_len) { + int ncores = core_num(); + int cid = core_id(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + + int64_t buf_size = 6 * 1024 / sizeof(T); + T* lm_src = (T*)lm; + int64_t lm_out_id[TopPBeamTopK]; + T lm_out_val[TopPBeamTopK]; + + __shared__ int64_t sm_out_id[64 * TopPBeamTopK]; + __shared__ T sm_out_val[64 * TopPBeamTopK]; + + // only used in core 0 + int lm_output_padding_offset; + + for (int64_t i = cluster_id(); i < token_num; i += cluster_num()) { + if (cid == 0) { + GM2LM(output_padding_offset + i, &lm_output_padding_offset, sizeof(int)); + } + for (int64_t j = 0; j < TopPBeamTopK; j++) { + lm_out_id[j] = -1; + } + for (int j = cid * buf_size; j < vocab_size; j += ncores * buf_size) { + int64_t read_size = min(buf_size, static_cast(vocab_size - j)); + GM2LM(src + i * vocab_size + j, lm_src, read_size * sizeof(T)); + for (int k = 0; k < read_size; k++) { + if (lm_out_id[TopPBeamTopK - 1] == -1 || + lm_src[k] > lm_out_val[TopPBeamTopK - 1] || + lm_src[k] == lm_out_val[TopPBeamTopK - 1] && + k < lm_out_id[TopPBeamTopK - 1]) { + int l = TopPBeamTopK - 2; + for (; l >= 0; l--) { + if (lm_out_id[l] == -1 || lm_src[k] > lm_out_val[l] || + lm_src[k] == lm_out_val[l] && (j + k) < lm_out_id[l]) { + lm_out_id[l + 1] = lm_out_id[l]; + lm_out_val[l + 1] = lm_out_val[l]; + } else { + break; + } + } + lm_out_id[l + 1] = j + k; + lm_out_val[l + 1] = lm_src[k]; + } + } + mfence_lm(); + } + + if (cid % 16 != 0) { + for (int64_t j = 0; j < TopPBeamTopK; j++) { + sm_out_id[cid * TopPBeamTopK + j] = lm_out_id[j]; + sm_out_val[cid * TopPBeamTopK + j] = lm_out_val[j]; + } + } + mfence_sm(); + sync_all(); + + if (cid % 16 == 0) { + int64_t local_sm_out_id; + T local_sm_out_val; + for (int j = cid + 1; j < cid + 16; j += 1) { + for (int offset = 0; offset < TopPBeamTopK; offset++) { + local_sm_out_id = sm_out_id[j * TopPBeamTopK + offset]; + local_sm_out_val = sm_out_val[j * TopPBeamTopK + offset]; + if (local_sm_out_val > lm_out_val[TopPBeamTopK - 1] || + local_sm_out_val == lm_out_val[TopPBeamTopK - 1] && + local_sm_out_id < lm_out_id[TopPBeamTopK - 1]) { + int k = TopPBeamTopK - 2; + for (; k >= 0; k--) { + if (local_sm_out_val > lm_out_val[k] || + local_sm_out_val == lm_out_val[k] && + local_sm_out_id < lm_out_id[k]) { + lm_out_id[k + 1] = lm_out_id[k]; + lm_out_val[k + 1] = lm_out_val[k]; + } else { + break; + } + } + lm_out_id[k + 1] = local_sm_out_id; + lm_out_val[k + 1] = local_sm_out_val; + } else { + break; + } + } + } + if (cid != 0) { + for (int64_t j = 0; j < TopPBeamTopK; j++) { + sm_out_id[cid * TopPBeamTopK + j] = lm_out_id[j]; + sm_out_val[cid * TopPBeamTopK + j] = lm_out_val[j]; + } + } + } + mfence_sm(); + sync_all(); + + if (cid == 0) { + int64_t local_sm_out_id; + T local_sm_out_val; + for (int j = cid + 16; j < ncores; j += 16) { + for (int offset = 0; offset < TopPBeamTopK; offset++) { + local_sm_out_id = sm_out_id[j * TopPBeamTopK + offset]; + local_sm_out_val = sm_out_val[j * TopPBeamTopK + offset]; + if (local_sm_out_val > lm_out_val[TopPBeamTopK - 1] || + local_sm_out_val == lm_out_val[TopPBeamTopK - 1] && + local_sm_out_id < lm_out_id[TopPBeamTopK - 1]) { + int k = TopPBeamTopK - 2; + for (; k >= 0; k--) { + if (local_sm_out_val > lm_out_val[k] || + local_sm_out_val == lm_out_val[k] && + local_sm_out_id < lm_out_id[k]) { + lm_out_id[k + 1] = lm_out_id[k]; + lm_out_val[k + 1] = lm_out_val[k]; + } else { + break; + } + } + lm_out_id[k + 1] = local_sm_out_id; + lm_out_val[k + 1] = local_sm_out_val; + } else { + break; + } + } + } + + int ori_token_id = i + lm_output_padding_offset; + int bid = ori_token_id / max_seq_len; + T lm_top_p; + GM2LM(top_ps + bid, &lm_top_p, sizeof(T)); + float top_p_value = static_cast(lm_top_p); + T default_val = static_cast(0.f); + int lm_actual_candidates_len = 0; + + float sum_prob = static_cast(lm_out_val[0]); + for (int j = 0; j < TopPBeamTopK; j++) { + if (sum_prob >= top_p_value) { + for (int k = j + 1; k < TopPBeamTopK; k++) { + lm_out_id[k] = 0; + lm_out_val[k] = default_val; + } + lm_actual_candidates_len = j + 1; + break; + } else { + sum_prob += static_cast(lm_out_val[j]); + } + } + mfence_lm(); + LM2GM_ASYNC( + &lm_actual_candidates_len, actual_candidates_lens + i, sizeof(int)); + LM2GM_ASYNC(lm_out_id, + out_id + i * max_cadidate_len, + TopPBeamTopK * sizeof(int64_t)); + LM2GM_ASYNC( + lm_out_val, out_val + i * max_cadidate_len, TopPBeamTopK * sizeof(T)); + } + mfence(); + sync_all(); + } +} + +template +__device__ void top_p_candidates_normal( + char* lm, + __global_ptr__ const T* src, + __global_ptr__ const T* top_ps, + __global_ptr__ const int* output_padding_offset, + __global_ptr__ int64_t* out_id, + __global_ptr__ T* out_val, + __global_ptr__ int* actual_candidates_lens, + int vocab_size, + int token_num, + int max_cadidate_len, + int max_seq_len) { + int ncores = core_num(); + int cid = core_id(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + + int64_t buf_size = 6 * 1024 / sizeof(T); + T* lm_src = (T*)lm; + int64_t lm_out_id[TopPBeamTopK]; + T lm_out_val[TopPBeamTopK]; + + int lm_output_padding_offset; + T lm_top_p; + int64_t default_id = 0; + T default_val = static_cast(0.f); + + for (int64_t i = tid; i < token_num; i += nthreads) { + float sum_prob = 0.0f; + for (int64_t j = 0; j < TopPBeamTopK; j++) { + lm_out_id[j] = -1; + } + for (int j = 0; j < vocab_size; j += buf_size) { + int64_t read_size = min(buf_size, static_cast(vocab_size - j)); + GM2LM(src + i * vocab_size + j, lm_src, read_size * sizeof(T)); + for (int k = 0; k < read_size; k++) { + if (lm_out_id[TopPBeamTopK - 1] == -1 || + lm_src[k] > lm_out_val[TopPBeamTopK - 1] || + lm_src[k] == lm_out_val[TopPBeamTopK - 1] && + k < lm_out_id[TopPBeamTopK - 1]) { + lm_out_id[TopPBeamTopK - 1] = j + k; + lm_out_val[TopPBeamTopK - 1] = lm_src[k]; + for (int l = TopPBeamTopK - 2; l >= 0; l--) { + if (lm_out_id[l] == -1 || lm_out_val[l + 1] > lm_out_val[l] || + lm_out_val[l + 1] == lm_out_val[l] && + lm_out_id[l + 1] < lm_out_id[l]) { + int64_t swap_id = lm_out_id[l]; + T swap_val = lm_out_val[l]; + lm_out_id[l] = lm_out_id[l + 1]; + lm_out_val[l] = lm_out_val[l + 1]; + lm_out_id[l + 1] = swap_id; + lm_out_val[l + 1] = swap_val; + } + } + } + } + mfence_lm(); + } + GM2LM(output_padding_offset + i, &lm_output_padding_offset, sizeof(int)); + int ori_token_id = i + lm_output_padding_offset; + int bid = ori_token_id / max_seq_len; + GM2LM(top_ps + bid, &lm_top_p, sizeof(T)); + float top_p_value = static_cast(lm_top_p); + bool set_to_default_val = false; + int lm_actual_candidates_len = 0; + for (int j = 0; j < TopPBeamTopK; j++) { + if (set_to_default_val) { + LM2GM_ASYNC( + &default_id, out_id + i * max_cadidate_len + j, sizeof(int64_t)); + LM2GM_ASYNC( + &default_val, out_val + i * max_cadidate_len + j, sizeof(T)); + } else { + LM2GM_ASYNC( + lm_out_id + j, out_id + i * max_cadidate_len + j, sizeof(int64_t)); + LM2GM_ASYNC( + lm_out_val + j, out_val + i * max_cadidate_len + j, sizeof(T)); + sum_prob += static_cast(lm_out_val[j]); + if (sum_prob >= top_p_value) { + lm_actual_candidates_len = j + 1; + mfence_lm(); + LM2GM_ASYNC(&lm_actual_candidates_len, + actual_candidates_lens + i, + sizeof(int)); + set_to_default_val = true; + } + } + } + mfence_lm(); + } +} + +template +__global__ void top_p_candidates(const T* src, + const T* top_ps, + const int* output_padding_offset, + int64_t* out_id, + T* out_val, + int* actual_candidates_lens, + int vocab_size, + int token_num, + int max_cadidate_len, + int max_seq_len) { + char lm[6 * 1024]; + if (token_num % (core_num() * cluster_num()) != 0 && + vocab_size >= core_num() * (6 * 1024 / sizeof(T)) && + vocab_size >= core_num() * TopPBeamTopK) { + top_p_candidates_big_n(lm, + src, + top_ps, + output_padding_offset, + out_id, + out_val, + actual_candidates_lens, + vocab_size, + token_num, + max_cadidate_len, + max_seq_len); + } else { + top_p_candidates_normal(lm, + src, + top_ps, + output_padding_offset, + out_id, + out_val, + actual_candidates_lens, + vocab_size, + token_num, + max_cadidate_len, + max_seq_len); + } +} + +#define _XPU_DEF_TOP_P_CANDIDATES_KERNEL(T, MaxLength, TopPBeamTopK) \ + template __global__ void top_p_candidates( \ + const T* src, \ + const T* top_ps, \ + const int* output_padding_offset, \ + int64_t* out_id, \ + T* out_val, \ + int* actual_candidates_lens, \ + int vocab_size, \ + int token_num, \ + int max_cadidate_len, \ + int max_seq_len); + +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(bfloat16, 2, 2); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(bfloat16, 2, 3); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(bfloat16, 2, 4); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(bfloat16, 2, 5); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(bfloat16, 2, 8); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(bfloat16, 2, 10); + +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float16, 2, 2); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float16, 2, 3); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float16, 2, 4); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float16, 2, 5); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float16, 2, 8); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float16, 2, 10); + +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float, 2, 2); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float, 2, 3); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float, 2, 4); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float, 2, 5); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float, 2, 8); +_XPU_DEF_TOP_P_CANDIDATES_KERNEL(float, 2, 10); + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/compute_order.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/compute_order.cpp new file mode 100644 index 000000000..64a45ad9b --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/compute_order.cpp @@ -0,0 +1,181 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +__attribute__((global)) void ComputeOrderKernel( + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* accept_nums, + int* position_map, + int* output_token_num, + const int bsz, + const int actual_draft_token_num, + const int input_token_num); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context* ctx, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* accept_nums, + int* position_map, + int* output_token_num, + const int bsz, + const int actual_draft_token_num, + const int input_token_num) { + int in_offset = 0; // input_offset(long) + int out_offset = 0; // output_offset(short) + for (int i = 0; i < bsz; ++i) { + int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i]; + int cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i]; + int cur_seq_lens_this_time = seq_lens_this_time[i]; + int accept_num = accept_nums[i]; + int cur_seq_lens_encoder = seq_lens_encoder[i]; + + // 1. eagle encoder. Base step=1 + if (cur_seq_lens_encoder > 0) { + for (int j = 0; j < cur_seq_lens_encoder; j++) { + position_map[in_offset++] = out_offset++; + } + // 2. base model encoder. Base step=0 + } else if (cur_base_model_seq_lens_encoder != 0) { + // nothing happens + // 3. New end + } else if (cur_base_model_seq_lens_this_time != 0 && + cur_seq_lens_this_time == 0) { + in_offset += cur_base_model_seq_lens_this_time; + // 4. stopped + } else if (cur_base_model_seq_lens_this_time == 0 && + cur_seq_lens_this_time == 0) /* end */ { + // nothing happens + } else { + if (accept_num <= + actual_draft_token_num) /*Accept partial draft tokens*/ { + position_map[in_offset + accept_num - 1] = out_offset++; + in_offset += cur_base_model_seq_lens_this_time; + } else /*Accept all draft tokens*/ { + position_map[in_offset + accept_num - 2] = out_offset++; + position_map[in_offset + accept_num - 1] = out_offset++; + in_offset += cur_base_model_seq_lens_this_time; + } + } + } + output_token_num[0] = out_offset; + return api::SUCCESS; +} + +static int xpu3_wrapper(Context* ctx, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* accept_nums, + int* position_map, + int* output_token_num, + const int bsz, + const int actual_draft_token_num, + const int input_token_num) { + xpu3::plugin::ComputeOrderKernel<<<1, 1, ctx->xpu_stream>>>( + seq_lens_this_time, + seq_lens_encoder, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + accept_nums, + position_map, + output_token_num, + bsz, + actual_draft_token_num, + input_token_num); + return api::SUCCESS; +} + +int compute_order(Context* ctx, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* accept_nums, + int* position_map, + int* output_token_num, + const int bsz, + const int actual_draft_token_num, + const int input_token_num) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_PARAM5(ctx, + seq_lens_this_time, + seq_lens_encoder, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + accept_nums); + WRAPPER_DUMP_PARAM5(ctx, + position_map, + output_token_num, + bsz, + actual_draft_token_num, + input_token_num); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_encoder); + WRAPPER_CHECK_PTR(ctx, int, bsz, base_model_seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int, bsz, base_model_seq_lens_encoder); + WRAPPER_CHECK_PTR(ctx, int, bsz, accept_nums); + WRAPPER_CHECK_PTR(ctx, int, input_token_num, position_map); + WRAPPER_CHECK_PTR(ctx, int, 1, output_token_num); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + seq_lens_this_time, + seq_lens_encoder, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + accept_nums, + position_map, + output_token_num, + bsz, + actual_draft_token_num, + input_token_num); + } else if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + seq_lens_this_time, + seq_lens_encoder, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + accept_nums, + position_map, + output_token_num, + bsz, + actual_draft_token_num, + input_token_num); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/compute_self_order.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/compute_self_order.cpp new file mode 100644 index 000000000..a197d858b --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/compute_self_order.cpp @@ -0,0 +1,133 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +__attribute__((global)) void ComputeSelfOrderKernel( + const int* last_seq_lens_this_time, + const int* seq_lens_this_time, + const int64_t* step_idx, + int* src_map, + int* output_token_num, + int bsz); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context* ctx, + const int* last_seq_lens_this_time, + const int* seq_lens_this_time, + const int64_t* step_idx, + int* src_map, + int* output_token_num, + int bsz) { + int in_offset = 0; + int out_offset = 0; + for (int i = 0; i < bsz; i++) { + int cur_seq_lens_this_time = seq_lens_this_time[i]; + int cur_last_seq_lens_this_time = last_seq_lens_this_time[i]; + + // 1. encoder + if (step_idx[i] == 1 && cur_seq_lens_this_time > 0) { + in_offset += 1; + src_map[out_offset++] = in_offset - 1; + // 2. decoder + } else if (cur_seq_lens_this_time > 0) /* =1 */ { + in_offset += cur_last_seq_lens_this_time; + src_map[out_offset++] = in_offset - 1; + // 3. stop + } else { + // first token end + if (step_idx[i] == 1) { + in_offset += cur_last_seq_lens_this_time > 0 ? 1 : 0; + // normal end + } else { + in_offset += cur_last_seq_lens_this_time; + } + } + } + output_token_num[0] = out_offset; + return api::SUCCESS; +} + +static int xpu3_wrapper(Context* ctx, + const int* last_seq_lens_this_time, + const int* seq_lens_this_time, + const int64_t* step_idx, + int* src_map, + int* output_token_num, + int bsz) { + using XPU_INT64 = typename XPUIndexType::type; + xpu3::plugin::ComputeSelfOrderKernel<<<1, 1, ctx->xpu_stream>>>( + last_seq_lens_this_time, + seq_lens_this_time, + reinterpret_cast(step_idx), + src_map, + output_token_num, + bsz); + return api::SUCCESS; +} + +int compute_self_order(Context* ctx, + const int* last_seq_lens_this_time, + const int* seq_lens_this_time, + const int64_t* step_idx, + int* src_map, + int* output_token_num, + int bsz) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_PARAM6(ctx, + last_seq_lens_this_time, + seq_lens_this_time, + step_idx, + src_map, + output_token_num, + bsz); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, int, bsz, last_seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz, step_idx); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + last_seq_lens_this_time, + seq_lens_this_time, + step_idx, + src_map, + output_token_num, + bsz); + } else if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + last_seq_lens_this_time, + seq_lens_this_time, + step_idx, + src_map, + output_token_num, + bsz); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_postprocess.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_postprocess.cpp new file mode 100644 index 000000000..a62937941 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_postprocess.cpp @@ -0,0 +1,142 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu2 { +namespace plugin { +__attribute__((global)) void draft_model_postprocess( + const int64_t* base_model_draft_tokens, + int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const bool* base_model_stop_flags, + int bsz, + int base_model_draft_token_len); +} // namespace plugin +} // namespace xpu2 + +namespace xpu3 { +namespace plugin { +__attribute__((global)) void draft_model_postprocess( + const int64_t* base_model_draft_tokens, + int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const bool* base_model_stop_flags, + int bsz, + int base_model_draft_token_len); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { +static int cpu_wrapper( + Context* ctx, + const int64_t* + base_model_draft_tokens, // size = [bsz, base_model_draft_token_len] + int* base_model_seq_lens_this_time, // size = [bsz] + const int* base_model_seq_lens_encoder, // size = [bsz] + const bool* base_model_stop_flags, // size = [bsz] + int bsz, + int base_model_draft_token_len) { + // 遍历每个样本 + for (int tid = 0; tid < bsz; ++tid) { + if (!base_model_stop_flags[tid] && base_model_seq_lens_encoder[tid] == 0) { + // 获取当前样本的草稿token指针 + const int64_t* base_model_draft_tokens_now = + base_model_draft_tokens + tid * base_model_draft_token_len; + // 计算有效token数量(非-1的token) + int token_num = 0; + for (int i = 0; i < base_model_draft_token_len; ++i) { + if (base_model_draft_tokens_now[i] != -1) { + token_num++; + } + } + // 更新序列长度 + base_model_seq_lens_this_time[tid] = token_num; + } else if (base_model_stop_flags[tid]) { + // 已停止的样本序列长度为0 + base_model_seq_lens_this_time[tid] = 0; + } + } + return api::SUCCESS; +} + +static int xpu3_wrapper(Context* ctx, + const int64_t* base_model_draft_tokens, + int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const bool* base_model_stop_flags, + int bsz, + int base_model_draft_token_len) { + xpu3::plugin::draft_model_postprocess<<ncluster(), 64, ctx->xpu_stream>>>( + reinterpret_cast(base_model_draft_tokens), + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_stop_flags, + bsz, + base_model_draft_token_len); + return api::SUCCESS; +} + +int draft_model_postprocess(Context* ctx, + const int64_t* base_model_draft_tokens, + int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const bool* base_model_stop_flags, + int bsz, + int base_model_draft_token_len) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_PARAM6(ctx, + base_model_draft_tokens, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_stop_flags, + bsz, + base_model_draft_token_len); + WRAPPER_DUMP(ctx); + WRAPPER_CHECK_PTR( + ctx, int64_t, bsz * base_model_draft_token_len, base_model_draft_tokens); + WRAPPER_CHECK_PTR(ctx, int, bsz, base_model_seq_lens_encoder); + WRAPPER_CHECK_PTR(ctx, bool, bsz, base_model_stop_flags); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + base_model_draft_tokens, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_stop_flags, + bsz, + base_model_draft_token_len); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + base_model_draft_tokens, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_stop_flags, + bsz, + base_model_draft_token_len); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +// template int draft_model_postprocess( +// Context*, const int64_t*, int*, const int*, const bool*, int, int); +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess.cpp new file mode 100644 index 000000000..9ca1f2224 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess.cpp @@ -0,0 +1,392 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl/launch_strategy.h" +#include "xpu/refactor/impl_public/wrapper_check.h" +#include "xpu/xdnn.h" + +namespace xpu3 { +namespace plugin { +__attribute__((global)) void draft_model_preprocess( + int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + int* seq_lens_encoder_record, + int* seq_lens_decoder_record, + bool* not_need_stop, + bool* batch_drop, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + int real_bsz, + int max_draft_token, + int accept_tokens_len, + int draft_tokens_len, + int input_ids_len, + int base_model_draft_tokens_len, + bool truncate_first_token, + bool splitwise_prefill); +} // namespace plugin +} // namespace xpu3 + +namespace xpu2 { +namespace plugin {} // namespace plugin +} // namespace xpu2 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(api::Context* ctx, + int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + int* seq_lens_encoder_record, + int* seq_lens_decoder_record, + bool* not_need_stop, + bool* batch_drop, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + int real_bsz, + int max_draft_token, + int accept_tokens_len, + int draft_tokens_len, + int input_ids_len, + int base_model_draft_tokens_len, + bool truncate_first_token, + bool splitwise_prefill) { + int64_t not_stop_flag_sum = 0; + int64_t not_stop_flag = 0; + for (int tid = 0; tid < real_bsz; tid++) { + if (splitwise_prefill) { + int base_model_step_idx_now = base_model_step_idx[tid]; + auto* input_ids_now = input_ids + tid * input_ids_len; + auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; + // printf("bid: %d, base_model_step_idx_now: %d seq_lens_encoder_record: + // %d\n", tid, base_model_step_idx_now, seq_lens_encoder_record[tid]); + if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) { + not_stop_flag = 1; + int seq_len_encoder_record = seq_lens_encoder_record[tid]; + seq_lens_encoder[tid] = seq_len_encoder_record; + seq_lens_encoder_record[tid] = -1; + stop_flags[tid] = false; + int64_t base_model_first_token = accept_tokens_now[0]; + int position = seq_len_encoder_record; + if (truncate_first_token) { + input_ids_now[position - 1] = base_model_first_token; + seq_lens_this_time[tid] = seq_len_encoder_record; + } else { + input_ids_now[position] = base_model_first_token; + seq_lens_this_time[tid] = seq_len_encoder_record + 1; + } + } else { + stop_flags[tid] = true; + seq_lens_this_time[tid] = 0; + seq_lens_decoder[tid] = 0; + seq_lens_encoder[tid] = 0; + not_stop_flag = 0; + } + not_stop_flag_sum += not_stop_flag; + } else { + auto base_model_step_idx_now = base_model_step_idx[tid]; + auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; + auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len; + auto accept_num_now = accept_num[tid]; + auto* input_ids_now = input_ids + tid * input_ids_len; + auto* base_model_draft_tokens_now = + base_model_draft_tokens + tid * base_model_draft_tokens_len; + for (int i = 1; i < base_model_draft_tokens_len; i++) { + base_model_draft_tokens_now[i] = -1; + } + if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { + batch_drop[tid] = true; + stop_flags[tid] = true; + } + + if (!(base_model_stop_flags[tid] || batch_drop[tid])) { + not_stop_flag = 1; + // 1. first token + + if (base_model_step_idx_now == 0) { + seq_lens_this_time[tid] = 0; + not_stop_flag = 0; + } else if (base_model_step_idx_now == 1 && + seq_lens_encoder_record[tid] > 0) { + // Can be extended to first few tokens + int seq_len_encoder_record = seq_lens_encoder_record[tid]; + seq_lens_encoder[tid] = seq_len_encoder_record; + seq_lens_encoder_record[tid] = -1; + seq_lens_decoder[tid] = seq_lens_decoder_record[tid]; + seq_lens_decoder_record[tid] = 0; + stop_flags[tid] = false; + int64_t base_model_first_token = accept_tokens_now[0]; + int position = seq_len_encoder_record; + if (truncate_first_token) { + input_ids_now[position - 1] = base_model_first_token; + seq_lens_this_time[tid] = seq_len_encoder_record; + } else { + input_ids_now[position] = base_model_first_token; + seq_lens_this_time[tid] = seq_len_encoder_record + 1; + } + } else if (accept_num_now <= + max_draft_token) /*Accept partial draft tokens*/ { + // Base Model reject stop + if (stop_flags[tid]) { + stop_flags[tid] = false; + seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid]; + step_idx[tid] = base_model_step_idx[tid]; + } else { + seq_lens_decoder[tid] -= max_draft_token - accept_num_now; + step_idx[tid] -= max_draft_token - accept_num_now; + } + int64_t modified_token = accept_tokens_now[accept_num_now - 1]; + draft_tokens_now[0] = modified_token; + seq_lens_this_time[tid] = 1; + } else /*Accept all draft tokens*/ { + draft_tokens_now[1] = accept_tokens_now[max_draft_token]; + seq_lens_this_time[tid] = 2; + } + } else { + stop_flags[tid] = true; + seq_lens_this_time[tid] = 0; + seq_lens_decoder[tid] = 0; + seq_lens_encoder[tid] = 0; + } + not_stop_flag_sum += not_stop_flag; + } + } + not_need_stop[0] = not_stop_flag_sum > 0; + return api::SUCCESS; +} + +static int xpu3_wrapper(api::Context* ctx, + int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + int* seq_lens_encoder_record, + int* seq_lens_decoder_record, + bool* not_need_stop, + bool* batch_drop, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + int real_bsz, + int max_draft_token, + int accept_tokens_len, + int draft_tokens_len, + int input_ids_len, + int base_model_draft_tokens_len, + bool truncate_first_token, + bool splitwise_prefill) { + using XPU_INT64 = typename XPUIndexType::type; + + // NOTE: Don't change 16 to 64, because kernel use gsm + xpu3::plugin::draft_model_preprocess<<<1, 64, ctx->xpu_stream>>>( + reinterpret_cast(draft_tokens), + reinterpret_cast(input_ids), + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + reinterpret_cast(step_idx), + seq_lens_encoder_record, + seq_lens_decoder_record, + not_need_stop, + batch_drop, + reinterpret_cast(accept_tokens), + accept_num, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + reinterpret_cast(base_model_step_idx), + base_model_stop_flags, + base_model_is_block_step, + reinterpret_cast(base_model_draft_tokens), + real_bsz, + max_draft_token, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + truncate_first_token, + splitwise_prefill); + return api::SUCCESS; +} + +int draft_model_preprocess(api::Context* ctx, + int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + int* seq_lens_encoder_record, + int* seq_lens_decoder_record, + bool* not_need_stop, + bool* batch_drop, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + int real_bsz, + int max_draft_token, + int accept_tokens_len, + int draft_tokens_len, + int input_ids_len, + int base_model_draft_tokens_len, + bool truncate_first_token, + bool splitwise_prefill) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "draft_model_preprocess", int64_t); + WRAPPER_DUMP_PARAM6(ctx, + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder); + WRAPPER_DUMP_PARAM5(ctx, + step_idx, + seq_lens_encoder_record, + seq_lens_decoder_record, + not_need_stop, + batch_drop); + WRAPPER_DUMP_PARAM3( + ctx, accept_tokens, accept_num, base_model_seq_lens_encoder); + WRAPPER_DUMP_PARAM3(ctx, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags); + WRAPPER_DUMP_PARAM3( + ctx, base_model_is_block_step, base_model_draft_tokens, real_bsz); + WRAPPER_DUMP_PARAM3( + ctx, max_draft_token, accept_tokens_len, draft_tokens_len); + WRAPPER_DUMP_PARAM3( + ctx, input_ids_len, base_model_draft_tokens_len, truncate_first_token); + WRAPPER_DUMP_PARAM1(ctx, splitwise_prefill); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * accept_tokens_len, accept_tokens); + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * input_ids_len, input_ids); + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * draft_tokens_len, draft_tokens); + WRAPPER_CHECK_PTR(ctx, + int64_t, + real_bsz * base_model_draft_tokens_len, + base_model_draft_tokens); + + WRAPPER_ASSERT_GT(ctx, real_bsz, 0); + WRAPPER_ASSERT_LT(ctx, accept_tokens_len, 128); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + seq_lens_encoder_record, + seq_lens_decoder_record, + not_need_stop, + batch_drop, + accept_tokens, + accept_num, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + real_bsz, + max_draft_token, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + truncate_first_token, + splitwise_prefill); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + seq_lens_encoder_record, + seq_lens_decoder_record, + not_need_stop, + batch_drop, + accept_tokens, + accept_num, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + real_bsz, + max_draft_token, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + truncate_first_token, + splitwise_prefill); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_update.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_update.cpp new file mode 100644 index 000000000..3fdaece7e --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_update.cpp @@ -0,0 +1,324 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +__attribute__((global)) void draft_model_update( + const int64_t* inter_next_tokens, + int64_t* draft_tokens, + int64_t* pre_ids, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + const int* output_cum_offsets, + bool* stop_flags, + bool* not_need_stop, + const int64_t* max_dec_len, + const int64_t* end_ids, + int64_t* base_model_draft_tokens, + const int bsz, + const int max_draft_token, + const int pre_id_length, + const int max_base_model_draft_token, + const int end_ids_len, + const int max_seq_len, + const int substep, + const bool prefill_one_step_stop); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +bool is_in_end(int64_t token, const int64_t* end_ids, int end_ids_len) { + for (int i = 0; i < end_ids_len; ++i) { + if (end_ids[i] == token) { + return true; + } + } + return false; +} + +static int cpu_wrapper(Context* ctx, + const int64_t* inter_next_tokens, + int64_t* draft_tokens, + int64_t* pre_ids, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + const int* output_cum_offsets, + bool* stop_flags, + bool* not_need_stop, + const int64_t* max_dec_len, + const int64_t* end_ids, + int64_t* base_model_draft_tokens, + const int bsz, + const int max_draft_token, + const int pre_id_length, + const int max_base_model_draft_token, + const int end_ids_len, + const int max_seq_len, + const int substep, + const bool prefill_one_step_stop) { + int64_t stop_sum = 0; + + // 遍历所有batch + for (int tid = 0; tid < bsz; ++tid) { + auto* draft_token_now = draft_tokens + tid * max_draft_token; + auto* pre_ids_now = pre_ids + tid * pre_id_length; + auto* base_model_draft_tokens_now = + base_model_draft_tokens + tid * max_base_model_draft_token; + const int next_tokens_start_id = + tid * max_seq_len - output_cum_offsets[tid]; + auto* next_tokens_start = inter_next_tokens + next_tokens_start_id; + auto seq_len_this_time = seq_lens_this_time[tid]; + auto seq_len_encoder = seq_lens_encoder[tid]; + auto seq_len_decoder = seq_lens_decoder[tid]; + + int64_t stop_flag_now_int = 0; + + // 1. update step_idx && seq_lens_dec + if (!stop_flags[tid]) { + int64_t token_this_time = -1; + // decoder step + if (seq_len_decoder > 0 && seq_len_encoder <= 0) { + seq_lens_decoder[tid] += seq_len_this_time; + token_this_time = next_tokens_start[seq_len_this_time - 1]; + draft_token_now[0] = next_tokens_start[seq_len_this_time - 1]; + base_model_draft_tokens_now[substep + 1] = token_this_time; + for (int i = 0; i < seq_len_this_time; ++i) { + pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i]; + } + step_idx[tid] += seq_len_this_time; + + } else { + token_this_time = next_tokens_start[0]; + seq_lens_decoder[tid] = seq_len_encoder + seq_len_decoder; + seq_lens_encoder[tid] = 0; + pre_ids_now[1] = token_this_time; + step_idx[tid] += 1; + draft_token_now[0] = token_this_time; + base_model_draft_tokens_now[substep + 1] = token_this_time; + } + + // multi_end + if (is_in_end(token_this_time, end_ids, end_ids_len) || + prefill_one_step_stop) { + stop_flags[tid] = true; + stop_flag_now_int = 1; + // max_dec_len + } else if (step_idx[tid] >= max_dec_len[tid]) { + stop_flags[tid] = true; + draft_token_now[seq_len_this_time - 1] = end_ids[0]; + base_model_draft_tokens_now[substep + 1] = end_ids[0]; + stop_flag_now_int = 1; + } + + } else { + draft_token_now[0] = -1; + base_model_draft_tokens_now[substep + 1] = -1; + stop_flag_now_int = 1; + } + + // 2. set end + if (!stop_flags[tid]) { + seq_lens_this_time[tid] = 1; + } else { + seq_lens_this_time[tid] = 0; + seq_lens_encoder[tid] = 0; + } + + stop_sum += stop_flag_now_int; + } + + // 等价于CUDA中的BlockReduce求和 + not_need_stop[0] = stop_sum < bsz; + return SUCCESS; +} + +static int xpu2or3_wrapper(Context* ctx, + const int64_t* inter_next_tokens, + int64_t* draft_tokens, + int64_t* pre_ids, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + const int* output_cum_offsets, + bool* stop_flags, + bool* not_need_stop, + const int64_t* max_dec_len, + const int64_t* end_ids, + int64_t* base_model_draft_tokens, + const int bsz, + const int max_draft_token, + const int pre_id_length, + const int max_base_model_draft_token, + const int end_ids_len, + const int max_seq_len, + const int substep, + const bool prefill_one_step_stop) { + ctx_guard RAII_GUARD(ctx); + using XPU_INT64 = typename XPUIndexType::type; + xpu3::plugin::draft_model_update<<<1, 64, ctx->xpu_stream>>>( + reinterpret_cast(inter_next_tokens), + reinterpret_cast(draft_tokens), + reinterpret_cast(pre_ids), + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + reinterpret_cast(step_idx), + output_cum_offsets, + stop_flags, + not_need_stop, + reinterpret_cast(max_dec_len), + reinterpret_cast(end_ids), + reinterpret_cast(base_model_draft_tokens), + bsz, + max_draft_token, + pre_id_length, + max_base_model_draft_token, + end_ids_len, + max_seq_len, + substep, + prefill_one_step_stop); + + return api::SUCCESS; +} + +int draft_model_update(Context* ctx, + const int64_t* inter_next_tokens, + int64_t* draft_tokens, + int64_t* pre_ids, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + const int* output_cum_offsets, + bool* stop_flags, + bool* not_need_stop, + const int64_t* max_dec_len, + const int64_t* end_ids, + int64_t* base_model_draft_tokens, + const int bsz, + const int max_draft_token, + const int pre_id_length, + const int max_base_model_draft_token, + const int end_ids_len, + const int max_seq_len, + const int substep, + const bool prefill_one_step_stop) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "draft_model_update", int); + WRAPPER_DUMP_PARAM6(ctx, + inter_next_tokens, + draft_tokens, + pre_ids, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder); + WRAPPER_DUMP_PARAM6(ctx, + step_idx, + output_cum_offsets, + stop_flags, + not_need_stop, + max_dec_len, + end_ids); + WRAPPER_DUMP_PARAM6(ctx, + base_model_draft_tokens, + bsz, + max_draft_token, + pre_id_length, + max_base_model_draft_token, + end_ids_len); + WRAPPER_DUMP_PARAM3(ctx, max_seq_len, substep, prefill_one_step_stop); + WRAPPER_DUMP(ctx); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * max_seq_len, inter_next_tokens); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * max_draft_token, draft_tokens); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * pre_id_length, pre_ids); + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_encoder); + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_decoder); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz, step_idx); + WRAPPER_CHECK_PTR(ctx, int, bsz, output_cum_offsets); + WRAPPER_CHECK_PTR(ctx, bool, bsz, stop_flags); + WRAPPER_CHECK_PTR(ctx, bool, 1, not_need_stop); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz, max_dec_len); + WRAPPER_CHECK_PTR(ctx, int64_t, end_ids_len, end_ids); + WRAPPER_CHECK_PTR( + ctx, int64_t, bsz * max_base_model_draft_token, base_model_draft_tokens); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + inter_next_tokens, + draft_tokens, + pre_ids, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + output_cum_offsets, + stop_flags, + not_need_stop, + max_dec_len, + end_ids, + base_model_draft_tokens, + bsz, + max_draft_token, + pre_id_length, + max_base_model_draft_token, + end_ids_len, + max_seq_len, + substep, + prefill_one_step_stop); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu2or3_wrapper(ctx, + inter_next_tokens, + draft_tokens, + pre_ids, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + output_cum_offsets, + stop_flags, + not_need_stop, + max_dec_len, + end_ids, + base_model_draft_tokens, + bsz, + max_draft_token, + pre_id_length, + max_base_model_draft_token, + end_ids_len, + max_seq_len, + substep, + prefill_one_step_stop); + } + + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/mtp_free_and_dispatch_block.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/mtp_free_and_dispatch_block.cpp new file mode 100644 index 000000000..07bdf5665 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/mtp_free_and_dispatch_block.cpp @@ -0,0 +1,254 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu2 { +namespace plugin {} // namespace plugin +} // namespace xpu2 + +namespace xpu3 { +namespace plugin { + +__attribute__((global)) void mtp_free_and_dispatch_block( + bool *base_model_stop_flags, + bool *stop_flags, + bool *batch_drop, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + int *used_list_len, + int *free_list, + int *free_list_len, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_draft_tokens); + +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context *ctx, + bool *base_model_stop_flags, + bool *stop_flags, + bool *batch_drop, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + int *used_list_len, + int *free_list, + int *free_list_len, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_draft_tokens) { + int need_block_len = 0; + int need_block_list[640]; + for (int tid = 0; tid < bsz; tid++) { + need_block_list[tid] = 0; + int *block_table_now = block_tables + tid * block_num_per_seq; + if (base_model_stop_flags[tid] || batch_drop[tid]) { + // 回收block块 + const int encoder_block_len = encoder_block_lens[tid]; + const int decoder_used_len = used_list_len[tid]; + if (decoder_used_len > 0) { + for (int i = 0; i < decoder_used_len; i++) { + free_list[free_list_len[0] + i] = + block_table_now[encoder_block_len + i]; + block_table_now[encoder_block_len + i] = -1; + } + free_list_len[0] += decoder_used_len; + encoder_block_lens[tid] = 0; + used_list_len[tid] = 0; + } + } + } + for (int tid = 0; tid < bsz; tid++) { + int *block_table_now = block_tables + tid * block_num_per_seq; + int max_possible_block_idx = + (seq_lens_decoder[tid] + max_draft_tokens + 1) / block_size; + if (!base_model_stop_flags[tid] && !batch_drop[tid] && + max_possible_block_idx < block_num_per_seq && + block_table_now[max_possible_block_idx] == -1) { + need_block_list[need_block_len] = tid; + need_block_len++; + } + } + // 这里直接从 bid 0 开始遍历 + while (need_block_len > free_list_len[0]) { + int max_used_list_len_id = 0; + int max_used_list_len = 0; + for (int i = 0; i < bsz; i++) { + if (!base_model_stop_flags[i] && used_list_len[i] > max_used_list_len) { + max_used_list_len = used_list_len[i]; + max_used_list_len_id = i; + } + } + const int encoder_block_len = encoder_block_lens[max_used_list_len_id]; + int *block_table_now = + block_tables + max_used_list_len_id * block_num_per_seq; + for (int i = 0; i < max_used_list_len; i++) { + free_list[free_list_len[0] + i] = block_table_now[encoder_block_len + i]; + block_table_now[encoder_block_len + i] = -1; + } + stop_flags[max_used_list_len_id] = true; + batch_drop[max_used_list_len_id] = true; + seq_lens_this_time[max_used_list_len_id] = 0; + seq_lens_decoder[max_used_list_len_id] = 0; + used_list_len[max_used_list_len_id] = 0; + free_list_len[0] += max_used_list_len; + } + for (int tid = 0; tid < need_block_len; tid++) { + const int need_block_id = need_block_list[tid]; + // 这里必须用 batch_drop, 不能用 stop_flags + if (!batch_drop[need_block_id]) { + used_list_len[need_block_id] += 1; + int *block_table_now = block_tables + need_block_id * block_num_per_seq; + block_table_now[(seq_lens_decoder[need_block_id] + max_draft_tokens + 1) / + block_size] = free_list[free_list_len[0] - 1]; + free_list_len[0] -= 1; + } + } + return api::SUCCESS; +} + +static int xpu2or3_wrapper(Context *ctx, + bool *base_model_stop_flags, + bool *stop_flags, + bool *batch_drop, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + int *used_list_len, + int *free_list, + int *free_list_len, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_draft_tokens) { + using XPU_INT64 = typename XPUIndexType::type; + bool is_xpu3 = ctx->dev().type() == api::kXPU3; + if (!is_xpu3) { + WRAPPER_UNIMPLEMENTED(ctx); + } + auto mtp_free_and_dispatch_block = xpu3::plugin::mtp_free_and_dispatch_block; + mtp_free_and_dispatch_block<<<12, 64, ctx->xpu_stream>>>( + base_model_stop_flags, + stop_flags, + batch_drop, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + used_list_len, + free_list, + free_list_len, + bsz, + block_size, + block_num_per_seq, + max_draft_tokens); + return api::SUCCESS; +} + +int mtp_free_and_dispatch_block(Context *ctx, + bool *base_model_stop_flags, + bool *stop_flags, + bool *batch_drop, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + int *used_list_len, + int *free_list, + int *free_list_len, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_draft_tokens) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "mtp_free_and_dispatch_block", float); + WRAPPER_DUMP_PARAM6(ctx, + base_model_stop_flags, + stop_flags, + batch_drop, + seq_lens_this_time, + seq_lens_decoder, + block_tables); + WRAPPER_DUMP_PARAM4( + ctx, encoder_block_lens, used_list_len, free_list, free_list_len); + WRAPPER_DUMP_PARAM4( + ctx, bsz, block_size, block_num_per_seq, max_draft_tokens); + WRAPPER_ASSERT_LE(ctx, bsz, 640); + WRAPPER_CHECK_PTR(ctx, bool, bsz, base_model_stop_flags); + WRAPPER_CHECK_PTR(ctx, bool, bsz, stop_flags); + WRAPPER_CHECK_PTR(ctx, bool, bsz, batch_drop); + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_decoder); + WRAPPER_CHECK_PTR(ctx, int, bsz *block_num_per_seq, block_tables); + WRAPPER_CHECK_PTR(ctx, int, bsz, encoder_block_lens); + WRAPPER_CHECK_PTR(ctx, int, bsz, used_list_len); + WRAPPER_CHECK_PTR(ctx, int, 1, free_list_len); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + base_model_stop_flags, + stop_flags, + batch_drop, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + used_list_len, + free_list, + free_list_len, + bsz, + block_size, + block_num_per_seq, + max_draft_tokens); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu2or3_wrapper(ctx, + base_model_stop_flags, + stop_flags, + batch_drop, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + used_list_len, + free_list, + free_list_len, + bsz, + block_size, + block_num_per_seq, + max_draft_tokens); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/rebuild_hidden_states.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/rebuild_hidden_states.cpp new file mode 100644 index 000000000..42ae04c2e --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/rebuild_hidden_states.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +template +__attribute__((global)) void rebuildHiddenStatesKernel(const T* input, + const int* position_map, + T* output, + int dim_embed, + int elem_cnt); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +template +static int cpu_wrapper(Context* ctx, + const T* input, + const int* position_map, + T* output, + int dim_embed, + int elem_cnt) { + for (int elem_id = 0; elem_id < elem_cnt; elem_id++) { + int ori_token_idx = elem_id / dim_embed; + int token_idx = position_map[ori_token_idx]; + int offset = elem_id % dim_embed; + if (token_idx >= 0) { + output[token_idx * dim_embed + offset] = + input[ori_token_idx * dim_embed + offset]; + } + } + return api::SUCCESS; +} + +template +static int xpu3_wrapper(Context* ctx, + const T* input, + const int* position_map, + T* output, + int dim_embed, + int elem_cnt) { + xpu3::plugin::rebuildHiddenStatesKernel + <<ncluster(), 64, ctx->xpu_stream>>>( + input, position_map, output, dim_embed, elem_cnt); + return api::SUCCESS; +} + +template +int rebuild_hidden_states(Context* ctx, + const T* input, + const int* position_map, + T* output, + int dim_embed, + int elem_cnt) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "rebuild_hidden_states", T); + WRAPPER_DUMP_PARAM5(ctx, input, position_map, output, dim_embed, elem_cnt); + WRAPPER_DUMP(ctx); + + WRAPPER_ASSERT_GT(ctx, dim_embed, 0); + WRAPPER_ASSERT_GT(ctx, elem_cnt, 0); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper( + ctx, input, position_map, output, dim_embed, elem_cnt); + } else if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper( + ctx, input, position_map, output, dim_embed, elem_cnt); + } + WRAPPER_UNIMPLEMENTED(ctx); + return api::SUCCESS; +} + +template int rebuild_hidden_states( + Context*, const bfloat16*, const int*, bfloat16*, int, int); +template int rebuild_hidden_states( + Context*, const float*, const int*, float*, int, int); +template int rebuild_hidden_states( + Context*, const float16*, const int*, float16*, int, int); +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/rebuild_self_hidden_states.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/rebuild_self_hidden_states.cpp new file mode 100644 index 000000000..a0a06682c --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/rebuild_self_hidden_states.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +template +__attribute__((global)) void rebuildSelfHiddenStatesKernel( + const T* input, int* src_map, T* output, int dim_embed, int elem_cnt); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +template +static int cpu_wrapper(Context* ctx, + const T* input, + int* src_map, + T* output, + int dim_embed, + int elem_cnt) { + for (int elem_id = 0; elem_id < elem_cnt; elem_id++) { + int output_token_idx = elem_id / dim_embed; + int input_token_idx = src_map[output_token_idx]; + int offset = elem_id % dim_embed; + output[output_token_idx * dim_embed + offset] = + input[input_token_idx * dim_embed + offset]; + } + return api::SUCCESS; +} + +template +static int xpu3_wrapper(Context* ctx, + const T* input, + int* src_map, + T* output, + int dim_embed, + int elem_cnt) { + xpu3::plugin::rebuildSelfHiddenStatesKernel + <<ncluster(), 64, ctx->xpu_stream>>>( + input, src_map, output, dim_embed, elem_cnt); + return api::SUCCESS; +} + +template +int rebuild_self_hidden_states(Context* ctx, + const T* input, + int* src_map, + T* output, + int dim_embed, + int elem_cnt) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "rebuild_self_hidden_states", T); + WRAPPER_DUMP_PARAM5(ctx, input, src_map, output, dim_embed, elem_cnt); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, T, elem_cnt, output); + WRAPPER_ASSERT_GT(ctx, dim_embed, 0); + WRAPPER_ASSERT_GT(ctx, elem_cnt, 0); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, input, src_map, output, dim_embed, elem_cnt); + } else if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, input, src_map, output, dim_embed, elem_cnt); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +template int rebuild_self_hidden_states( + Context*, const bfloat16*, int*, bfloat16*, int, int); +template int rebuild_self_hidden_states( + Context*, const float*, int*, float*, int, int); +template int rebuild_self_hidden_states( + Context*, const float16*, int*, float16*, int, int); +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_clear_accept_nums.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_clear_accept_nums.cpp new file mode 100644 index 000000000..c6503f8d7 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_clear_accept_nums.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +__attribute__((global)) void speculate_clear_accept_nums( + int* accept_num, const int* seq_lens_decoder, const int max_bsz); +} +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context* ctx, + int* accept_num, + const int* seq_lens_decoder, + const int max_bsz) { + for (int i = 0; i < max_bsz; i++) { + accept_num[i] = seq_lens_decoder[i] == 0 ? 0 : accept_num[i]; + } + return SUCCESS; +} + +static int xpu2or3_wrapper(Context* ctx, + int* accept_num, + const int* seq_lens_decoder, + const int max_bsz) { + ctx_guard RAII_GUARD(ctx); + xpu3::plugin::speculate_clear_accept_nums<<<1, 64, ctx->xpu_stream>>>( + accept_num, seq_lens_decoder, max_bsz); + + return api::SUCCESS; +} + +int speculate_clear_accept_nums(Context* ctx, + int* accept_num, + const int* seq_lens_decoder, + const int max_bsz) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_clear_accept_nums", int); + WRAPPER_DUMP_PARAM3(ctx, accept_num, seq_lens_decoder, max_bsz); + WRAPPER_DUMP(ctx); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, accept_num, seq_lens_decoder, max_bsz); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu2or3_wrapper(ctx, accept_num, seq_lens_decoder, max_bsz); + } + + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_reschedule.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_reschedule.cpp new file mode 100644 index 000000000..14fbcba77 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_reschedule.cpp @@ -0,0 +1,230 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { + +__attribute__((global)) void speculate_free_and_reschedule( + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens); + +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context *ctx, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens) { + return -1; +} + +static int xpu3_wrapper(Context *ctx, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens) { + using XPU_INT64 = typename XPUIndexType::type; + auto speculate_free_and_reschedule = + xpu3::plugin::speculate_free_and_reschedule; + speculate_free_and_reschedule<<ncluster(), 64, ctx->xpu_stream>>>( + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + reinterpret_cast(first_token_ids), + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + return api::SUCCESS; +} + +int speculate_free_and_reschedule(Context *ctx, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_free_and_reschedule", float); + WRAPPER_DUMP_PARAM6(ctx, + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step); + WRAPPER_DUMP_PARAM6(ctx, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len); + WRAPPER_DUMP_PARAM4( + ctx, used_list_len, free_list, free_list_len, first_token_ids); + WRAPPER_DUMP_PARAM5(ctx, + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + first_token_ids, + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + first_token_ids, + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_output_padding_offset.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_output_padding_offset.cpp new file mode 100644 index 000000000..a27f3bdde --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_output_padding_offset.cpp @@ -0,0 +1,118 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +__attribute__((global)) void speculate_get_output_padding_offset( + int* output_padding_offset, + int* output_cum_offsets, + const int* output_cum_offsets_tmp, + const int* seq_lens_output, + const int bsz, + const int max_seq_len); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context* ctx, + int* output_padding_offset, + int* output_cum_offsets, + const int* output_cum_offsets_tmp, + const int* seq_lens_output, + const int bsz, + const int max_seq_len) { + for (int bi = 0; bi < bsz; bi++) { + int cum_offset = 0; + if (bi > 0) { + cum_offset = output_cum_offsets_tmp[bi - 1]; + } + output_cum_offsets[bi] = cum_offset; + for (int token_i = 0; token_i < seq_lens_output[bi]; token_i++) { + output_padding_offset[bi * max_seq_len - cum_offset + token_i] = + cum_offset; + } + } + return SUCCESS; +} + +static int xpu2or3_wrapper(Context* ctx, + int* output_padding_offset, + int* output_cum_offsets, + const int* output_cum_offsets_tmp, + const int* seq_lens_output, + const int bsz, + const int max_seq_len) { + ctx_guard RAII_GUARD(ctx); + xpu3::plugin::speculate_get_output_padding_offset<<ncluster(), + 64, + ctx->xpu_stream>>>( + output_padding_offset, + output_cum_offsets, + output_cum_offsets_tmp, + seq_lens_output, + bsz, + max_seq_len); + return api::SUCCESS; +} + +int speculate_get_output_padding_offset(Context* ctx, + int* output_padding_offset, + int* output_cum_offsets, + const int* output_cum_offsets_tmp, + const int* seq_lens_output, + const int bsz, + const int max_seq_len) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_output_padding_offset", int); + WRAPPER_DUMP_PARAM5(ctx, + output_padding_offset, + output_cum_offsets, + output_cum_offsets_tmp, + seq_lens_output, + max_seq_len); + WRAPPER_DUMP(ctx); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + output_padding_offset, + output_cum_offsets, + output_cum_offsets_tmp, + seq_lens_output, + bsz, + max_seq_len); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu2or3_wrapper(ctx, + output_padding_offset, + output_cum_offsets, + output_cum_offsets_tmp, + seq_lens_output, + bsz, + max_seq_len); + } + + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp new file mode 100644 index 000000000..a0066e455 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp @@ -0,0 +1,295 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl/xdnn_impl.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { + +template +__attribute__((global)) void speculate_remove_padding( + T* output_data, + const T* input_data, + const T* draft_tokens, + const int* seq_lens, + const int* seq_lens_encoder, + const int* cum_offsets, + int sequence_length, + int max_draft_tokens, + int bsz, + int token_num_data); + +__attribute__((global)) void speculate_get_padding_offset( + int* padding_offset, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz); + +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +template +static int cpu_wrapper_remove_padding(Context* ctx, + T* output_data, + const T* input_data, + const T* draft_tokens, + const int* seq_lens, + const int* seq_lens_encoder, + const int* cum_offsets, + int sequence_length, + int max_draft_tokens, + int bsz, + int token_num_data) { + for (int bi = 0; bi < bsz; ++bi) { + for (int i = 0; i < seq_lens[bi]; i++) { + const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i; + if (seq_lens_encoder[bi] > 0) { + const int src_seq_id = bi * sequence_length + i; + output_data[tgt_seq_id] = input_data[src_seq_id]; + } else { + const int src_seq_id = bi * max_draft_tokens + i; + output_data[tgt_seq_id] = draft_tokens[src_seq_id]; + } + } + } + return api::SUCCESS; +} + +static int cpu_wrapper_get_padding_offset(Context* ctx, + int* padding_offset, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz) { + for (int bi = 0; bi < bsz; ++bi) { + int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; + for (int i = 0; i < seq_lens[bi]; i++) { + padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset; + } + cum_offsets_out[bi] = cum_offset; + int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]; + cu_seqlens_q[bi + 1] = cum_seq_len; + cu_seqlens_k[bi + 1] = cum_seq_len; + } + return api::SUCCESS; +} + +template +static int xpu3_wrapper_remove_padding(Context* ctx, + T* output_data, + const T* input_data, + const T* draft_tokens, + const int* seq_lens, + const int* seq_lens_encoder, + const int* cum_offsets, + int sequence_length, + int max_draft_tokens, + int bsz, + int token_num_data) { + using XPU_T = typename XPUIndexType::type; + xpu3::plugin::speculate_remove_padding + <<ncluster(), 64, ctx->xpu_stream>>>( + static_cast(static_cast(output_data)), + static_cast(static_cast(input_data)), + static_cast(static_cast(draft_tokens)), + seq_lens, + seq_lens_encoder, + cum_offsets, + sequence_length, + max_draft_tokens, + bsz, + token_num_data); + + return api::SUCCESS; +} + +static int xpu3_wrapper_get_padding_offset(Context* ctx, + int* padding_offset, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz) { + xpu3::plugin:: + speculate_get_padding_offset<<ncluster(), 64, ctx->xpu_stream>>>( + padding_offset, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + cum_offsets, + seq_lens, + max_seq_len, + bsz); + return api::SUCCESS; +} + +template +int speculate_remove_padding(Context* ctx, + T* x_remove_padding, + const T* input_ids, + const T* draft_tokens, + const int* seq_lens, + const int* seq_lens_encoder, + const int* cum_offsets_out, + int seq_length, + int max_draft_tokens, + int bsz, + int token_num_data) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_remove_padding", T); + WRAPPER_DUMP_PARAM6(ctx, + x_remove_padding, + input_ids, + draft_tokens, + seq_lens, + seq_lens_encoder, + cum_offsets_out); + WRAPPER_DUMP_PARAM4(ctx, seq_length, max_draft_tokens, bsz, token_num_data); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, T, bsz * seq_length, input_ids); + WRAPPER_CHECK_PTR(ctx, T, bsz * max_draft_tokens, draft_tokens); + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens); + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_encoder); + WRAPPER_CHECK_PTR(ctx, int, bsz, cum_offsets_out); + WRAPPER_CHECK_PTR(ctx, T, token_num_data, x_remove_padding); + + WRAPPER_ASSERT_GT(ctx, bsz, 0); + WRAPPER_ASSERT_GT(ctx, seq_length, 0); + WRAPPER_ASSERT_GT(ctx, max_draft_tokens, 0); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper_remove_padding(ctx, + x_remove_padding, + input_ids, + draft_tokens, + seq_lens, + seq_lens_encoder, + cum_offsets_out, + seq_length, + max_draft_tokens, + bsz, + token_num_data); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper_remove_padding(ctx, + x_remove_padding, + input_ids, + draft_tokens, + seq_lens, + seq_lens_encoder, + cum_offsets_out, + seq_length, + max_draft_tokens, + bsz, + token_num_data); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +int speculate_get_padding_offset(Context* ctx, + int* padding_offset, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz) { + WRAPPER_CHECK_CTX(ctx); + + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_padding_offset", float); + WRAPPER_DUMP_PARAM6(ctx, + padding_offset, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + cum_offsets, + seq_lens); + WRAPPER_DUMP_PARAM2(ctx, max_seq_len, bsz); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, int, bsz, cum_offsets); + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens); + WRAPPER_CHECK_PTR(ctx, int, bsz, cum_offsets_out); + WRAPPER_CHECK_PTR(ctx, int, bsz + 1, cu_seqlens_q); + WRAPPER_CHECK_PTR(ctx, int, bsz + 1, cu_seqlens_k); + + WRAPPER_ASSERT_GT(ctx, bsz, 0); + WRAPPER_ASSERT_GT(ctx, max_seq_len, 0); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper_get_padding_offset(ctx, + padding_offset, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + cum_offsets, + seq_lens, + max_seq_len, + bsz); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper_get_padding_offset(ctx, + padding_offset, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + cum_offsets, + seq_lens, + max_seq_len, + bsz); + } + + WRAPPER_UNIMPLEMENTED(ctx); +} + +#define INSTANTIATION_SPECULATE_REMOVE_PADDING(T) \ + template int speculate_remove_padding(Context * ctx, \ + T * x_remove_padding, \ + const T* input_ids, \ + const T* draft_tokens, \ + const int* seq_len, \ + const int* seq_lens_encoder, \ + const int* cum_offsets_out, \ + int seq_length, \ + int max_draft_tokens, \ + int bsz, \ + int token_num_data) + +INSTANTIATION_SPECULATE_REMOVE_PADDING(float); +INSTANTIATION_SPECULATE_REMOVE_PADDING(float16); +INSTANTIATION_SPECULATE_REMOVE_PADDING(bfloat16); +INSTANTIATION_SPECULATE_REMOVE_PADDING(int64_t); + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_seq_lens_output.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_seq_lens_output.cpp new file mode 100644 index 000000000..97da31219 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_seq_lens_output.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +__attribute__((global)) void speculate_get_seq_lens_output( + int* seq_lens_output, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int real_bsz); +} +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context* ctx, + int* seq_lens_output, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int real_bsz) { + for (int bid = 0; bid < real_bsz; ++bid) { + if (seq_lens_this_time[bid] == 0) { + continue; + } else if (seq_lens_this_time[bid] == 1) { + seq_lens_output[bid] = 1; + } else if (seq_lens_encoder[bid] != 0) { + seq_lens_output[bid] = 1; + } else { + seq_lens_output[bid] = seq_lens_this_time[bid]; + } + } + return SUCCESS; +} + +static int xpu2or3_wrapper(Context* ctx, + int* seq_lens_output, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int real_bsz) { + ctx_guard RAII_GUARD(ctx); + xpu3::plugin:: + speculate_get_seq_lens_output<<ncluster(), 64, ctx->xpu_stream>>>( + seq_lens_output, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + real_bsz); + + return api::SUCCESS; +} + +int speculate_get_seq_lens_output(Context* ctx, + int* seq_lens_output, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int real_bsz) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_seq_lens_output", int); + WRAPPER_DUMP_PARAM5(ctx, + seq_lens_output, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + real_bsz); + WRAPPER_DUMP(ctx); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + seq_lens_output, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + real_bsz); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu2or3_wrapper(ctx, + seq_lens_output, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + real_bsz); + } + + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_rebuild_append_padding.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_rebuild_append_padding.cpp new file mode 100644 index 000000000..62391b2fc --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_rebuild_append_padding.cpp @@ -0,0 +1,156 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +template +__attribute__((global)) void RebuildAppendPaddingKernel( + const T* full_hidden_states, + const int* cum_offsets, + const int* seq_len_encoder, + const int* seq_len_decoder, + const int* output_padding_offset, + int max_seq_len, + int dim_embed, + int elem_nums, + T* out); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +template +static int cpu_wrapper(Context* ctx, + T* full_hidden_states, + int* cum_offsets, + int* seq_len_encoder, + int* seq_len_decoder, + int* output_padding_offset, + int max_seq_len, + int dim_embed, + int elem_nums, + T* out) { + for (int64_t i = 0; i < elem_nums; ++i) { + int64_t out_token_id = i / dim_embed; + int64_t ori_token_id = out_token_id + output_padding_offset[out_token_id]; + int64_t bi = ori_token_id / max_seq_len; + + int64_t seq_id = 0; + if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) { + continue; + } else if (seq_len_encoder[bi] != 0) { + seq_id = seq_len_encoder[bi] - 1; + } + + int64_t input_token_id = ori_token_id - cum_offsets[bi] + seq_id; + int64_t bias_idx = i % dim_embed; + + out[i] = full_hidden_states[input_token_id * dim_embed + bias_idx]; + } + return api::SUCCESS; +} + +template +static int xpu3_wrapper(Context* ctx, + T* full_hidden_states, + int* cum_offsets, + int* seq_len_encoder, + int* seq_len_decoder, + int* output_padding_offset, + int max_seq_len, + int dim_embed, + int elem_nums, + T* out) { + xpu3::plugin::RebuildAppendPaddingKernel + <<ncluster(), 64, ctx->xpu_stream>>>(full_hidden_states, + cum_offsets, + seq_len_encoder, + seq_len_decoder, + output_padding_offset, + max_seq_len, + dim_embed, + elem_nums, + out); + return api::SUCCESS; +} + +template +int speculate_rebuild_append_padding(Context* ctx, + T* full_hidden_states, + int* cum_offsets, + int* seq_len_encoder, + int* seq_len_decoder, + int* output_padding_offset, + int max_seq_len, + int dim_embed, + int elem_nums, + T* out) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_rebuild_append_padding", T); + WRAPPER_DUMP_PARAM5(ctx, + full_hidden_states, + cum_offsets, + seq_len_encoder, + seq_len_decoder, + output_padding_offset); + WRAPPER_DUMP_PARAM4(ctx, max_seq_len, dim_embed, elem_nums, out); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, T, elem_nums, out); + WRAPPER_ASSERT_GT(ctx, max_seq_len, 0); + WRAPPER_ASSERT_GT(ctx, dim_embed, 0); + WRAPPER_ASSERT_GT(ctx, elem_nums, 0); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + full_hidden_states, + cum_offsets, + seq_len_encoder, + seq_len_decoder, + output_padding_offset, + max_seq_len, + dim_embed, + elem_nums, + out); + } else if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + full_hidden_states, + cum_offsets, + seq_len_encoder, + seq_len_decoder, + output_padding_offset, + max_seq_len, + dim_embed, + elem_nums, + out); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +template int speculate_rebuild_append_padding( + Context*, bfloat16*, int*, int*, int*, int*, int, int, int, bfloat16*); +template int speculate_rebuild_append_padding( + Context*, float16*, int*, int*, int*, int*, int, int, int, float16*); +template int speculate_rebuild_append_padding( + Context*, float*, int*, int*, int*, int*, int, int, int, float*); +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_set_stop_value_multi_seqs.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_set_stop_value_multi_seqs.cpp new file mode 100644 index 000000000..d57907966 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_set_stop_value_multi_seqs.cpp @@ -0,0 +1,224 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +__attribute__((global)) void speculate_set_stop_value_multi_seqs( + bool* stop_flags, + int64_t* accept_tokens, + int* accept_nums, + const int64_t* pre_ids, + const int64_t* step_idx, + const int64_t* stop_seqs, + const int* stop_seqs_len, + const int* seq_lens, + const int64_t* end_ids, + const int bs, + const int accept_tokens_len, + const int stop_seqs_bs, + const int stop_seqs_max_len, + const int pre_ids_len); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context* ctx, + bool* stop_flags, + int64_t* accept_tokens, + int* accept_nums, + const int64_t* pre_ids, + const int64_t* step_idx, + const int64_t* stop_seqs, + const int* stop_seqs_len, + const int* seq_lens, + const int64_t* end_ids, + const int bs, + const int accept_tokens_len, + const int stop_seqs_bs, + const int stop_seqs_max_len, + const int pre_ids_len) { + for (int bid = 0; bid < bs; ++bid) { + const int64_t* pre_ids_now = pre_ids + bid * pre_ids_len; + int64_t* accept_tokens_now = accept_tokens + bid * accept_tokens_len; + const int accept_num = accept_nums[bid]; + const int64_t step_idx_now = step_idx[bid]; + for (int tid = 0; tid < stop_seqs_bs; ++tid) { + const int stop_seq_len = stop_seqs_len[tid]; + if (stop_seq_len <= 0) continue; + const int64_t* stop_seq_now = stop_seqs + tid * stop_seqs_max_len; + if (!stop_flags[bid]) { + int accept_idx = 0; + bool is_end = false; + // 遍历起始位置 + for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) { + if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) { + continue; + } + // 遍历一个 stop_seqs + for (int i = stop_seq_len - 1; i >= 0; --i) { + int64_t cur_token_idx = -1; + + // 通过当前值判断 token 是在 pre_ids 还是 accept_token 里 + if (stop_seq_len - 1 - i < accept_idx) { + cur_token_idx = + accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1]; + } else { + int pre_ids_idx = step_idx_now - accept_num + accept_idx - + (stop_seq_len - 1 - i); + // EC3 + // 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23, + // 导致异常结束 + if (pre_ids_idx <= 0) { + break; + } + cur_token_idx = pre_ids_now[pre_ids_idx]; + } + if (cur_token_idx != stop_seq_now[i]) { + break; + } + if (i == 0) { + is_end = true; + } + } + } + if (is_end) { + accept_nums[bid] = accept_idx; + accept_tokens_now[accept_idx - 1] = end_ids[0]; + stop_flags[bid] = true; + } + } + } + } + + return api::SUCCESS; +} + +static int xpu2or3_wrapper(Context* ctx, + bool* stop_flags, + int64_t* accept_tokens, + int* accept_nums, + const int64_t* pre_ids, + const int64_t* step_idx, + const int64_t* stop_seqs, + const int* stop_seqs_len, + const int* seq_lens, + const int64_t* end_ids, + const int bs, + const int accept_tokens_len, + const int stop_seqs_bs, + const int stop_seqs_max_len, + const int pre_ids_len) { + using XPU_INT64 = typename XPUIndexType::type; + xpu3::plugin::speculate_set_stop_value_multi_seqs<<ncluster(), + 64, + ctx->xpu_stream>>>( + stop_flags, + reinterpret_cast(accept_tokens), + accept_nums, + reinterpret_cast(pre_ids), + reinterpret_cast(step_idx), + reinterpret_cast(stop_seqs), + stop_seqs_len, + seq_lens, + reinterpret_cast(end_ids), + bs, + accept_tokens_len, + stop_seqs_bs, + stop_seqs_max_len, + pre_ids_len); + return api::SUCCESS; +} + +int speculate_set_stop_value_multi_seqs(Context* ctx, + bool* stop_flags, + int64_t* accept_tokens, + int* accept_nums, + const int64_t* pre_ids, + const int64_t* step_idx, + const int64_t* stop_seqs, + const int* stop_seqs_len, + const int* seq_lens, + const int64_t* end_ids, + const int bs_now, + const int accept_tokens_len, + const int stop_seqs_bs, + const int stop_seqs_max_len, + const int pre_ids_len) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_set_stop_value_multi_seqs", int64_t); + WRAPPER_DUMP_PARAM3(ctx, stop_flags, accept_tokens, accept_nums); + WRAPPER_DUMP_PARAM6( + ctx, pre_ids, step_idx, stop_seqs, stop_seqs_len, seq_lens, end_ids); + WRAPPER_DUMP_PARAM5(ctx, + bs_now, + accept_tokens_len, + stop_seqs_bs, + stop_seqs_max_len, + pre_ids_len); + WRAPPER_DUMP(ctx); + WRAPPER_CHECK_PTR(ctx, int64_t, bs_now * accept_tokens_len, accept_tokens); + WRAPPER_CHECK_PTR(ctx, int64_t, stop_seqs_bs * stop_seqs_max_len, stop_seqs); + WRAPPER_ASSERT_GT(ctx, bs_now, 0); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + stop_flags, + accept_tokens, + accept_nums, + pre_ids, + step_idx, + stop_seqs, + stop_seqs_len, + seq_lens, + end_ids, + bs_now, + accept_tokens_len, + stop_seqs_bs, + stop_seqs_max_len, + pre_ids_len); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu2or3_wrapper(ctx, + stop_flags, + accept_tokens, + accept_nums, + pre_ids, + step_idx, + stop_seqs, + stop_seqs_len, + seq_lens, + end_ids, + bs_now, + accept_tokens_len, + stop_seqs_bs, + stop_seqs_max_len, + pre_ids_len); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_set_value_by_flags.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_set_value_by_flags.cpp new file mode 100644 index 000000000..8c7dbfad2 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_set_value_by_flags.cpp @@ -0,0 +1,157 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +__attribute__((global)) void speculate_set_value_by_flag_and_id( + int64_t *pre_ids_all, + const int64_t *accept_tokens, + const int *accept_num, + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *step_idx, + int bs, + int length, + int max_draft_tokens); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context *ctx, + int64_t *pre_ids_all, // bs * length + const int64_t *accept_tokens, // bs * max_draft_tokens + const int *accept_num, // bs + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *step_idx, + int bs, + int length, + int max_draft_tokens) { + for (int i = 0; i < bs; i++) { + if (stop_flags[i] || (seq_lens_encoder[i] == 0 && seq_lens_decoder[i] == 0)) + continue; + + int64_t *pre_ids_all_now = pre_ids_all + i * length; + const int64_t *accept_tokens_now = accept_tokens + i * max_draft_tokens; + int accept_num_now = accept_num[i]; + int64_t step_idx_now = step_idx[i]; + + if (step_idx_now >= 0) { + for (int j = 0; j < accept_num_now; j++) { + pre_ids_all_now[step_idx_now - j] = + accept_tokens_now[accept_num_now - 1 - j]; + } + } + } + return SUCCESS; +} + +static int xpu2or3_wrapper(Context *ctx, + int64_t *pre_ids_all, + const int64_t *accept_tokens, + const int *accept_num, + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *step_idx, + int bs, + int length, + int max_draft_tokens) { + ctx_guard RAII_GUARD(ctx); + using XPU_INT64 = typename XPUIndexType::type; + + xpu3::plugin::speculate_set_value_by_flag_and_id<<ncluster(), + 64, + ctx->xpu_stream>>>( + reinterpret_cast(pre_ids_all), + reinterpret_cast(accept_tokens), + accept_num, + stop_flags, + seq_lens_encoder, + seq_lens_decoder, + reinterpret_cast(step_idx), + bs, + length, + max_draft_tokens); + return api::SUCCESS; +} + +int speculate_set_value_by_flag_and_id(Context *ctx, + int64_t *pre_ids_all, + const int64_t *accept_tokens, + const int *accept_num, + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *step_idx, + int bs, + int length, + int max_draft_tokens) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_set_value_by_flag_and_id", int); + WRAPPER_DUMP_PARAM6(ctx, + pre_ids_all, + accept_tokens, + accept_num, + stop_flags, + seq_lens_encoder, + seq_lens_decoder); + WRAPPER_DUMP_PARAM4(ctx, step_idx, bs, length, max_draft_tokens); + WRAPPER_DUMP(ctx); + + WRAPPER_ASSERT_LE(ctx, max_draft_tokens, 500); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + pre_ids_all, + accept_tokens, + accept_num, + stop_flags, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + bs, + length, + max_draft_tokens); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu2or3_wrapper(ctx, + pre_ids_all, + accept_tokens, + accept_num, + stop_flags, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + bs, + length, + max_draft_tokens); + } + + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_token_penalty_multi_scores.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_token_penalty_multi_scores.cpp new file mode 100644 index 000000000..ad607f694 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_token_penalty_multi_scores.cpp @@ -0,0 +1,512 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { + +template +__attribute__((global)) void speculate_min_length_logits_process( + T* logits, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int* output_padding_offset, + const int* output_cum_offsets, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t token_num, + const int64_t max_seq_len); +__attribute__((global)) void speculate_update_repeat_times( + const int64_t* pre_ids, + const int64_t* cur_len, + int* repeat_times, + const int* output_padding_offset, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t token_num, + const int64_t max_seq_len); +template +__attribute__((global)) void speculate_update_value_by_repeat_times( + const int* repeat_times, + const T* penalty_scores, + const T* frequency_score, + const T* presence_score, + const float* temperatures, + T* logits, + const int* output_padding_offset, + const int64_t bs, + const int64_t length, + const int64_t token_num, + const int64_t max_seq_len); +template +__attribute__((global)) void speculate_update_value_by_repeat_times_simd( + const int* repeat_times, + const T* penalty_scores, + const T* frequency_score, + const T* presence_score, + const float* temperatures, + T* logits, + const int* output_padding_offset, + const int64_t bs, + const int64_t length, + const int64_t token_num, + const int64_t max_seq_len); +template +__attribute__((global)) void speculate_ban_bad_words( + T* logits, + const int64_t* bad_words_list, + const int* output_padding_offset, + const int64_t bs, + const int64_t length, + const int64_t bad_words_length, + const int64_t token_num, + const int64_t max_seq_len); + +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +void update_repeat_times_cpu(const int64_t* pre_ids, + const int64_t* cur_len, + int* repeat_times, + const int* output_padding_offset, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t token_num, + const int64_t max_seq_len) { + for (int64_t i = 0; i < token_num; i++) { + int64_t bi = (i + output_padding_offset[i]) / max_seq_len; + if (bi < bs && cur_len[bi] >= 0) { + for (int64_t j = 0; j < length_id; j++) { + int64_t id = pre_ids[bi * length_id + j]; + if (id < 0) { + break; + } else if (id >= length) { + continue; + } else { + repeat_times[i * length + id] += 1; + } + } + } + } +} + +void ban_bad_words_cpu(float* logits, + const int64_t* bad_words_list, + const int* output_padding_offset, + const int64_t bs, + const int64_t length, + const int64_t bad_words_length, + const int64_t token_num, + const int64_t max_seq_len) { + for (int64_t i = 0; i < token_num; i++) { + int64_t bi = (i + output_padding_offset[i]) / max_seq_len; + if (bi >= bs) { + continue; + } + float* logits_now = logits + i * length; + for (int64_t j = 0; j < bad_words_length; j++) { + int64_t bad_words_token_id = bad_words_list[j]; + if (bad_words_token_id >= length || bad_words_token_id < 0) continue; + logits_now[bad_words_token_id] = -1e10; + } + } +} + +template +static int cpu_wrapper(Context* ctx, + const int64_t* pre_ids, + T* logits, + const T* penalty_scores, + const T* frequency_scores, + const T* presence_scores, + const float* temperatures, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t* bad_words, + const int* output_padding_offset, + const int* output_cum_offsets, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words, + const int64_t token_num, + const int64_t max_seq_len) { + std::vector logitsfp32(token_num * length); + std::vector penalty_scoresfp32(bs); + std::vector frequency_scoresfp32(bs); + std::vector presence_scoresfp32(bs); + std::vector repeat_times_buffer(token_num * length, 0); + int ret = + api::cast(ctx, logits, logitsfp32.data(), token_num * length); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + ret = api::cast(ctx, penalty_scores, penalty_scoresfp32.data(), bs); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + ret = api::cast( + ctx, frequency_scores, frequency_scoresfp32.data(), bs); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + ret = + api::cast(ctx, presence_scores, presence_scoresfp32.data(), bs); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + for (int64_t i = 0; i < token_num; i++) { + int64_t bi = (i + output_padding_offset[i]) / max_seq_len; + int64_t query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi]; + if (bi < bs && cur_len[bi] >= 0 && + (cur_len[bi] + (i - query_start_token_idx) < min_len[bi])) { + for (int64_t j = 0; j < end_length; j++) { + logitsfp32[i * length + eos_token_id[j]] = + std::is_same::value ? -1e4 : -1e10; + } + } + } + int* repeat_times = &(repeat_times_buffer[0]); + update_repeat_times_cpu(pre_ids, + cur_len, + repeat_times, + output_padding_offset, + bs, + length, + length_id, + token_num, + max_seq_len); + for (int64_t i = 0; i < token_num; i++) { + int64_t bi = (i + output_padding_offset[i]) / max_seq_len; + if (bi >= bs) { + continue; + } + float alpha = penalty_scoresfp32[bi]; + float beta = frequency_scoresfp32[bi]; + float gamma = presence_scoresfp32[bi]; + float temperature = temperatures[bi]; + for (int64_t j = 0; j < length; j++) { + int times = repeat_times[i * length + j]; + float logit_now = logitsfp32[i * length + j]; + if (times != 0) { + logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha; + logit_now = logit_now - times * beta - gamma; + } + logitsfp32[i * length + j] = logit_now / temperature; + } + } + if (bad_words && length_bad_words > 0) { + ban_bad_words_cpu(logitsfp32.data(), + bad_words, + output_padding_offset, + bs, + length, + length_bad_words, + token_num, + max_seq_len); + } + ret = api::cast(ctx, logitsfp32.data(), logits, token_num * length); + return ret; +} + +template +static int xpu3_wrapper(Context* ctx, + const int64_t* pre_ids, + T* logits, + const T* penalty_scores, + const T* frequency_scores, + const T* presence_scores, + const float* temperatures, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t* bad_words, + const int* output_padding_offset, + const int* output_cum_offsets, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words, + const int64_t token_num, + const int64_t max_seq_len) { + api::ctx_guard RAII_GUARD(ctx); + using XPU_INT64 = typename XPUIndexType::type; + auto min_length_logits_process_kernel = + xpu3::plugin::speculate_min_length_logits_process; + auto update_repeat_times_kernel = xpu3::plugin::speculate_update_repeat_times; + auto update_value_by_repeat_times_kernel = + xpu3::plugin::speculate_update_value_by_repeat_times; + if (length % 16 == 0) { + update_value_by_repeat_times_kernel = + xpu3::plugin::speculate_update_value_by_repeat_times_simd; + } + auto ban_bad_words_kernel = xpu3::plugin::speculate_ban_bad_words; + + int* repeat_times = RAII_GUARD.alloc_l3_or_gm(token_num * length); + WRAPPER_ASSERT_WORKSPACE(ctx, repeat_times); + int ret = api::constant(ctx, repeat_times, token_num * length, 0); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + + update_repeat_times_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + reinterpret_cast(pre_ids), + reinterpret_cast(cur_len), + repeat_times, + output_padding_offset, + bs, + length, + length_id, + token_num, + max_seq_len); + min_length_logits_process_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + logits, + reinterpret_cast(cur_len), + reinterpret_cast(min_len), + reinterpret_cast(eos_token_id), + output_padding_offset, + output_cum_offsets, + bs, + length, + length_id, + end_length, + token_num, + max_seq_len); + update_value_by_repeat_times_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + repeat_times, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + logits, + output_padding_offset, + bs, + length, + token_num, + max_seq_len); + + if (bad_words && length_bad_words > 0) { + ban_bad_words_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + logits, + reinterpret_cast(bad_words), + output_padding_offset, + bs, + length, + length_bad_words, + token_num, + max_seq_len); + } + return api::SUCCESS; +} + +template +int speculate_token_penalty_multi_scores(Context* ctx, + const int64_t* pre_ids, + T* logits, + const T* penalty_scores, + const T* frequency_scores, + const T* presence_scores, + const float* temperatures, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t* bad_words, + const int* output_padding_offset, + const int* output_cum_offsets, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words, + const int64_t token_num, + const int64_t max_seq_len) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_token_penalty_multi_scores", T); + WRAPPER_DUMP_PARAM6(ctx, + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures); + WRAPPER_DUMP_PARAM6(ctx, + cur_len, + min_len, + eos_token_id, + bad_words, + output_padding_offset, + output_cum_offsets); + WRAPPER_DUMP_PARAM4(ctx, bs, length, length_id, end_length); + WRAPPER_DUMP_PARAM3(ctx, length_bad_words, token_num, max_seq_len); + WRAPPER_DUMP(ctx); + // TODO(mayang02) shape check + int64_t pre_ids_len = -1; + int64_t logits_len = -1; + int64_t penalty_scores_len = -1; + int64_t frequency_scores_len = -1; + int64_t presence_scores_len = -1; + int64_t temperatures_len = -1; + int64_t cur_len_len = -1; + int64_t min_len_len = -1; + int64_t eos_token_id_len = -1; + int64_t bad_words_len = -1; + int64_t output_padding_offset_len = -1; + int64_t output_cum_offsets_len = -1; + WRAPPER_ASSERT_LE(ctx, bs, 640); + WRAPPER_CHECK_SHAPE(ctx, &pre_ids_len, {bs, length_id}); + WRAPPER_CHECK_SHAPE(ctx, &logits_len, {token_num, length}); + WRAPPER_CHECK_SHAPE(ctx, &penalty_scores_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &frequency_scores_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &presence_scores_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &temperatures_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &cur_len_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &min_len_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &eos_token_id_len, {end_length}); + WRAPPER_CHECK_SHAPE(ctx, &bad_words_len, {length_bad_words}); + WRAPPER_CHECK_SHAPE(ctx, &output_padding_offset_len, {token_num}); + WRAPPER_CHECK_SHAPE(ctx, &output_cum_offsets_len, {bs}); + WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_len, pre_ids); + WRAPPER_CHECK_PTR(ctx, T, logits_len, logits); + WRAPPER_CHECK_PTR(ctx, T, penalty_scores_len, penalty_scores); + WRAPPER_CHECK_PTR(ctx, T, frequency_scores_len, frequency_scores); + WRAPPER_CHECK_PTR(ctx, T, presence_scores_len, presence_scores); + WRAPPER_CHECK_PTR(ctx, float, temperatures_len, temperatures); + WRAPPER_CHECK_PTR(ctx, int64_t, cur_len_len, cur_len); + WRAPPER_CHECK_PTR(ctx, int64_t, min_len_len, min_len); + WRAPPER_CHECK_PTR(ctx, int64_t, eos_token_id_len, eos_token_id); + WRAPPER_CHECK_PTR(ctx, int64_t, bad_words_len, bad_words); + WRAPPER_CHECK_PTR(ctx, int, output_padding_offset_len, output_padding_offset); + WRAPPER_CHECK_PTR(ctx, int, output_cum_offsets_len, output_cum_offsets); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + cur_len, + min_len, + eos_token_id, + bad_words, + output_padding_offset, + output_cum_offsets, + bs, + length, + length_id, + end_length, + length_bad_words, + token_num, + max_seq_len); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + cur_len, + min_len, + eos_token_id, + bad_words, + output_padding_offset, + output_cum_offsets, + bs, + length, + length_id, + end_length, + length_bad_words, + token_num, + max_seq_len); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +template int speculate_token_penalty_multi_scores( + Context* ctx, + const int64_t* pre_ids, + float* logits, + const float* penalty_scores, + const float* frequency_scores, + const float* presence_scores, + const float* temperatures, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t* bad_words, + const int* output_padding_offset, + const int* output_cum_offsets, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words, + const int64_t token_num, + const int64_t max_seq_len); +template int speculate_token_penalty_multi_scores( + Context* ctx, + const int64_t* pre_ids, + float16* logits, + const float16* penalty_scores, + const float16* frequency_scores, + const float16* presence_scores, + const float* temperatures, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t* bad_words, + const int* output_padding_offset, + const int* output_cum_offsets, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words, + const int64_t token_num, + const int64_t max_seq_len); +template int speculate_token_penalty_multi_scores( + Context* ctx, + const int64_t* pre_ids, + bfloat16* logits, + const bfloat16* penalty_scores, + const bfloat16* frequency_scores, + const bfloat16* presence_scores, + const float* temperatures, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t* bad_words, + const int* output_padding_offset, + const int* output_cum_offsets, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words, + const int64_t token_num, + const int64_t max_seq_len); + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_update_v3.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_update_v3.cpp new file mode 100644 index 000000000..9feb8cf5f --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_update_v3.cpp @@ -0,0 +1,241 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +template +__attribute__((global)) void speculate_update_v3( + int *seq_lens_encoder, // 输入 [B_max, ] + int *seq_lens_decoder, // 输出 [B_max, ] + bool *not_need_stop, // 输出 [1,] + int64_t *draft_tokens, // 输出 [B_max, T_max] + int *actual_draft_token_nums, // 输出 [B_max, ] + const int64_t *accept_tokens, // 输入 [B_max, T_max] + const int *accept_num, // 输入 [B_max, ] + const bool *stop_flags, // 输入 [B_max, ] + const int *seq_lens_this_time, // 输入 [B_real,] + const bool *is_block_step, // 输入 [B_max, ] + const int64_t *stop_nums, // 输入 [1,] + const int real_bsz, + const int max_bsz, + const int max_draft_tokens); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context *ctx, + int *seq_lens_encoder, // 输入 [B_max, ] + int *seq_lens_decoder, // 输出 [B_max, ] + bool *not_need_stop, // [1,] + int64_t *draft_tokens, // [B_max, T_max] + int *actual_draft_token_nums, // [B_max, ] + const int64_t *accept_tokens, // [B_max, T_max] + const int *accept_num, // [B_max, ] + const bool *stop_flags, // [B_max, ] + const int *seq_lens_this_time, // [B_real,] + const bool *is_block_step, // [B_max, ] + const int64_t *stop_nums, // [1,] + const int real_bsz, + const int max_bsz, + const int max_draft_tokens) { + int64_t stop_sum = 0; + + for (int bid = 0; bid < max_bsz; ++bid) { + int stop_flag_now_int = 0; + + const bool inactive = (bid >= real_bsz); + const bool block_step = (!inactive && is_block_step[bid]); + + if (!block_step && !inactive) { + // 1. 本样本是否已触发 stop + if (stop_flags[bid]) stop_flag_now_int = 1; + + // 2. encoder len == 0 时可直接累加 decoder + if (seq_lens_encoder[bid] == 0) { + seq_lens_decoder[bid] += accept_num[bid]; + } + + // 3. 根据「是否全部接受」动态调整 draft 长度 + if (seq_lens_encoder[bid] == 0 && // append-mode 才走 + seq_lens_this_time[bid] > 1) { + int cur_len = actual_draft_token_nums[bid]; + + if (accept_num[bid] - 1 == cur_len) { + // 全部接受:尝试 +2 / +1 + if (cur_len + 2 <= max_draft_tokens - 1) + cur_len += 2; + else if (cur_len + 1 <= max_draft_tokens - 1) + cur_len += 1; + else + cur_len = max_draft_tokens - 1; + } else { + // 有拒绝:-1,最小 1 + cur_len = std::max(1, cur_len - 1); + } + actual_draft_token_nums[bid] = cur_len; + } + + // 4. 偿还 encoder 欠账 + if (seq_lens_encoder[bid] != 0) { + seq_lens_decoder[bid] += seq_lens_encoder[bid]; + const_cast(seq_lens_encoder)[bid] = 0; // cast 因原指针是 const + } + + // 6. 如果 stop,decoder 长度清零 + if (stop_flag_now_int) { + seq_lens_decoder[bid] = 0; + } else { + // 5. 写回下一轮首 token,但理论上只需要更新有效draft即可 + draft_tokens[bid * max_draft_tokens] = + accept_tokens[bid * max_draft_tokens + accept_num[bid] - 1]; + } + + } else if (inactive) { + // padding slot:直接当作 stop + stop_flag_now_int = 1; + } + + stop_sum += stop_flag_now_int; + } + + // 7. 写出全局标志 + not_need_stop[0] = (stop_sum < stop_nums[0]); + + return api::SUCCESS; +} + +static int xpu3_wrapper(Context *ctx, + int *seq_lens_encoder, // 输入 [B_max, ] + int *seq_lens_decoder, // 输出 [B_max, ] + bool *not_need_stop, // [1,] + int64_t *draft_tokens, // [B_max, T_max] + int *actual_draft_token_nums, // [B_max, ] + const int64_t *accept_tokens, // [B_max, T_max] + const int *accept_num, // [B_max, ] + const bool *stop_flags, // [B_max, ] + const int *seq_lens_this_time, // [B_real,] + const bool *is_block_step, // [B_max, ] + const int64_t *stop_nums, // [1,] + const int real_bsz, + const int max_bsz, + const int max_draft_tokens) { + constexpr int BlockSize = 512; + using XPU_TI = typename XPUIndexType::type; + xpu3::plugin::speculate_update_v3 + <<<1, 64, ctx->xpu_stream>>>(seq_lens_encoder, + seq_lens_decoder, + not_need_stop, + reinterpret_cast(draft_tokens), + actual_draft_token_nums, + (const XPU_TI *)accept_tokens, + accept_num, + stop_flags, + seq_lens_this_time, + is_block_step, + (const XPU_TI *)stop_nums, + real_bsz, + max_bsz, + max_draft_tokens); + return api::SUCCESS; +} + +int speculate_update_v3(Context *ctx, + int *seq_lens_encoder, // 输入 [B_max, ] + int *seq_lens_decoder, // 输出 [B_max, ] + bool *not_need_stop, // [1,] + int64_t *draft_tokens, // [B_max, T_max] + int *actual_draft_token_nums, // [B_max, ] + const int64_t *accept_tokens, // [B_max, T_max] + const int *accept_num, // [B_max, ] + const bool *stop_flags, // [B_max, ] + const int *seq_lens_this_time, // [B_real,] + const bool *is_block_step, // [B_max, ] + const int64_t *stop_nums, // [1,] + const int real_bsz, + const int max_bsz, + const int max_draft_tokens) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_update_v3", int); + WRAPPER_DUMP_PARAM4( + ctx, seq_lens_encoder, seq_lens_decoder, not_need_stop, draft_tokens); + WRAPPER_DUMP_PARAM4( + ctx, actual_draft_token_nums, accept_tokens, accept_num, stop_flags); + WRAPPER_DUMP_PARAM4( + ctx, seq_lens_this_time, is_block_step, stop_nums, real_bsz); + WRAPPER_DUMP_PARAM2(ctx, max_bsz, max_draft_tokens); + WRAPPER_DUMP(ctx); + WRAPPER_ASSERT_GT(ctx, real_bsz, 0); + WRAPPER_ASSERT_GT(ctx, max_bsz, 0); + WRAPPER_ASSERT_LE(ctx, max_bsz, 512); + WRAPPER_ASSERT_GT(ctx, max_draft_tokens, 0); + WRAPPER_ASSERT_GE(ctx, max_bsz, real_bsz); + WRAPPER_CHECK_PTR(ctx, int, max_bsz, seq_lens_encoder); + WRAPPER_CHECK_PTR(ctx, int, max_bsz, seq_lens_decoder); + WRAPPER_CHECK_PTR(ctx, bool, 1, not_need_stop); + WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz * max_draft_tokens, draft_tokens); + WRAPPER_CHECK_PTR(ctx, int, max_bsz, actual_draft_token_nums); + WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz * max_draft_tokens, accept_tokens); + WRAPPER_CHECK_PTR(ctx, int, max_bsz, accept_num); + WRAPPER_CHECK_PTR(ctx, bool, max_bsz, stop_flags); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, bool, max_bsz, is_block_step); + WRAPPER_CHECK_PTR(ctx, int64_t, 1, stop_nums); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + seq_lens_encoder, + seq_lens_decoder, + not_need_stop, + draft_tokens, + actual_draft_token_nums, + accept_tokens, + accept_num, + stop_flags, + seq_lens_this_time, + is_block_step, + stop_nums, + real_bsz, + max_bsz, + max_draft_tokens); + } else if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + seq_lens_encoder, + seq_lens_decoder, + not_need_stop, + draft_tokens, + actual_draft_token_nums, + accept_tokens, + accept_num, + stop_flags, + seq_lens_this_time, + is_block_step, + stop_nums, + real_bsz, + max_bsz, + max_draft_tokens); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp new file mode 100644 index 000000000..c5e3e425b --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp @@ -0,0 +1,543 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +typedef uint32_t curandStatePhilox4_32_10_t; + +template +__attribute__((global)) void speculate_verify( + int64_t *accept_tokens, + int *accept_num, + int64_t *step_idx, + bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *draft_tokens, + const int *actual_draft_token_nums, + const float *dev_curand_states, + const float *topp, + const int *seq_lens_this_time, + const int64_t *verify_tokens, + const float *verify_scores, + const int64_t *max_dec_len, + const int64_t *end_tokens, + const bool *is_block_step, + const int *output_cum_offsets, + const int *actual_candidate_len, + const int real_bsz, + const int max_draft_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const bool prefill_one_step_stop); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static inline bool is_in_end(const int64_t id, + const int64_t *end_ids, + int length) { + bool flag = false; + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; + } + } + return flag; +} + +static inline bool is_in(const int64_t *candidates, + const int64_t draft, + const int candidate_len) { + for (int i = 0; i < candidate_len; i++) { + if (draft == candidates[i]) { + return true; + } + } + return false; +} + +static inline unsigned int xorwow(unsigned int &state) { // NOLINT + state ^= state >> 7; + state ^= state << 9; + state ^= state >> 13; + return state; +} + +typedef uint32_t curandStatePhilox4_32_10_t; + +static int64_t topp_sampling_kernel(const int64_t *candidate_ids, + const float *candidate_scores, + const float *dev_curand_states, + const int candidate_len, + const float topp, + int tid) { + // const int tid = core_id(); + float sum_scores = 0.0f; + float rand_top_p = *dev_curand_states * topp; + for (int i = 0; i < candidate_len; i++) { + // printf("debug cpu sample i:%d scores:%f,ids:%ld + // rand_top_p:%f,candidate_len:%d\n", + // i,candidate_scores[i],candidate_ids[i],rand_top_p,candidate_len); + sum_scores += candidate_scores[i]; + if (rand_top_p <= sum_scores) { + return candidate_ids[i]; + } + } + return candidate_ids[0]; +} + +template +static int cpu_wrapper(Context *ctx, + int64_t *accept_tokens, + int *accept_num, + int64_t *step_idx, + bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *draft_tokens, + const int *actual_draft_token_nums, + const float *dev_curand_states, + const float *topp, + const int *seq_lens_this_time, + const int64_t *verify_tokens, + const float *verify_scores, + const int64_t *max_dec_len, + const int64_t *end_tokens, + const bool *is_block_step, + const int *output_cum_offsets, + const int *actual_candidate_len, + const int real_bsz, + const int max_draft_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const bool prefill_one_step_stop) { + for (int bid = 0; bid < real_bsz; ++bid) { + const int start_token_id = bid * max_seq_len - output_cum_offsets[bid]; + // verify and set stop flags + int accept_num_now = 1; + int stop_flag_now_int = 0; + + if (!(is_block_step[bid] || bid >= real_bsz)) { + // printf("debug cpu bid:%d,start_token_id:%d\n",bid, start_token_id); + // printf("bid %d\n", bid); + + if (stop_flags[bid]) { + stop_flag_now_int = 1; + } else { // 这里prefill阶段也会进入,但是因为draft + // tokens会置零,因此会直接到最后的采样阶段 + auto *verify_tokens_now = + verify_tokens + start_token_id * max_candidate_len; + auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens; + auto *actual_candidate_len_now = actual_candidate_len + start_token_id; + + int i = 0; + // printf("seq_lens_this_time[%d]-1: %d \n",bid, + // seq_lens_this_time[bid]-1); + for (; i < seq_lens_this_time[bid] - 1; i++) { + if (seq_lens_encoder[bid] != 0) { + break; + } + if (USE_TOPK) { + if (verify_tokens_now[i * max_candidate_len] == + draft_tokens_now[i + 1]) { + // accept_num_now++; + step_idx[bid]++; + auto accept_token = draft_tokens_now[i + 1]; + // printf("[USE_TOPK] bid %d Top 1 verify write accept + // %d is %lld\n", bid, i, accept_token); + accept_tokens[bid * max_draft_tokens + i] = accept_token; + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + // printf("[USE_TOPK] bid %d Top 1 verify write + // accept %d is %lld\n", bid, i, accept_token); + break; + } else { + accept_num_now++; + } + } else { + break; + } + } else { + auto actual_candidate_len_value = + actual_candidate_len_now[i] > max_candidate_len + ? max_candidate_len + : actual_candidate_len_now[i]; + if (is_in(verify_tokens_now + i * max_candidate_len, + draft_tokens_now[i + 1], + actual_candidate_len_value)) { + // Top P verify + // accept_num_now++; + step_idx[bid]++; + auto accept_token = draft_tokens_now[i + 1]; + accept_tokens[bid * max_draft_tokens + i] = accept_token; + + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + // printf("bid %d Top P verify write accept %d is + // %lld\n", bid, i, accept_token); + break; + } else { + accept_num_now++; + } + } else { + // TopK verify + int ii = i; + if (max_candidate_len >= 2 && + verify_tokens_now[ii * max_candidate_len + 1] == + draft_tokens_now[ii + 1]) { // top-2 + int j = 0; + ii += 1; + for (; j < verify_window && ii < seq_lens_this_time[bid] - 1; + j++, ii++) { + if (verify_tokens_now[ii * max_candidate_len] != + draft_tokens_now[ii + 1]) { + break; + } + } + if (j >= verify_window) { // accept all + accept_num_now += verify_window + 1; + step_idx[bid] += verify_window + 1; + for (; i < ii; i++) { + auto accept_token = draft_tokens_now[i + 1]; + accept_tokens[bid * max_draft_tokens + i] = accept_token; + // printf("bid %d TopK verify write accept %dis " + // "%lld\n",bid,i,accept_token); + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = + end_tokens[0]; + // printf("bid %d TopK verify write accept %d is %lld\n", + // bid, i,end_tokens[0]); + accept_num_now--; + step_idx[bid]--; + break; + } + } + } + } + break; + } + } + } + // sampling阶段 + // 第一种,draft_token[i+1]被拒绝,需要从verify_tokens_now[i]中选一个 + // 第二种,i == seq_lens_this_time[bid]-1, + // 也是从verify_tokens_now[i]中选一个 但是停止的情况不算 + if (!stop_flag_now_int) { + int64_t accept_token; + const float *verify_scores_now = + verify_scores + start_token_id * max_candidate_len; + step_idx[bid]++; + if (ENABLE_TOPP) { + auto actual_candidate_len_value = + actual_candidate_len_now[i] > max_candidate_len + ? max_candidate_len + : actual_candidate_len_now[i]; + + accept_token = + topp_sampling_kernel(verify_tokens_now + i * max_candidate_len, + verify_scores_now + i * max_candidate_len, + dev_curand_states + i, + actual_candidate_len_value, + topp[bid], + bid); + } else { + accept_token = verify_tokens_now[i * max_candidate_len]; + } + accept_tokens[bid * max_draft_tokens + i] = accept_token; + if (prefill_one_step_stop) { + stop_flags[bid] = true; + } + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + } + } + accept_num[bid] = accept_num_now; + } + } + } + return api::SUCCESS; +} + +template +static int xpu3_wrapper(Context *ctx, + int64_t *accept_tokens, + int *accept_num, + int64_t *step_idx, + bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *draft_tokens, + const int *actual_draft_token_nums, + const float *dev_curand_states, + const float *topp, + const int *seq_lens_this_time, + const int64_t *verify_tokens, + const float *verify_scores, + const int64_t *max_dec_len, + const int64_t *end_tokens, + const bool *is_block_step, + const int *output_cum_offsets, + const int *actual_candidate_len, + const int real_bsz, + const int max_draft_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const bool prefill_one_step_stop) { + using XPU_INT64 = typename XPUIndexType::type; + xpu3::plugin::speculate_verify + <<<1, 64, ctx->xpu_stream>>>( + reinterpret_cast(accept_tokens), + accept_num, + reinterpret_cast(step_idx), + stop_flags, + seq_lens_encoder, + seq_lens_decoder, + reinterpret_cast(draft_tokens), + actual_draft_token_nums, + dev_curand_states, + topp, + seq_lens_this_time, + reinterpret_cast(verify_tokens), + verify_scores, + reinterpret_cast(max_dec_len), + reinterpret_cast(end_tokens), + is_block_step, + output_cum_offsets, + actual_candidate_len, + real_bsz, + max_draft_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + prefill_one_step_stop); + return api::SUCCESS; +} +template +int speculate_verify(Context *ctx, + int64_t *accept_tokens, + int *accept_num, + int64_t *step_idx, + bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *draft_tokens, + const int *actual_draft_token_nums, + const float *dev_curand_states, + const float *topp, + const int *seq_lens_this_time, + const int64_t *verify_tokens, + const float *verify_scores, + const int64_t *max_dec_len, + const int64_t *end_tokens, + const bool *is_block_step, + const int *output_cum_offsets, + const int *actual_candidate_len, + const int real_bsz, + const int max_draft_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const bool prefill_one_step_stop) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_verify", int64_t); + WRAPPER_DUMP_PARAM3(ctx, accept_tokens, accept_num, step_idx); + WRAPPER_DUMP_PARAM6(ctx, + stop_flags, + seq_lens_encoder, + seq_lens_decoder, + draft_tokens, + actual_draft_token_nums, + topp); + WRAPPER_DUMP_PARAM5(ctx, + seq_lens_this_time, + verify_tokens, + verify_scores, + max_dec_len, + end_tokens); + WRAPPER_DUMP_PARAM5(ctx, + is_block_step, + output_cum_offsets, + actual_candidate_len, + real_bsz, + max_draft_tokens); + WRAPPER_DUMP_PARAM5(ctx, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + prefill_one_step_stop); + WRAPPER_DUMP(ctx); + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, accept_tokens); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, accept_num); + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, step_idx); + WRAPPER_CHECK_PTR(ctx, bool, real_bsz, stop_flags); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_encoder); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_decoder); + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, draft_tokens); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, actual_draft_token_nums); + WRAPPER_CHECK_PTR(ctx, float, real_bsz, dev_curand_states); + WRAPPER_CHECK_PTR(ctx, float, real_bsz, topp); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time); + // WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, verify_tokens); + // WRAPPER_CHECK_PTR(ctx, float, real_bsz, verify_scores); + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, max_dec_len); + WRAPPER_CHECK_PTR(ctx, int64_t, end_length, end_tokens); + WRAPPER_CHECK_PTR(ctx, bool, real_bsz, is_block_step); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, output_cum_offsets); + // WRAPPER_CHECK_PTR(ctx, int, real_bsz, actual_candidate_len); + + // param check sm size limit + WRAPPER_ASSERT_GT(ctx, real_bsz, 0); + WRAPPER_ASSERT_LE(ctx, real_bsz, 1024); + WRAPPER_ASSERT_LE(ctx, real_bsz * max_candidate_len, 2048); + WRAPPER_ASSERT_LE(ctx, verify_window * max_candidate_len, 128); + // int sum = 0; + // for (int i=0;i < real_bsz; i++){ + // sum+= seq_lens_this_time[i]; + // } + // WRAPPER_ASSERT_LE(ctx, sum * max_draft_tokens, 2048); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + accept_tokens, + accept_num, + step_idx, + stop_flags, + seq_lens_encoder, + seq_lens_decoder, + draft_tokens, + actual_draft_token_nums, + dev_curand_states, + topp, + seq_lens_this_time, + verify_tokens, + verify_scores, + max_dec_len, + end_tokens, + is_block_step, + output_cum_offsets, + actual_candidate_len, + real_bsz, + max_draft_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + prefill_one_step_stop); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + accept_tokens, + accept_num, + step_idx, + stop_flags, + seq_lens_encoder, + seq_lens_decoder, + draft_tokens, + actual_draft_token_nums, + dev_curand_states, + topp, + seq_lens_this_time, + verify_tokens, + verify_scores, + max_dec_len, + end_tokens, + is_block_step, + output_cum_offsets, + actual_candidate_len, + real_bsz, + max_draft_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + prefill_one_step_stop); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +#define INSTANTIATE_SPECULATE_VERIFY(ENABLE_TOPP, USE_TOPK) \ + template int \ + baidu::xpu::api::plugin::speculate_verify( \ + baidu::xpu::api::Context *, /* xpu_ctx */ \ + int64_t *, /* accept_tokens */ \ + int *, /* accept_num */ \ + int64_t *, /* step_idx */ \ + bool *, /* stop_flags */ \ + const int *, /* seq_lens_encoder */ \ + const int *, /* seq_lens_decoder */ \ + const int64_t *, /* draft_tokens */ \ + const int *, /* actual_draft_token_nums */ \ + const float *, /* dev_curand_states or topp */ \ + const float *, /* topp or nullptr */ \ + const int *, /* seq_lens_this_time */ \ + const int64_t *, /* verify_tokens */ \ + const float *, /* verify_scores */ \ + const int64_t *, /* max_dec_len */ \ + const int64_t *, /* end_tokens */ \ + const bool *, /* is_block_step */ \ + const int *, /* output_cum_offsets */ \ + const int *, /* actual_candidate_len */ \ + int, /* real_bsz */ \ + int, /* max_draft_tokens */ \ + int, /* end_length */ \ + int, /* max_seq_len */ \ + int, /* max_candidate_len */ \ + int, /* verify_window */ \ + bool); /* prefill_one_step_stop */ + +INSTANTIATE_SPECULATE_VERIFY(false, false) +INSTANTIATE_SPECULATE_VERIFY(false, true) +INSTANTIATE_SPECULATE_VERIFY(true, false) +INSTANTIATE_SPECULATE_VERIFY(true, true) + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/top_p_candidates.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/top_p_candidates.cpp new file mode 100644 index 000000000..5b0c489af --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/top_p_candidates.cpp @@ -0,0 +1,266 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +template +__attribute__((global)) void top_p_candidates(const T* src, + const T* top_ps, + const int* output_padding_offset, + int64_t* out_id, + T* out_val, + int* actual_candidates_lens, + int vocab_size, + int token_num, + int max_candidate_len, + int max_seq_len); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +template +static int cpu_wrapper(Context* ctx, + const T* src, + const T* top_ps, + const int* output_padding_offset, + int64_t* out_id, + T* out_val, + int* actual_candidates_lens, + int vocab_size, + int token_num, + int candidate_len, + int max_seq_len) { + int64_t local_out_id[TopPBeamTopK]; + T local_out_val[TopPBeamTopK]; + + for (int64_t i = 0; i < token_num; i++) { + float sum_prob = 0.0f; + for (int j = 0; j < TopPBeamTopK; j++) { + local_out_id[j] = -1; + local_out_val[j] = std::numeric_limits::min(); + } + const T* cur_row_src = src + i * vocab_size; + for (int id = 0; id < vocab_size; id++) { + if (cur_row_src[id] > local_out_val[TopPBeamTopK - 1] || + (cur_row_src[id] == local_out_val[TopPBeamTopK - 1] && + id < local_out_id[TopPBeamTopK - 1])) { + local_out_id[TopPBeamTopK - 1] = id; + local_out_val[TopPBeamTopK - 1] = cur_row_src[id]; + for (int k = TopPBeamTopK - 2; k >= 0; k--) { + if (local_out_val[k + 1] > local_out_val[k] || + (local_out_val[k + 1] == local_out_val[k] && + local_out_id[k + 1] < local_out_id[k])) { + std::swap(local_out_id[k + 1], local_out_id[k]); + std::swap(local_out_val[k + 1], local_out_val[k]); + } + } + } + } + int ori_token_id = i + output_padding_offset[i]; + int bid = ori_token_id / max_seq_len; + float top_p_value = static_cast(top_ps[bid]); + bool set_to_default_val = false; + for (int j = 0; j < TopPBeamTopK; j++) { + if (set_to_default_val) { + out_id[i * candidate_len + j] = 0; + out_val[i * candidate_len + j] = 0; + } else { + out_id[i * candidate_len + j] = local_out_id[j]; + out_val[i * candidate_len + j] = local_out_val[j]; + float val = static_cast(local_out_val[j]); + sum_prob += val; + if (sum_prob >= top_p_value) { + actual_candidates_lens[i] = j + 1; + set_to_default_val = true; + } + } + } + } + return api::SUCCESS; +} + +template +static int xpu3_wrapper(Context* ctx, + const T* src, + const T* top_ps, + const int* output_padding_offset, + int64_t* out_id, + T* out_val, + int* actual_candidates_lens, + int vocab_size, + int token_num, + int candidate_len, + int max_seq_len) { + using XPU_INT64 = typename XPUIndexType::type; + xpu3::plugin::top_p_candidates + <<ncluster(), 64, ctx->xpu_stream>>>( + src, + top_ps, + output_padding_offset, + reinterpret_cast(out_id), + out_val, + actual_candidates_lens, + vocab_size, + token_num, + candidate_len, + max_seq_len); + return api::SUCCESS; +} + +template +int top_p_candidates(Context* ctx, + const T* src, + const T* top_ps, + const int* output_padding_offset, + int64_t* out_id, + T* out_val, + int* actual_candidates_lens, + int vocab_size, + int token_num, + int candidate_len, + int max_seq_len) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "top_p_candidates", T); + WRAPPER_DUMP_PARAM5(ctx, src, top_ps, output_padding_offset, out_id, out_val); + WRAPPER_DUMP_PARAM5(ctx, + actual_candidates_lens, + vocab_size, + token_num, + candidate_len, + max_seq_len); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, T, token_num * vocab_size, src); + WRAPPER_CHECK_PTR(ctx, T, token_num, output_padding_offset); + WRAPPER_CHECK_PTR(ctx, T, token_num * candidate_len, out_id); + WRAPPER_CHECK_PTR(ctx, T, token_num * candidate_len, out_val); + + WRAPPER_ASSERT_GT(ctx, vocab_size, 0); + WRAPPER_ASSERT_GT(ctx, token_num, 0); + WRAPPER_ASSERT_GT(ctx, candidate_len, 0); + WRAPPER_ASSERT_GT(ctx, max_seq_len, 0); + WRAPPER_ASSERT_GT(ctx, TopPBeamTopK, 0); + WRAPPER_ASSERT_LE(ctx, TopPBeamTopK, 10); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + src, + top_ps, + output_padding_offset, + out_id, + out_val, + actual_candidates_lens, + vocab_size, + token_num, + candidate_len, + max_seq_len); + } else if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + src, + top_ps, + output_padding_offset, + out_id, + out_val, + actual_candidates_lens, + vocab_size, + token_num, + candidate_len, + max_seq_len); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +#define _XPU_DEF_TOP_P_CANDIDATES_WRAPPER(T, MaxLength) \ + template int top_p_candidates(Context*, \ + const T*, \ + const T*, \ + const int*, \ + int64_t*, \ + T*, \ + int*, \ + int, \ + int, \ + int, \ + int); \ + template int top_p_candidates(Context*, \ + const T*, \ + const T*, \ + const int*, \ + int64_t*, \ + T*, \ + int*, \ + int, \ + int, \ + int, \ + int); \ + template int top_p_candidates(Context*, \ + const T*, \ + const T*, \ + const int*, \ + int64_t*, \ + T*, \ + int*, \ + int, \ + int, \ + int, \ + int); \ + template int top_p_candidates(Context*, \ + const T*, \ + const T*, \ + const int*, \ + int64_t*, \ + T*, \ + int*, \ + int, \ + int, \ + int, \ + int); \ + template int top_p_candidates(Context*, \ + const T*, \ + const T*, \ + const int*, \ + int64_t*, \ + T*, \ + int*, \ + int, \ + int, \ + int, \ + int); \ + template int top_p_candidates(Context*, \ + const T*, \ + const T*, \ + const int*, \ + int64_t*, \ + T*, \ + int*, \ + int, \ + int, \ + int, \ + int); + +_XPU_DEF_TOP_P_CANDIDATES_WRAPPER(bfloat16, 2); +_XPU_DEF_TOP_P_CANDIDATES_WRAPPER(float, 2); +_XPU_DEF_TOP_P_CANDIDATES_WRAPPER(float16, 2); + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/setup_ops.py b/custom_ops/xpu_ops/src/setup_ops.py index 5ad31e912..88f450916 100755 --- a/custom_ops/xpu_ops/src/setup_ops.py +++ b/custom_ops/xpu_ops/src/setup_ops.py @@ -162,6 +162,11 @@ def xpu_setup_ops(): ] ops = [os.path.join(base_dir, op) for op in ops] + for root, dirs, files in os.walk(base_dir / "ops/mtp_ops"): + for file in files: + if file.endswith(".cc"): + ops.append(os.path.join(root, file)) + include_dirs = [ os.path.join(base_dir, "./"), os.path.join(base_dir, "./plugin/include"), diff --git a/custom_ops/xpu_ops/test/test_draft_model_postprocess.py b/custom_ops/xpu_ops/test/test_draft_model_postprocess.py new file mode 100644 index 000000000..e0920277e --- /dev/null +++ b/custom_ops/xpu_ops/test/test_draft_model_postprocess.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import draft_model_postprocess + + +def draft_model_postprocess_cpu( + base_model_draft_tokens, # 2D列表: [bsz, base_model_draft_token_len] # 1D列表: [bsz] + base_model_seq_lens_encoder, # 1D列表: [bsz] + base_model_stop_flags, # 1D列表: [bsz] +): + bsz = base_model_draft_tokens.shape[0] + base_model_draft_token_len = base_model_draft_tokens.shape[1] + base_model_seq_lens_this_time = paddle.ones((bsz), dtype=paddle.int32) + # 遍历每个样本 + for tid in range(bsz): + if (not base_model_stop_flags[tid]) and (base_model_seq_lens_encoder[tid] == 0): + # 获取当前样本的草稿token列表 + base_model_draft_tokens_now = base_model_draft_tokens[tid] + token_num = 0 + for i in range(base_model_draft_token_len): + if base_model_draft_tokens_now[i] != -1: + token_num += 1 + + # 更新序列长度 + base_model_seq_lens_this_time[tid] = token_num + elif base_model_stop_flags[tid]: + # 已停止的样本序列长度为0 + base_model_seq_lens_this_time[tid] = 0 + + return [base_model_seq_lens_this_time] + + +def test_draft_model_postprocess(batch_size=1, base_model_draft_token_len=8192): # 批次大小 + paddle.seed(66) + base_model_draft_tokens = paddle.randint( + low=-1, + high=1, + shape=[batch_size, base_model_draft_token_len], + dtype="int64", + ) + # base_model_seq_lens_this_time = paddle.ones((batch_size), dtype=paddle.int32) + base_model_seq_lens_encoder = paddle.randint(low=0, high=2, shape=[batch_size], dtype="int32") + random_floats = paddle.rand(shape=[batch_size]) + base_model_stop_flags = random_floats >= 0.5 + + base_model_seq_lens_this_time = draft_model_postprocess_cpu( + base_model_draft_tokens, # 2D列表: [bsz, base_model_draft_token_len] + base_model_seq_lens_encoder, # 1D列表: [bsz] + base_model_stop_flags, + ) + base_model_seq_lens_this_time_xpu = paddle.ones((batch_size), dtype=paddle.int32) + draft_model_postprocess( + base_model_draft_tokens, # 2D列表: [bsz, base_model_draft_token_len] + base_model_seq_lens_this_time_xpu, # 1D列表: [bsz] + base_model_seq_lens_encoder, # 1D列表: [bsz] + base_model_stop_flags, + ) + print("test start") + assert np.allclose(base_model_seq_lens_this_time, base_model_seq_lens_this_time_xpu) + print("test passed") + + +def test_enough_cases(): + test_draft_model_postprocess(100, 1024) + test_draft_model_postprocess(1, 11) + test_draft_model_postprocess(1, 8192) + test_draft_model_postprocess(2, 2048) + test_draft_model_postprocess(3, 1023) + test_draft_model_postprocess(4, 2047) + test_draft_model_postprocess(5, 4095) + test_draft_model_postprocess(10, 9191) + test_draft_model_postprocess(20, 618) + test_draft_model_postprocess(30, 703) + test_draft_model_postprocess(100, 1025) + test_draft_model_postprocess(1536, 1026) + + +if __name__ == "__main__": + test_enough_cases() diff --git a/custom_ops/xpu_ops/test/test_draft_model_preprocess.py b/custom_ops/xpu_ops/test/test_draft_model_preprocess.py new file mode 100644 index 000000000..c687bdf30 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_draft_model_preprocess.py @@ -0,0 +1,135 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import draft_model_preprocess + + +def run_test(device="xpu"): + paddle.seed(2022) + + # Define parameters + bsz = 10 + draft_tokens_len = 4 + input_ids_len = 8 + max_draft_token = 10 + + truncate_first_token = True + splitwise_prefill = False + # Create input tensors + if device == "cpu": + paddle.set_device(device) + + draft_tokens = paddle.randint(0, 100, [bsz, draft_tokens_len], dtype="int64") + input_ids = paddle.randint(0, 100, [bsz, input_ids_len], dtype="int64") + stop_flags = paddle.randint(0, 1, [bsz], dtype="int").cast("bool") + seq_lens_this_time = paddle.randint(0, 100, [bsz], dtype="int32") + seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32") + seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32") + step_idx = paddle.randint(0, 100, [bsz], dtype="int64") + seq_lens_encoder_record = paddle.randint(0, 100, [bsz], dtype="int32") + seq_lens_decoder_record = paddle.randint(0, 100, [bsz], dtype="int32") + not_need_stop = paddle.zeros([1], dtype="bool").cpu() + batch_drop = paddle.zeros([bsz], dtype="bool") + + # Output tensors + accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64") + accept_num = paddle.randint(1, max_draft_token + 5, [bsz], dtype="int32") + base_model_seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32") + base_model_seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32") + base_model_step_idx = paddle.randint(0, 100, [bsz], dtype="int64") + base_model_stop_flags = paddle.zeros([bsz], dtype="bool") + base_model_is_block_step = paddle.zeros([bsz], dtype="bool") + base_model_draft_tokens = paddle.zeros([bsz, max_draft_token], dtype="int64") + # Run the op + outputs = draft_model_preprocess( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + seq_lens_encoder_record, + seq_lens_decoder_record, + not_need_stop, + batch_drop, + accept_tokens, + accept_num, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + max_draft_token=max_draft_token, + truncate_first_token=truncate_first_token, + splitwise_prefill=splitwise_prefill, + ) + + # Return results for comparison + results = { + "draft_tokens": draft_tokens.numpy(), + "input_ids": input_ids.numpy(), + "stop_flags": stop_flags.numpy(), + "seq_lens_this_time": seq_lens_this_time.numpy(), + "accept_tokens": accept_tokens.numpy(), + "accept_num": accept_num.numpy(), + "not_need_stop": not_need_stop.numpy(), + "outputs": [x.numpy() for x in outputs], + } + return results + + +def compare_results(cpu_results, xpu_results): + # Compare all outputs + for key in cpu_results: + if key == "outputs": + for i, (cpu_out, xpu_out) in enumerate(zip(cpu_results[key], xpu_results[key])): + np.testing.assert_allclose( + cpu_out, + xpu_out, + rtol=1e-5, + atol=1e-8, + err_msg=f"Output {i} mismatch between CPU and GPU", + ) + else: + np.testing.assert_allclose( + cpu_results[key], + xpu_results[key], + rtol=1e-5, + atol=1e-8, + err_msg=f"{key} mismatch between CPU and GPU", + ) + print("CPU and GPU results match!") + + +def test_draft_model_preprocess(): + + print("Running XPU test...") + xpu_results = run_test("xpu") + + print("Running CPU test...") + cpu_results = run_test("cpu") + + print("Comparing results...") + compare_results(cpu_results, xpu_results) + + print("Test passed!") + + +if __name__ == "__main__": + test_draft_model_preprocess() diff --git a/custom_ops/xpu_ops/test/test_draft_model_update.py b/custom_ops/xpu_ops/test/test_draft_model_update.py new file mode 100644 index 000000000..268e08229 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_draft_model_update.py @@ -0,0 +1,122 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import draft_model_update + + +def run_paddle_test(device="cpu"): + np.random.seed(42) + paddle.seed(42) + if device == "cpu": + paddle.set_device(device) + elif device == "xpu": + paddle.set_device(device) + else: + raise ValueError(f"Invalid device: {device}") + + # 设置参数 + max_bsz = 128 + max_draft_token = 3 + pre_id_length = 3 + max_seq_len = 100 + max_base_model_draft_token = 4 + substep = 2 + + # 创建随机张量 + inter_next_tokens = paddle.randint(1, 100, shape=(max_bsz, max_seq_len), dtype="int64") + draft_tokens = paddle.randint(1, 100, shape=(max_bsz, max_draft_token), dtype="int64") + pre_ids = paddle.randint(1, 100, shape=(max_bsz, pre_id_length), dtype="int64") + seq_lens_this_time = paddle.randint(1, 2, shape=(max_bsz,), dtype="int32") + seq_lens_encoder = paddle.randint(1, 10, shape=(max_bsz,), dtype="int32") + seq_lens_decoder = paddle.randint(1, 10, shape=(max_bsz,), dtype="int32") + step_idx = paddle.randint(1, 10, shape=(max_bsz,), dtype="int64") + output_cum_offsets = paddle.randint(0, 2, shape=(max_bsz,), dtype="int32") + output_cum_offsets[0] = 0 # 确保第一个偏移量为0 + stop_flags = paddle.zeros([max_bsz], dtype="bool") + not_need_stop = paddle.zeros([1], dtype="bool") + max_dec_len = paddle.randint(100, 102, shape=(max_bsz,), dtype="int64") + end_ids = paddle.to_tensor([2], dtype="int64") + base_model_draft_tokens = paddle.randint(1, 10, shape=(max_bsz, max_base_model_draft_token), dtype="int64") + + # 打印张量信息 + # print("inter_next_tokens shape:", inter_next_tokens.shape) + # print("draft_tokens shape:", draft_tokens.shape) + # print("pre_ids shape:", pre_ids.shape) + # print("seq_lens_this_time shape:", seq_lens_this_time.shape) + # print("seq_lens_encoder shape:", seq_lens_encoder.shape) + # print("seq_lens_decoder shape:", seq_lens_decoder.shape) + # print("step_idx shape:", step_idx.shape) + # print("output_cum_offsets shape:", output_cum_offsets.shape) + # print("stop_flags shape:", stop_flags.shape) + # print("not_need_stop shape:", not_need_stop.shape) + # print("max_dec_len shape:", max_dec_len.shape) + # print("end_ids shape:", end_ids.shape) + # print("base_model_draft_tokens shape:", base_model_draft_tokens.shape) + + # print("draft_tokens before update:", draft_tokens) + # print("pre_ids before update:", pre_ids) + draft_model_update( + inter_next_tokens, + draft_tokens, + pre_ids, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + output_cum_offsets, + stop_flags, + not_need_stop, + max_dec_len, + end_ids, + base_model_draft_tokens, + max_seq_len, + substep, + ) + # print("draft_tokens after update:", draft_tokens) + # print("pre_ids after update:", pre_ids) + return ( + draft_tokens, + pre_ids, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + stop_flags, + not_need_stop, + base_model_draft_tokens, + ) + + +if __name__ == "__main__": + res_xpu = run_paddle_test("xpu") + res_cpu = run_paddle_test() + for idx in range(len(res_cpu)): + # 将结果转换为numpy数组 + cpu_arr = res_cpu[idx].numpy() + xpu_arr = res_xpu[idx].numpy() + + # 检查是否为布尔类型 + if cpu_arr.dtype == bool: + assert np.array_equal(cpu_arr, xpu_arr), f"布尔结果在索引 {idx} 处不匹配" + else: + # 对于数值类型,使用更宽松的比较条件 + assert np.allclose( + cpu_arr, xpu_arr, rtol=1e-4, atol=1e-5 + ), f"数值结果在索引 {idx} 处不匹配,最大差异: {np.max(np.abs(cpu_arr - xpu_arr))}" + + print(f"结果 {idx} 验证通过") diff --git a/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py b/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py new file mode 100644 index 000000000..ac68a53e3 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py @@ -0,0 +1,104 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import eagle_get_hidden_states + + +def test_eagle_get_hidden_states(): + bs = np.random.randint(1, 8 + 1, dtype=np.int32) + input_token_num = np.random.randint(2 * 1024, 4 * 1024 + 1, dtype=np.int32) + dim_embed = np.random.randint(1, 4 * 1024 + 1, dtype=np.int32) + actual_draft_token_num = np.random.randint(2, 6, dtype=np.int32) + + seq_lens_this_time = np.random.randint(0, 2, bs, dtype=np.int32) + seq_lens_encoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) + accept_nums = np.random.randint(0, actual_draft_token_num + 1, bs, dtype=np.int32) + base_model_seq_lens_this_time = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) + base_model_seq_lens_encoder = np.random.randint(0, 2, bs, dtype=np.int32) + # dont care + seq_lens_decoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) + stop_flags = np.random.randint(0, 2, bs, dtype=np.int32) + + seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32) + seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder, dtype=paddle.int32) + accept_nums_tensor = paddle.to_tensor(accept_nums, dtype=paddle.int32) + base_model_seq_lens_this_time_tensor = paddle.to_tensor(base_model_seq_lens_this_time, dtype=paddle.int32) + base_model_seq_lens_encoder_tensor = paddle.to_tensor(base_model_seq_lens_encoder, dtype=paddle.int32) + # dont care + seq_lens_decoder_tensor = paddle.to_tensor(seq_lens_decoder, dtype=paddle.int32) + stop_flags_tensor = paddle.to_tensor(stop_flags, dtype=paddle.int32) + + # fp32 test + input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32) + input_tensor = paddle.to_tensor(input, dtype=paddle.float32) + cpu_out = eagle_get_hidden_states( + input_tensor.cpu(), + seq_lens_this_time_tensor.cpu(), + seq_lens_encoder_tensor.cpu(), + seq_lens_decoder_tensor.cpu(), + stop_flags_tensor.cpu(), + accept_nums_tensor.cpu(), + base_model_seq_lens_this_time_tensor.cpu(), + base_model_seq_lens_encoder_tensor.cpu(), + actual_draft_token_num, + ) + xpu_out = eagle_get_hidden_states( + input_tensor, + seq_lens_this_time_tensor, + seq_lens_encoder_tensor, + seq_lens_decoder_tensor, + stop_flags_tensor, + accept_nums_tensor, + base_model_seq_lens_this_time_tensor, + base_model_seq_lens_encoder_tensor, + actual_draft_token_num, + ) + assert np.allclose(cpu_out.numpy(), xpu_out.numpy()) + + # bf16/fp16 test + for dtype in [paddle.bfloat16, paddle.float16]: + input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int16) + input_tensor = paddle.to_tensor(input, dtype=dtype) + cpu_out = eagle_get_hidden_states( + input_tensor.cpu(), + seq_lens_this_time_tensor.cpu(), + seq_lens_encoder_tensor.cpu(), + seq_lens_decoder_tensor.cpu(), + stop_flags_tensor.cpu(), + accept_nums_tensor.cpu(), + base_model_seq_lens_this_time_tensor.cpu(), + base_model_seq_lens_encoder_tensor.cpu(), + actual_draft_token_num, + ) + xpu_out = eagle_get_hidden_states( + input_tensor, + seq_lens_this_time_tensor, + seq_lens_encoder_tensor, + seq_lens_decoder_tensor, + stop_flags_tensor, + accept_nums_tensor, + base_model_seq_lens_this_time_tensor, + base_model_seq_lens_encoder_tensor, + actual_draft_token_num, + ) + assert np.allclose(cpu_out.numpy(), xpu_out.numpy()) + + print("All test passed") + + +if __name__ == "__main__": + test_eagle_get_hidden_states() diff --git a/custom_ops/xpu_ops/test/test_eagle_get_self_hidden_states.py b/custom_ops/xpu_ops/test/test_eagle_get_self_hidden_states.py new file mode 100644 index 000000000..2808c95a9 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_eagle_get_self_hidden_states.py @@ -0,0 +1,132 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import eagle_get_self_hidden_states + + +def computeOrder(last_seq_lens_this_time, seq_lens_this_time, step_idx, src_map, bsz): + in_offset = 0 + out_offset = 0 + for i in range(bsz): + cur_seq_lens_this_time = seq_lens_this_time[i] + cur_last_seq_lens_this_time = last_seq_lens_this_time[i] + + # 1. encoder + if step_idx[i] == 1 and cur_seq_lens_this_time > 0: + in_offset += 1 + src_map[out_offset] = in_offset - 1 + out_offset += 1 + # 2. decoder + elif cur_seq_lens_this_time > 0: + in_offset += cur_last_seq_lens_this_time + src_map[out_offset] = in_offset - 1 + out_offset += 1 + # 3. stop + else: + # first token end + if step_idx[i] == 1: + in_offset += 1 if cur_last_seq_lens_this_time > 0 else 0 + # normal end + else: + in_offset += cur_last_seq_lens_this_time + + return (out_offset, src_map) + + +def rebuildSelfHiddenStatesKernel(input, src_map, out, dim_embed, elem_cnt): + print(f"input.shape {input.shape}") + print(f"out.shape {out.shape}") + print(f"elem_cnt {elem_cnt}") + for elem_id in range(elem_cnt): + output_token_idx = elem_id // dim_embed + input_token_idx = src_map[output_token_idx] + offset = elem_id % dim_embed + out[output_token_idx * dim_embed + offset] = input[input_token_idx * dim_embed + offset] + return out + + +def ref_eagle_get_self_hidden_states(input, last_seq_lens_this_time, seq_lens_this_time, step_idx): + input_token_num = input.shape[0] + dim_embed = input.shape[1] + bsz = seq_lens_this_time.shape[0] + src_map = np.full(input_token_num, -1, seq_lens_this_time.dtype) + + output_token_num, src_map = computeOrder(last_seq_lens_this_time, seq_lens_this_time, step_idx, src_map, bsz) + + out = np.full([output_token_num * dim_embed], -1, input.dtype) + + elem_cnt = output_token_num * dim_embed + + out = rebuildSelfHiddenStatesKernel(input, src_map, out, dim_embed, elem_cnt) + out = out.reshape([output_token_num, dim_embed]) + + return out + + +def test_eagle_get_self_hidden_states(): + bs = np.random.randint(1, 8 + 1, dtype=np.int32) + input_token_num = np.random.randint(2 * 1024, 4 * 1024 + 1, dtype=np.int32) + dim_embed = np.random.randint(1, 4 * 1024 + 1, dtype=np.int32) + + last_seq_lens_this_time = np.random.randint(0, input_token_num // bs, bs, dtype=np.int32) + seq_lens_this_time = np.random.randint(0, input_token_num // bs, bs, dtype=np.int32) + step_idx = np.arange(0, bs, dtype=np.int32) + + last_seq_lens_this_time_tensor = paddle.to_tensor(last_seq_lens_this_time, dtype=paddle.int32) + seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32) + step_idx_tensor = paddle.to_tensor(step_idx, dtype=paddle.int64) + + # fp32 test + input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32) + input_tensor = paddle.to_tensor(input, dtype=paddle.float32) + cpu_out = eagle_get_self_hidden_states( + input_tensor.cpu(), + last_seq_lens_this_time_tensor.cpu(), + seq_lens_this_time_tensor.cpu(), + step_idx_tensor.cpu(), + ) + xpu_out = eagle_get_self_hidden_states( + input_tensor, + last_seq_lens_this_time_tensor, + seq_lens_this_time_tensor, + step_idx_tensor, + ) + assert np.allclose(cpu_out.numpy(), xpu_out.numpy()) + + # bf16/fp16 test + for dtype in [paddle.bfloat16, paddle.float16]: + input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int16) + input_tensor = paddle.to_tensor(input, dtype=dtype) + cpu_out = eagle_get_self_hidden_states( + input_tensor.cpu(), + last_seq_lens_this_time_tensor.cpu(), + seq_lens_this_time_tensor.cpu(), + step_idx_tensor.cpu(), + ) + xpu_out = eagle_get_self_hidden_states( + input_tensor, + last_seq_lens_this_time_tensor, + seq_lens_this_time_tensor, + step_idx_tensor, + ) + assert np.allclose(cpu_out.numpy(), xpu_out.numpy()) + + print("All test passed") + + +if __name__ == "__main__": + test_eagle_get_self_hidden_states() diff --git a/custom_ops/xpu_ops/test/python/ops/test_get_padding_offset.py b/custom_ops/xpu_ops/test/test_get_padding_offset.py similarity index 100% rename from custom_ops/xpu_ops/test/python/ops/test_get_padding_offset.py rename to custom_ops/xpu_ops/test/test_get_padding_offset.py diff --git a/custom_ops/xpu_ops/test/python/ops/test_get_token_penalty_multi_scores.py b/custom_ops/xpu_ops/test/test_get_token_penalty_multi_scores.py similarity index 100% rename from custom_ops/xpu_ops/test/python/ops/test_get_token_penalty_multi_scores.py rename to custom_ops/xpu_ops/test/test_get_token_penalty_multi_scores.py diff --git a/custom_ops/xpu_ops/test/python/ops/test_set_value_by_flags_and_idx.py b/custom_ops/xpu_ops/test/test_set_value_by_flags_and_idx.py similarity index 100% rename from custom_ops/xpu_ops/test/python/ops/test_set_value_by_flags_and_idx.py rename to custom_ops/xpu_ops/test/test_set_value_by_flags_and_idx.py diff --git a/custom_ops/xpu_ops/test/test_speculate_clear_accept_nums.py b/custom_ops/xpu_ops/test/test_speculate_clear_accept_nums.py new file mode 100644 index 000000000..a11824f0d --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_clear_accept_nums.py @@ -0,0 +1,46 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import speculate_clear_accept_nums + +np.set_printoptions(threshold=np.inf) # threshold设为无穷大 +np.set_printoptions(linewidth=np.inf) # 确保一行显示完整(可选) + + +def speculate_clear_accept_nums_np(accept_num, seq_lens_decoder): + for i in range(len(accept_num)): + if seq_lens_decoder[i] == 0: + accept_num[i] = 0 + return accept_num, seq_lens_decoder + + +max_bs = 1024 +accept_num_np = np.random.randint(low=0, high=11, size=[max_bs], dtype="int32") +accept_num_paddle = paddle.to_tensor(accept_num_np) + +seq_lens_decoder_np = np.random.randint(low=0, high=2, size=[max_bs], dtype="int32") +seq_lens_decoder_paddle = paddle.to_tensor(seq_lens_decoder_np) + +a = accept_num_paddle.numpy() +# print((a - accept_num_np).sum()) +assert (a - accept_num_np).sum() == 0, "Check failed." +accept_num_np, seq_lens_decoder_np = speculate_clear_accept_nums_np(accept_num_np, seq_lens_decoder_np) +seq_lens_decoder_paddle = speculate_clear_accept_nums(accept_num_paddle, seq_lens_decoder_paddle) +b = accept_num_paddle.numpy() +# print(b) +# print((accept_num_np - b).sum()) +assert (accept_num_np - b).sum() == 0, "Check failed." diff --git a/custom_ops/xpu_ops/test/test_speculate_get_output_padding_offset.py b/custom_ops/xpu_ops/test/test_speculate_get_output_padding_offset.py new file mode 100644 index 000000000..f61de62ad --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_get_output_padding_offset.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +if paddle.is_compiled_with_xpu(): + from fastdeploy.model_executor.ops.xpu import speculate_get_output_padding_offset +else: + from efficientllm.ops.gpu import speculate_get_output_padding_offset + + +def test_speculate_get_output_padding_offset(): + bsz = 256 + max_seq_len = 8192 + + seq_lens_output = np.random.randint(0, 4, size=bsz) + output_token_num = np.sum(seq_lens_output) + + seq_lens_output = paddle.to_tensor(seq_lens_output, dtype="int32") + out_token_num = paddle.sum(seq_lens_output) + output_cum_offsets_tmp = paddle.cumsum(max_seq_len - seq_lens_output) + + output_padding_offset_xpu, output_cum_offsets_xpu = speculate_get_output_padding_offset( + output_cum_offsets_tmp, out_token_num, seq_lens_output, max_seq_len + ) + + output_padding_offset_cpu = [-1] * output_token_num + output_cum_offsets_cpu = [-1] * bsz + + for bi in range(bsz): + cum_offset = 0 if bi == 0 else output_cum_offsets_tmp[bi - 1] + output_cum_offsets_cpu[bi] = cum_offset + for token_i in range(seq_lens_output[bi]): + output_padding_offset_cpu[bi * max_seq_len - cum_offset + token_i] = cum_offset + + # print(f"seq_lens_output: {seq_lens_output}") + # print(f"output_cum_offsets_tmp: {output_cum_offsets_tmp}") + # print(f"output_padding_offset_xpu: {output_padding_offset_xpu}") + # print(f"output_cum_offsets_xpu: {output_cum_offsets_xpu}") + # print(f"output_padding_offset_cpu: {output_padding_offset_cpu}") + # print(f"output_cum_offsets_cpu: {output_cum_offsets_cpu}") + + assert np.array_equal( + output_padding_offset_xpu, output_padding_offset_cpu + ), "output_padding_offset_xpu != output_padding_offset_cpu" + assert np.array_equal( + output_cum_offsets_xpu, output_cum_offsets_cpu + ), "output_cum_offsets_xpu != output_cum_offsets_cpu" + + print("test_speculate_get_output_padding_offset passed!") + + +if __name__ == "__main__": + test_speculate_get_output_padding_offset() diff --git a/custom_ops/xpu_ops/test/test_speculate_get_padding_offset.py b/custom_ops/xpu_ops/test/test_speculate_get_padding_offset.py new file mode 100644 index 000000000..b2001a4f6 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_get_padding_offset.py @@ -0,0 +1,525 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import speculate_get_padding_offset + +test_failed = False + + +def ref_speculate_get_padding_offset(cum_offsets, seq_lens, max_seq_len, token_num_data): + bsz = seq_lens.shape[0] + + padding_offset = np.zeros([token_num_data], dtype=np.int32) + cum_offsets_out = np.zeros([bsz], dtype=np.int32) + cu_seqlens_q = np.zeros([bsz + 1], dtype=np.int32) + cu_seqlens_k = np.zeros([bsz + 1], dtype=np.int32) + + modified_indices = { + "padding_offset": [], + "cum_offsets_out": [], + "cu_seqlens_q": [], + "cu_seqlens_k": [], + } + + cu_seqlens_q[0] = 0 + cu_seqlens_k[0] = 0 + modified_indices["cu_seqlens_q"].append(0) + modified_indices["cu_seqlens_k"].append(0) + + for bi in range(bsz): + cum_offset = 0 if bi == 0 else cum_offsets[bi - 1] + cum_offsets_out[bi] = cum_offset + modified_indices["cum_offsets_out"].append(bi) + + for i in range(seq_lens[bi]): + idx = bi * max_seq_len - cum_offset + i + if idx >= 0 and idx < token_num_data: + padding_offset[idx] = cum_offset + modified_indices["padding_offset"].append(idx) + + cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi] + cu_seqlens_q[bi + 1] = cum_seq_len + cu_seqlens_k[bi + 1] = cum_seq_len + modified_indices["cu_seqlens_q"].append(bi + 1) + modified_indices["cu_seqlens_k"].append(bi + 1) + + return ( + padding_offset, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + modified_indices, + ) + + +def test_speculate_get_padding_offset(): + global test_failed + print("Testing speculate_get_padding_offset...") + + test_cases = [ + { + "name": "Basic test case", + "bsz": 4, + "max_seq_len": 10, + "token_num_data": 32, + "cum_offsets": np.array([2, 5, 8, 12], dtype=np.int32), + "seq_lens": np.array([8, 5, 7, 6], dtype=np.int32), + "seq_lens_encoder": np.array([1, 0, 1, 0], dtype=np.int32), + }, + { + "name": "Batch copy optimization", + "bsz": 5, + "max_seq_len": 12, + "token_num_data": 50, + "cum_offsets": np.array([1, 4, 8, 13, 19], dtype=np.int32), + "seq_lens": np.array([10, 6, 8, 5, 7], dtype=np.int32), + "seq_lens_encoder": np.array([1, 0, 1, 0, 1], dtype=np.int32), + }, + { + "name": "Boundary conditions", + "bsz": 3, + "max_seq_len": 8, + "token_num_data": 20, + "cum_offsets": np.array([3, 8, 14], dtype=np.int32), + "seq_lens": np.array([4, 3, 2], dtype=np.int32), + "seq_lens_encoder": np.array([1, 0, 1], dtype=np.int32), + }, + { + "name": "Large sequence length", + "bsz": 2, + "max_seq_len": 2000, + "token_num_data": 3000, + "cum_offsets": np.array([100, 500], dtype=np.int32), + "seq_lens": np.array([1800, 1500], dtype=np.int32), + "seq_lens_encoder": np.array([1, 0], dtype=np.int32), + }, + ] + + max_draft_tokens = 4 + all_passed = True + + for i, case in enumerate(test_cases): + print(f" Test case {i+1}: {case['name']}") + + input_ids = np.random.randint(0, 1000, (case["bsz"], case["max_seq_len"]), dtype=np.int64) + draft_tokens = np.random.randint(0, 1000, (case["bsz"], max_draft_tokens), dtype=np.int64) + token_num = np.array([case["token_num_data"]], dtype=np.int64) + + input_ids_tensor = paddle.to_tensor(input_ids) + draft_tokens_tensor = paddle.to_tensor(draft_tokens) + cum_offsets_tensor = paddle.to_tensor(case["cum_offsets"]) + seq_lens_tensor = paddle.to_tensor(case["seq_lens"]) + seq_lens_encoder_tensor = paddle.to_tensor(case["seq_lens_encoder"]) + token_num_tensor = paddle.to_tensor(token_num) + + ( + x_remove_padding, + cum_offsets_out, + padding_offset, + cu_seqlens_q, + cu_seqlens_k, + ) = speculate_get_padding_offset( + input_ids_tensor, + draft_tokens_tensor, + cum_offsets_tensor, + token_num_tensor, + seq_lens_tensor, + seq_lens_encoder_tensor, + ) + + ( + ref_padding_offset, + ref_cum_offsets_out, + ref_cu_seqlens_q, + ref_cu_seqlens_k, + modified_indices, + ) = ref_speculate_get_padding_offset( + case["cum_offsets"], + case["seq_lens"], + case["max_seq_len"], + case["token_num_data"], + ) + + output_arrays = { + "padding_offset": padding_offset.numpy(), + "cum_offsets_out": cum_offsets_out.numpy(), + "cu_seqlens_q": cu_seqlens_q.numpy(), + "cu_seqlens_k": cu_seqlens_k.numpy(), + } + + ref_arrays = { + "padding_offset": ref_padding_offset, + "cum_offsets_out": ref_cum_offsets_out, + "cu_seqlens_q": ref_cu_seqlens_q, + "cu_seqlens_k": ref_cu_seqlens_k, + } + + case_passed = True + for key in output_arrays: + modified_pos = modified_indices[key] + if case["name"] == "Large sequence length" and key == "padding_offset": + match_count = sum(1 for pos in modified_pos if output_arrays[key][pos] == ref_arrays[key][pos]) + total_positions = len(modified_pos) + if match_count != total_positions: + case_passed = False + print(f" \033[91m✗ {key}: {match_count}/{total_positions} positions match\033[0m") + else: + print(f" \033[92m✓ {key}: All {total_positions} positions match\033[0m") + else: + match_count = sum(1 for pos in modified_pos if output_arrays[key][pos] == ref_arrays[key][pos]) + if match_count != len(modified_pos): + case_passed = False + print(f" \033[91m✗ {key}: {match_count}/{len(modified_pos)} positions match\033[0m") + else: + print(f" \033[92m✓ {key}: {match_count}/{len(modified_pos)} positions match\033[0m") + + if not case_passed: + all_passed = False + test_failed = True + + if all_passed: + print("\033[92m✓ All speculate_get_padding_offset tests passed\033[0m\n") + else: + print("\033[91m✗ Some speculate_get_padding_offset tests failed\033[0m\n") + + +def test_speculate_get_padding_offset_edge_cases(): + global test_failed + print("Testing speculate_get_padding_offset edge cases...") + + print("Test case 1: Single batch") + bsz = 1 + max_seq_len = 10 + token_num_data = 10 + max_draft_tokens = 3 + + input_ids = np.random.randint(0, 1000, (bsz, max_seq_len), dtype=np.int64) + draft_tokens = np.random.randint(0, 1000, (bsz, max_draft_tokens), dtype=np.int64) + cum_offsets = np.array([3], dtype=np.int32) + seq_lens = np.array([7], dtype=np.int32) + seq_lens_encoder = np.array([1], dtype=np.int32) + token_num = np.array([token_num_data], dtype=np.int64) + + input_ids_tensor = paddle.to_tensor(input_ids) + draft_tokens_tensor = paddle.to_tensor(draft_tokens) + cum_offsets_tensor = paddle.to_tensor(cum_offsets) + seq_lens_tensor = paddle.to_tensor(seq_lens) + seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder) + token_num_tensor = paddle.to_tensor(token_num) + + try: + ( + x_remove_padding, + cum_offsets_out, + padding_offset, + cu_seqlens_q, + cu_seqlens_k, + ) = speculate_get_padding_offset( + input_ids_tensor, + draft_tokens_tensor, + cum_offsets_tensor, + token_num_tensor, + seq_lens_tensor, + seq_lens_encoder_tensor, + ) + print( + f"\033[92m✓ Test case 1 passed, shapes: {[x.shape for x in [x_remove_padding, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k]]}\033[0m" + ) + except Exception as e: + print(f"\033[91m✗ Test case 1 failed: {e}\033[0m") + test_failed = True + + print("Test case 2: Large batch") + bsz = 8 + max_seq_len = 16 + token_num_data = 100 + + input_ids = np.random.randint(0, 1000, (bsz, max_seq_len), dtype=np.int64) + draft_tokens = np.random.randint(0, 1000, (bsz, max_draft_tokens), dtype=np.int64) + cum_offsets = np.array([1, 3, 6, 10, 15, 21, 28, 36], dtype=np.int32) + seq_lens = np.random.randint(1, max_seq_len, bsz).astype(np.int32) + seq_lens_encoder = np.random.randint(0, 2, bsz).astype(np.int32) + token_num = np.array([token_num_data], dtype=np.int64) + + input_ids_tensor = paddle.to_tensor(input_ids) + draft_tokens_tensor = paddle.to_tensor(draft_tokens) + cum_offsets_tensor = paddle.to_tensor(cum_offsets) + seq_lens_tensor = paddle.to_tensor(seq_lens) + seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder) + token_num_tensor = paddle.to_tensor(token_num) + + try: + ( + x_remove_padding, + cum_offsets_out, + padding_offset, + cu_seqlens_q, + cu_seqlens_k, + ) = speculate_get_padding_offset( + input_ids_tensor, + draft_tokens_tensor, + cum_offsets_tensor, + token_num_tensor, + seq_lens_tensor, + seq_lens_encoder_tensor, + ) + print( + f"\033[92m✓ Test case 2 passed, shapes: {[x.shape for x in [x_remove_padding, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k]]}\033[0m" + ) + except Exception as e: + print(f"\033[91m✗ Test case 2 failed: {e}\033[0m") + test_failed = True + + print("Test case 3: Small sequences") + bsz = 3 + max_seq_len = 5 + token_num_data = 12 + + input_ids = np.random.randint(0, 1000, (bsz, max_seq_len), dtype=np.int64) + draft_tokens = np.random.randint(0, 1000, (bsz, max_draft_tokens), dtype=np.int64) + cum_offsets = np.array([1, 2, 4], dtype=np.int32) + seq_lens = np.array([2, 3, 1], dtype=np.int32) + seq_lens_encoder = np.array([1, 0, 1], dtype=np.int32) + token_num = np.array([token_num_data], dtype=np.int64) + + input_ids_tensor = paddle.to_tensor(input_ids) + draft_tokens_tensor = paddle.to_tensor(draft_tokens) + cum_offsets_tensor = paddle.to_tensor(cum_offsets) + seq_lens_tensor = paddle.to_tensor(seq_lens) + seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder) + token_num_tensor = paddle.to_tensor(token_num) + + try: + ( + x_remove_padding, + cum_offsets_out, + padding_offset, + cu_seqlens_q, + cu_seqlens_k, + ) = speculate_get_padding_offset( + input_ids_tensor, + draft_tokens_tensor, + cum_offsets_tensor, + token_num_tensor, + seq_lens_tensor, + seq_lens_encoder_tensor, + ) + print( + f"\033[92m✓ Test case 3 passed, shapes: {[x.shape for x in [x_remove_padding, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k]]}\033[0m\n" + ) + except Exception as e: + print(f"\033[91m✗ Test case 3 failed: {e}\033[0m\n") + test_failed = True + + +def test_large_scale(): + global test_failed + print("Testing large scale data...") + + bsz = 32 + max_seq_len = 128 + token_num_data = 2048 + max_draft_tokens = 16 + + input_ids = np.random.randint(0, 1000, (bsz, max_seq_len), dtype=np.int64) + draft_tokens = np.random.randint(0, 1000, (bsz, max_draft_tokens), dtype=np.int64) + cum_offsets = np.cumsum(np.random.randint(1, 20, bsz)).astype(np.int32) + seq_lens = np.random.randint(1, max_seq_len, bsz).astype(np.int32) + seq_lens_encoder = np.random.randint(0, 2, bsz).astype(np.int32) + token_num = np.array([token_num_data], dtype=np.int64) + + input_ids_tensor = paddle.to_tensor(input_ids) + draft_tokens_tensor = paddle.to_tensor(draft_tokens) + cum_offsets_tensor = paddle.to_tensor(cum_offsets) + seq_lens_tensor = paddle.to_tensor(seq_lens) + seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder) + token_num_tensor = paddle.to_tensor(token_num) + + try: + ( + x_remove_padding, + cum_offsets_out, + padding_offset, + cu_seqlens_q, + cu_seqlens_k, + ) = speculate_get_padding_offset( + input_ids_tensor, + draft_tokens_tensor, + cum_offsets_tensor, + token_num_tensor, + seq_lens_tensor, + seq_lens_encoder_tensor, + ) + print("\033[92m✓ Large scale speculate_get_padding_offset test passed\033[0m") + print( + f"\033[92m Shapes: {[x.shape for x in [x_remove_padding, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k]]}\033[0m\n" + ) + except Exception as e: + print(f"\033[91m✗ Large scale speculate_get_padding_offset test failed: {e}\033[0m\n") + test_failed = True + + +def get_modified_indices_for_consistency_test(cum_offsets, seq_lens, max_seq_len, token_num_data): + bsz = seq_lens.shape[0] + + modified_indices = { + "x_remove_padding": [], + "padding_offset": [], + "cum_offsets_out": [], + "cu_seqlens_q": [], + "cu_seqlens_k": [], + } + + for bi in range(bsz): + modified_indices["cum_offsets_out"].append(bi) + + for i in range(bsz + 1): + modified_indices["cu_seqlens_q"].append(i) + modified_indices["cu_seqlens_k"].append(i) + + for bi in range(bsz): + cum_offset = 0 if bi == 0 else cum_offsets[bi - 1] + for i in range(seq_lens[bi]): + padding_idx = bi * max_seq_len - cum_offset + i + if padding_idx >= 0 and padding_idx < token_num_data: + modified_indices["padding_offset"].append(padding_idx) + + remove_padding_idx = bi * max_seq_len - cum_offsets[bi] + i + if remove_padding_idx >= 0 and remove_padding_idx < token_num_data: + modified_indices["x_remove_padding"].append(remove_padding_idx) + + return modified_indices + + +def test_consistency(): + global test_failed + print("Testing consistency...") + + np.random.seed(42) + + bsz = 4 + max_seq_len = 8 + token_num_data = 24 + max_draft_tokens = 3 + + input_ids = np.random.randint(0, 1000, (bsz, max_seq_len), dtype=np.int64) + draft_tokens = np.random.randint(0, 1000, (bsz, max_draft_tokens), dtype=np.int64) + cum_offsets = np.array([1, 3, 6, 10], dtype=np.int32) + seq_lens = np.array([6, 4, 5, 3], dtype=np.int32) + seq_lens_encoder = np.array([1, 0, 1, 0], dtype=np.int32) + token_num = np.array([token_num_data], dtype=np.int64) + + input_ids_tensor = paddle.to_tensor(input_ids) + draft_tokens_tensor = paddle.to_tensor(draft_tokens) + cum_offsets_tensor = paddle.to_tensor(cum_offsets) + seq_lens_tensor = paddle.to_tensor(seq_lens) + seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder) + token_num_tensor = paddle.to_tensor(token_num) + + modified_indices = get_modified_indices_for_consistency_test(cum_offsets, seq_lens, max_seq_len, token_num_data) + + print("Checking consistency for modified positions only:") + for key, indices in modified_indices.items(): + print(f" {key}: {len(indices)} positions") + + results = [] + for run in range(3): + ( + x_remove_padding, + cum_offsets_out, + padding_offset, + cu_seqlens_q, + cu_seqlens_k, + ) = speculate_get_padding_offset( + input_ids_tensor, + draft_tokens_tensor, + cum_offsets_tensor, + token_num_tensor, + seq_lens_tensor, + seq_lens_encoder_tensor, + ) + results.append( + [ + x_remove_padding.numpy(), + cum_offsets_out.numpy(), + padding_offset.numpy(), + cu_seqlens_q.numpy(), + cu_seqlens_k.numpy(), + ] + ) + + output_names = [ + "x_remove_padding", + "cum_offsets_out", + "padding_offset", + "cu_seqlens_q", + "cu_seqlens_k", + ] + consistent = True + for j, name in enumerate(output_names): + indices = modified_indices[name] if name in modified_indices else [] + + if not indices: + print(f"\033[93m ~ {name}: No modified indices to check\033[0m") + continue + + positions_consistent = True + + for i in range(1, len(results)): + for idx in indices: + if results[0][j][idx] != results[i][j][idx]: + consistent = False + positions_consistent = False + print( + f"\033[91m ✗ {name}[{idx}]: Run 1 = {results[0][j][idx]}, Run {i+1} = {results[i][j][idx]}\033[0m" + ) + break + if not positions_consistent: + break + + if positions_consistent: + print(f"\033[92m ✓ {name}: All {len(indices)} modified positions are consistent\033[0m") + + if consistent: + print( + "\033[92m✓ Consistency test passed - results are identical across runs (modified positions only)\033[0m\n" + ) + else: + print("\033[91m✗ Consistency test failed - some modified positions are inconsistent\033[0m\n") + print("Note: This test now only compares positions that the kernel actually modifies,") + print(" ignoring uninitialized values in other positions.\n") + test_failed = True + + +if __name__ == "__main__": + print("=" * 60) + print("Testing Speculate Get Padding Offset Kernels") + print("=" * 60) + + test_speculate_get_padding_offset() + test_speculate_get_padding_offset_edge_cases() + test_large_scale() + test_consistency() + + print("=" * 60) + if test_failed: + print("\033[91mSOME TESTS FAILED! \033[0m") + print("\033[91mPlease check the output above for failed test details.\033[0m") + else: + print("\033[92mALL TESTS PASSED! \033[0m") + print("\033[92mAll speculate_get_padding_offset kernels are working correctly.\033[0m") + print("=" * 60) diff --git a/custom_ops/xpu_ops/test/test_speculate_get_seq_lens_output.py b/custom_ops/xpu_ops/test/test_speculate_get_seq_lens_output.py new file mode 100644 index 000000000..9e948cd73 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_get_seq_lens_output.py @@ -0,0 +1,105 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import speculate_get_seq_lens_output # 假设已编译并导入 + + +def run_seq_lens_test(device="cpu"): + """运行序列长度测试函数""" + paddle.seed(42) + np.random.seed(42) + + if device == "cpu": + paddle.set_device(device) + elif device == "xpu": + paddle.set_device(device) + else: + raise ValueError(f"Invalid device: {device}") + + # 创建不同尺寸的随机测试数据 + batch_sizes = [1, 4, 16, 64, 128, 192, 256] + results = [] + test_times = 100 + for _ in range(test_times): + for bsz in batch_sizes: + # 生成随机输入张量 + seq_lens_this_time = paddle.randint(0, 10, shape=(bsz,), dtype="int32") + seq_lens_encoder = paddle.randint(0, 10, shape=(bsz,), dtype="int32") + seq_lens_decoder = paddle.randint(0, 10, shape=(bsz,), dtype="int32") + + # 记录输入值用于调试 + input_values = [ + seq_lens_this_time.numpy().copy(), + seq_lens_encoder.numpy().copy(), + seq_lens_decoder.numpy().copy(), + ] + # 运行算子 + seq_lens_output = speculate_get_seq_lens_output(seq_lens_this_time, seq_lens_encoder, seq_lens_decoder)[0] + + # 收集结果 + results.append((input_values, seq_lens_output.numpy())) + + return results + + +if __name__ == "__main__": + print("\n运行XPU测试...") + xpu_results = run_seq_lens_test("xpu") + + print("运行CPU测试...") + cpu_results = run_seq_lens_test("cpu") + + print("\n比较结果...") + all_pass = True + + # 逐个批次比较结果 + for i, (cpu_data, xpu_data) in enumerate(zip(cpu_results, xpu_results)): + # 解包数据 + cpu_inputs, cpu_output = cpu_data + xpu_inputs, xpu_output = xpu_data + + # 比较输入数据是否相同 + for j in range(3): + if not np.array_equal(cpu_inputs[j], xpu_inputs[j]): + print(f"错误: 批次 #{i+1} 输入 {j} 不同 (CPU vs XPU)") + print(f"CPU输入: {cpu_inputs[j]}") + print(f"XPU输入: {xpu_inputs[j]}") + all_pass = False + + # 比较输出结果是否相同 + if not np.array_equal(cpu_output, xpu_output): + print(f"\n错误: 批次 #{i+1} 输出不同 (CPU vs XPU)") + print(f"CPU输出: {cpu_output}") + print(f"XPU输出: {xpu_output}") + + # 打印差异详情 + diff_indices = np.where(cpu_output != xpu_output)[0] + for idx in diff_indices: + print(f"索引 {idx}: CPU输出={cpu_output[idx]}, XPU输出={xpu_output[idx]}") + print( + f"对应输入: this_time={cpu_inputs[0][idx]}, " + f"encoder={cpu_inputs[1][idx]}, decoder={cpu_inputs[2][idx]}" + ) + all_pass = False + else: + print(f"批次 #{i+1} 结果匹配") + + if all_pass: + print("\n所有测试通过! CPU和XPU结果完全一致") + else: + print("\n测试失败: 发现不一致的结果") + exit(1) diff --git a/custom_ops/xpu_ops/test/test_speculate_get_token_penalty_multi_scores.py b/custom_ops/xpu_ops/test/test_speculate_get_token_penalty_multi_scores.py new file mode 100644 index 000000000..e1dd9d3cb --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_get_token_penalty_multi_scores.py @@ -0,0 +1,206 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import speculate_get_token_penalty_multi_scores + +paddle.seed(2023) + + +def allclose_any(a, b, rtol=1e-5, atol=1e-5, equal_nan=False): + """检查两个数组是否满足任意一个容差条件""" + condition = (np.abs(a - b) <= atol) | (np.abs(a - b) <= rtol * np.abs(b)) # 绝对误差条件 # 相对误差条件 + print(f"cond={condition}") + # 处理 NaN(如果需要) + if equal_nan: + nan_mask = np.isnan(a) & np.isnan(b) + condition = condition | nan_mask + # 检查所有元素是否都满足条件 + return np.all(condition) + + +def find_max_diff(arr1, arr2): + """找出两个数组元素差值的最大值及其索引 + 返回: + max_diff (float): 最大绝对值差 + index (tuple): 最大值的位置索引 + actual_diff (float): 实际差值(带符号) + """ + diff = arr1 - arr2 + abs_diff = np.abs(diff) + flat_idx = np.argmax(abs_diff) + idx = np.unravel_index(flat_idx, arr1.shape) + return abs_diff[idx], idx, diff[idx], arr1[idx], arr2[idx] + + +def test_main( + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id, + seq_len_this_time, + output_padding_offset, + output_cum_offsets, + max_seq_len, +): + pre_ids_ref = pre_ids.cpu() + logits_ref = logits.cpu() + penalty_scores_ref = penalty_scores.cpu() + frequency_scores_ref = frequency_scores.cpu() + presence_scores_ref = presence_scores.cpu() + temperatures_ref = temperatures.cpu() + bad_tokens_ref = bad_tokens.cpu() + cur_len_ref = cur_len.cpu() + min_len_ref = min_len.cpu() + eos_token_id_ref = eos_token_id.cpu() + seq_len_this_time_ref = seq_len_this_time.cpu() + output_padding_offset_ref = output_padding_offset.cpu() + output_cum_offsets_ref = output_cum_offsets.cpu() + + speculate_get_token_penalty_multi_scores( + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id, + seq_len_this_time, + output_padding_offset, + output_cum_offsets, + max_seq_len, + ) + speculate_get_token_penalty_multi_scores( + pre_ids_ref, + logits_ref, + penalty_scores_ref, + frequency_scores_ref, + presence_scores_ref, + temperatures_ref, + bad_tokens_ref, + cur_len_ref, + min_len_ref, + eos_token_id_ref, + seq_len_this_time_ref, + output_padding_offset_ref, + output_cum_offsets_ref, + max_seq_len, + ) + logits_ref_np = logits_ref.astype("float32").numpy() + logits_np = logits.astype("float32").numpy() + np.set_printoptions(threshold=10000) + # print(f"logits_ref={logits_ref_np[:50,:100]}") + # print(f"logits={logits_np[:50,:100]}") + + diff_logits = np.sum(np.abs(logits_ref_np - logits_np)) + print("diff_logits\n", diff_logits) + abs_diff, idx, diff, val1, val2 = find_max_diff(logits_ref_np, logits_np) + print(f"abs_diff={abs_diff}, index={idx}, diff={diff}, {val1} vs {val2}") + assert allclose_any(logits_ref_np, logits_np, 1e-5, 1e-5) + # assert np.allclose(logits_ref_np, logits_np, 1e-5, 1e-5) + + +# gtest_speculate_token_penalty_multi_scores(api::kXPU3, "GM", "GM", "GM", "GM", "GM", "GM", "GM", "GM", "GM", "GM", "GM", "GM", +# 84, 100352, 12288, 1, 1, 54, 32768); + + +def miain(): + seed = np.random.randint(1, 1e9) + print(f"random seed is {seed}") + np.random.seed(seed) + + bs = 64 + max_seq_len = 32768 # 1024 #2048 #8192 + data_type = "float32" # bfloat16 or float32 + + # prepare output_padding_offset and output_cum_offsets + tokens = [1] * bs + token_num = np.sum(tokens) + print(f"bs={bs}, tokens={tokens}, token_num={token_num}") + output_padding_offset = [] + output_cum_offsets = [0] + opo_offset = 0 + for bid in range(bs): + ts = tokens[bid] + for i in range(ts): + output_padding_offset.append(opo_offset) + opo_offset += max_seq_len - ts + output_cum_offsets.append(opo_offset) + output_cum_offsets = output_cum_offsets[:-1] + # print(f"output_padding_offset={output_padding_offset}") + # print(f"output_cum_offsets={output_cum_offsets}") + output_padding_offset = paddle.to_tensor(output_padding_offset, "int32") + output_cum_offsets = paddle.to_tensor(output_cum_offsets, "int32") + + # prepare pre_ids and logits + pre_ids_len = 12288 + # pre_ids_len = np.random.randint(1, 512) + logits_len = 100352 + # print(f"pre_ids_len={pre_ids_len}, logits_len={logits_len}") + pre_ids = np.random.randint(1, logits_len, size=(bs, pre_ids_len)) + negative_start = np.random.randint(1, pre_ids_len + 1, size=(bs)) + print(negative_start) + for i in range(bs): + pre_ids[:, negative_start[i] :] = -1 + pre_ids = paddle.to_tensor(pre_ids).astype("int64") + # logits = paddle.to_tensor( + # np.float32(np.random.random([token_num, logits_len])) + # ).astype(data_type) + logits = paddle.to_tensor(np.float32(np.zeros([token_num, logits_len]))).astype(data_type) + # prepare other params + penalty_scores = paddle.to_tensor(np.random.random([bs])).astype(data_type) + frequency_scores = paddle.to_tensor(np.random.random([bs])).astype(data_type) + presence_scores = paddle.to_tensor(np.random.random([bs])).astype(data_type) + temperatures = paddle.to_tensor(np.random.random([bs])).astype("float32") + bad_tokens = paddle.to_tensor(np.random.randint(0, 101, size=(1))).astype("int64") + cur_len = paddle.to_tensor(np.random.randint(1, 50, size=(bs))).astype("int64") + min_len = paddle.to_tensor(np.random.randint(1, 50, size=(bs))).astype("int64") + eos_token_id = paddle.to_tensor(np.random.randint(1, 101, size=(1))).astype("int64") + seq_len_this_time = paddle.to_tensor( + np.random.randint(0, 1, size=(bs)), "int32" + ) # value of seq_len_this_time is useless + + # test + test_main( + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id, + seq_len_this_time, + output_padding_offset, + output_cum_offsets, + max_seq_len, + ) + + +if __name__ == "__main__": + for i in range(10): + miain() diff --git a/custom_ops/xpu_ops/test/test_speculate_rebuild_append_padding.py b/custom_ops/xpu_ops/test/test_speculate_rebuild_append_padding.py new file mode 100644 index 000000000..9def0b9a5 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_rebuild_append_padding.py @@ -0,0 +1,132 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import speculate_rebuild_append_padding + + +def ref_speculate_rebuild_append_padding( + full_hidden_states, + cum_offsets, + seq_len_encoder, + seq_len_decoder, + output_padding_offset, + max_seq_len, +): + dim_embed = full_hidden_states.shape[1] + output_token_num = output_padding_offset.shape[0] + elem_nums = output_token_num * dim_embed + + out = np.zeros(output_token_num * dim_embed, dtype=full_hidden_states.dtype) + full_hidden_states_flatten = full_hidden_states.flatten() + cum_offsets_flatten = cum_offsets.flatten() + seq_len_encoder_flatten = seq_len_encoder.flatten() + seq_len_decoder_flatten = seq_len_decoder.flatten() + output_padding_offset_flatten = output_padding_offset.flatten() + + for i in range(elem_nums): + out_token_id = i // dim_embed + ori_token_id = out_token_id + output_padding_offset_flatten[out_token_id] + bi = ori_token_id // max_seq_len + + seq_id = 0 + if seq_len_decoder_flatten[bi] == 0 and seq_len_encoder_flatten[bi] == 0: + continue + elif seq_len_encoder_flatten[bi] != 0: + seq_id = seq_len_encoder[bi] - 1 + + input_token_id = ori_token_id - cum_offsets_flatten[bi] + seq_id + bias_idx = i % dim_embed + + out[i] = full_hidden_states_flatten[input_token_id * dim_embed + bias_idx] + out = np.reshape(out, (output_token_num, dim_embed)) + return out + + +def test_speculate_rebuild_append_padding(): + bs = np.random.randint(1, 4 + 1, dtype=np.int32) + max_seq_len = 1 * 1024 + dim_embed = np.random.randint(1, 4 * 1024 + 1, dtype=np.int32) + seq_lens = [] + for _ in range(bs): + seq_lens.append(np.random.randint(1, max_seq_len + 1, dtype=np.int32)) + seq_lens = np.asarray(seq_lens) + cum_offsets = np.cumsum(np.asarray(max_seq_len) - seq_lens) + cum_offsets = np.insert(cum_offsets, 0, 0) + output_padding_offsets = [] + for i in range(bs): + offset = cum_offsets[i] + for j in range(seq_lens[i]): + output_padding_offsets.append(offset) + output_padding_offsets = np.asarray(output_padding_offsets) + # TODO: seq_len_encoder with non-zero element + seq_len_decoder = np.random.randint(0, 2 + 1, bs, dtype=np.int32) + seq_len_encoder_zeros = np.zeros(bs, dtype=np.int32) + + for dtype in [paddle.bfloat16, paddle.float16]: + full_hidden_states = np.random.randint(0, 10, (np.sum(seq_lens), dim_embed), dtype=np.int16) + full_hidden_states_tensor = paddle.to_tensor(full_hidden_states, dtype=dtype) + cum_offsets_tensor = paddle.to_tensor(cum_offsets, dtype=paddle.int32) + seq_len_encoder_zeros_tensor = paddle.to_tensor(seq_len_encoder_zeros, dtype=paddle.int32) + seq_len_decoder_tensor = paddle.to_tensor(seq_len_decoder, dtype=paddle.int32) + output_padding_offsets_tensor = paddle.to_tensor(output_padding_offsets, dtype=paddle.int32) + cpu_out = speculate_rebuild_append_padding( + full_hidden_states_tensor.cpu(), + cum_offsets_tensor.cpu(), + seq_len_encoder_zeros_tensor.cpu(), + seq_len_decoder_tensor.cpu(), + output_padding_offsets_tensor.cpu(), + max_seq_len, + ) + xpu_out = speculate_rebuild_append_padding( + full_hidden_states_tensor, + cum_offsets_tensor, + seq_len_encoder_zeros_tensor, + seq_len_decoder_tensor, + output_padding_offsets_tensor, + max_seq_len, + ) + assert np.allclose(cpu_out.numpy(), xpu_out.numpy()) + for dtype in [paddle.float32]: + full_hidden_states = np.random.randint(0, 10, (np.sum(seq_lens), dim_embed), dtype=np.int32) + full_hidden_states_tensor = paddle.to_tensor(full_hidden_states, dtype=dtype) + cum_offsets_tensor = paddle.to_tensor(cum_offsets, dtype=paddle.int32) + seq_len_encoder_zeros_tensor = paddle.to_tensor(seq_len_encoder_zeros, dtype=paddle.int32) + seq_len_decoder_tensor = paddle.to_tensor(seq_len_decoder, dtype=paddle.int32) + output_padding_offsets_tensor = paddle.to_tensor(output_padding_offsets, dtype=paddle.int32) + cpu_out = speculate_rebuild_append_padding( + full_hidden_states_tensor.cpu(), + cum_offsets_tensor.cpu(), + seq_len_encoder_zeros_tensor.cpu(), + seq_len_decoder_tensor.cpu(), + output_padding_offsets_tensor.cpu(), + max_seq_len, + ) + xpu_out = speculate_rebuild_append_padding( + full_hidden_states_tensor, + cum_offsets_tensor, + seq_len_encoder_zeros_tensor, + seq_len_decoder_tensor, + output_padding_offsets_tensor, + max_seq_len, + ) + assert np.allclose(cpu_out.numpy(), xpu_out.numpy()) + + print("All test passed") + + +if __name__ == "__main__": + test_speculate_rebuild_append_padding() diff --git a/custom_ops/xpu_ops/test/test_speculate_set_stop_value_multi_seqs.py b/custom_ops/xpu_ops/test/test_speculate_set_stop_value_multi_seqs.py new file mode 100644 index 000000000..9c8f5ff28 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_set_stop_value_multi_seqs.py @@ -0,0 +1,307 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import speculate_set_stop_value_multi_seqs + + +def compare_results(cpu_results, xpu_results): + # Compare all outputs + for key in cpu_results: + if key in ["output_accept_tokens", "output_stop_flags"]: + np.testing.assert_array_equal( + cpu_results[key], + xpu_results[key], + err_msg=f"{key} mismatch between CPU and GPU", + ) + print("CPU and GPU results match!") + + +class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): + def setUp(self): + self.place = paddle.device.XPUPlace(0) + + def run_op( + self, + device, + accept_tokens, + accept_num, + pre_ids, + step_idx, + stop_flags, + seq_lens, + stop_seqs, + stop_seqs_len, + end_ids, + ): + if device == "cpu": + accept_tokens = accept_tokens.cpu() + accept_num = accept_num.cpu() + pre_ids = pre_ids.cpu() + step_idx = step_idx.cpu() + stop_flags = stop_flags.cpu() + seq_lens = seq_lens.cpu() + stop_seqs = stop_seqs.cpu() + stop_seqs_len = stop_seqs_len.cpu() + end_ids = end_ids.cpu() + + accept_tokens_out = accept_tokens.clone() + stop_flags_out = stop_flags.clone() + speculate_set_stop_value_multi_seqs( + accept_tokens_out, + accept_num, + pre_ids, + step_idx, + stop_flags_out, + seq_lens, + stop_seqs, + stop_seqs_len, + end_ids, + ) + + # Return results for comparison + results = { + "accept_tokens": accept_tokens.numpy(), + "accept_num": accept_num.numpy(), + "pre_ids": pre_ids.numpy(), + "step_idx": step_idx.numpy(), + "stop_flags": stop_flags.numpy(), + "output_accept_tokens": accept_tokens_out.numpy(), + "output_stop_flags": stop_flags_out.numpy(), + } + return results + + def test_basic_functionality(self): + # Test basic functionality with one sequence matching stop sequence + import paddle + + accept_tokens = paddle.to_tensor( + [ + [4, 5, 0, 0, 0], # batch 0 + [1, 2, 3, 0, 0], # batch 1 (不匹配) + ], + dtype="int64", + ) + + accept_num = paddle.to_tensor([3, 4], dtype="int32") + + pre_ids = paddle.to_tensor( + [ + [7, 8, 9, 3, 4, 5], # batch 0 + [7, 8, 9, 1, 2, 3], # batch 1 + ], + dtype="int64", + ) + + step_idx = paddle.to_tensor([6, 6], dtype="int64") # pre_ids最后一位为下标5 + + stop_flags = paddle.to_tensor([False, False], dtype="bool") + seq_lens = paddle.to_tensor([6, 6], dtype="int32") + stop_seqs = paddle.to_tensor( + [ + [3, 4, 5], # batch 0 + [0, 0, 0], # batch 1 + ], + dtype="int64", + ) + stop_seqs_len = paddle.to_tensor([3, 0], dtype="int32") + end_ids = paddle.to_tensor([-1], dtype="int64") + # Run operator + xpu_results = self.run_op( + "xpu", + accept_tokens, + accept_num, + pre_ids, + step_idx, + stop_flags, + seq_lens, + stop_seqs, + stop_seqs_len, + end_ids, + ) + cpu_results = self.run_op( + "cpu", + accept_tokens, + accept_num, + pre_ids, + step_idx, + stop_flags, + seq_lens, + stop_seqs, + stop_seqs_len, + end_ids, + ) + compare_results(cpu_results, xpu_results) + + # Verify results + expected_accept_tokens = np.array([[4, 5, -1, 0, 0], [1, 2, 3, 0, 0]]) + expected_stop_flags = np.array([True, False]) + + np.testing.assert_array_equal(xpu_results["output_accept_tokens"], expected_accept_tokens) + np.testing.assert_array_equal(xpu_results["output_stop_flags"], expected_stop_flags) + + def test_no_match(self): + # Test case where no stop sequence matches + # Input tensors + accept_tokens = paddle.to_tensor( + [[10, 20, 30, 0, 0], [40, 50, 60, 0, 0]], + dtype="int64", + place=self.place, + ) + accept_num = paddle.to_tensor([3, 3], dtype="int32", place=self.place) + pre_ids = paddle.to_tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], dtype="int64", place=self.place) + step_idx = paddle.to_tensor([8, 8], dtype="int64", place=self.place) + stop_flags = paddle.to_tensor([False, False], dtype="bool", place=self.place) + seq_lens = paddle.to_tensor([10, 10], dtype="int32", place=self.place) + + # Stop sequences that don't match + stop_seqs = paddle.to_tensor([[11, 12, 13], [14, 15, 16]], dtype="int64", place=self.place) + stop_seqs_len = paddle.to_tensor([3, 3], dtype="int32", place=self.place) + end_ids = paddle.to_tensor([-1], dtype="int64", place=self.place) + + # Run operator + xpu_results = self.run_op( + "xpu", + accept_tokens, + accept_num, + pre_ids, + step_idx, + stop_flags, + seq_lens, + stop_seqs, + stop_seqs_len, + end_ids, + ) + cpu_results = self.run_op( + "cpu", + accept_tokens, + accept_num, + pre_ids, + step_idx, + stop_flags, + seq_lens, + stop_seqs, + stop_seqs_len, + end_ids, + ) + compare_results(cpu_results, xpu_results) + + # Verify nothing changed + + np.testing.assert_array_equal(xpu_results["output_accept_tokens"], accept_tokens.numpy()) + np.testing.assert_array_equal(xpu_results["output_stop_flags"], stop_flags.numpy()) + + def test_partial_match(self): + # Test case where only part of the sequence matches + # Input tensors + accept_tokens = paddle.to_tensor([[10, 20, 30, 0, 0]], dtype="int64", place=self.place) + accept_num = paddle.to_tensor([3], dtype="int32", place=self.place) + pre_ids = paddle.to_tensor([[1, 2, 3, 4, 5]], dtype="int64", place=self.place) + step_idx = paddle.to_tensor([8], dtype="int64", place=self.place) + stop_flags = paddle.to_tensor([False], dtype="bool", place=self.place) + seq_lens = paddle.to_tensor([10], dtype="int32", place=self.place) + + # Stop sequence that partially matches + stop_seqs = paddle.to_tensor( + [[5, 4, 99]], # Only 5,4 matches (from pre_ids), 99 doesn't + dtype="int64", + place=self.place, + ) + stop_seqs_len = paddle.to_tensor([3], dtype="int32", place=self.place) + end_ids = paddle.to_tensor([-1], dtype="int64", place=self.place) + + # Run operator + xpu_results = self.run_op( + "xpu", + accept_tokens, + accept_num, + pre_ids, + step_idx, + stop_flags, + seq_lens, + stop_seqs, + stop_seqs_len, + end_ids, + ) + cpu_results = self.run_op( + "cpu", + accept_tokens, + accept_num, + pre_ids, + step_idx, + stop_flags, + seq_lens, + stop_seqs, + stop_seqs_len, + end_ids, + ) + compare_results(cpu_results, xpu_results) + + # Verify nothing changed + np.testing.assert_array_equal(xpu_results["output_accept_tokens"], accept_tokens.numpy()) + np.testing.assert_array_equal(xpu_results["output_stop_flags"], stop_flags.numpy()) + + def test_already_stopped(self): + # Test case where sequence is already stopped + # Input tensors + accept_tokens = paddle.to_tensor([[10, 20, 30, 0, 0]], dtype="int64", place=self.place) + accept_num = paddle.to_tensor([3], dtype="int32", place=self.place) + pre_ids = paddle.to_tensor([[1, 2, 3, 4, 5]], dtype="int64", place=self.place) + step_idx = paddle.to_tensor([8], dtype="int64", place=self.place) + stop_flags = paddle.to_tensor([True], dtype="bool", place=self.place) # Already stopped + seq_lens = paddle.to_tensor([10], dtype="int32", place=self.place) + + # Stop sequence that would match + stop_seqs = paddle.to_tensor([[5, 4, 3]], dtype="int64", place=self.place) + stop_seqs_len = paddle.to_tensor([3], dtype="int32", place=self.place) + end_ids = paddle.to_tensor([-1], dtype="int64", place=self.place) + + # Run operator + xpu_results = self.run_op( + "xpu", + accept_tokens, + accept_num, + pre_ids, + step_idx, + stop_flags, + seq_lens, + stop_seqs, + stop_seqs_len, + end_ids, + ) + cpu_results = self.run_op( + "cpu", + accept_tokens, + accept_num, + pre_ids, + step_idx, + stop_flags, + seq_lens, + stop_seqs, + stop_seqs_len, + end_ids, + ) + compare_results(cpu_results, xpu_results) + + # Verify nothing changed + np.testing.assert_array_equal(xpu_results["output_accept_tokens"], accept_tokens.numpy()) + np.testing.assert_array_equal(xpu_results["output_stop_flags"], stop_flags.numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/custom_ops/xpu_ops/test/test_speculate_set_value_by_flags.py b/custom_ops/xpu_ops/test/test_speculate_set_value_by_flags.py new file mode 100644 index 000000000..cf05ba79e --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_set_value_by_flags.py @@ -0,0 +1,83 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +if paddle.is_compiled_with_xpu(): + from fastdeploy.model_executor.ops.xpu import speculate_set_value_by_flags_and_idx +else: + from efficientllm.ops.gpu import speculate_set_value_by_flags_and_idx + + +def test_speculate_set_value_by_flags_and_idx(): + # 将accept_tokens添加到pre_ids的特定位置 + bs = 256 + length = 8192 + max_draft_tokens = 4 + + pre_ids_all = paddle.to_tensor(np.full((bs, length), -1), dtype="int64") + + accept_tokens = np.random.randint(100, 200, size=(bs, max_draft_tokens)) + accept_tokens = paddle.to_tensor(accept_tokens, dtype="int64") + + accept_num = np.random.randint(0, max_draft_tokens + 1, size=bs) + accept_num = paddle.to_tensor(accept_num, dtype="int32") + + stop_flags = np.random.choice([True, False, False, False], size=bs) + stop_flags = paddle.to_tensor(stop_flags, dtype="bool") + + seq_lens_this_time = paddle.to_tensor(np.full((bs), 1), dtype="int32") + seq_lens_encoder = paddle.to_tensor(np.full((bs), 0), dtype="int32") + seq_lens_decoder = paddle.to_tensor(np.full((bs), 2), dtype="int32") + + step_idx = np.random.randint(max_draft_tokens, length, size=bs) + step_idx = paddle.to_tensor(step_idx, dtype="int64") + + out_xpu = speculate_set_value_by_flags_and_idx( + pre_ids_all, + accept_tokens, + accept_num, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + ) + out_xpu = out_xpu.numpy() + + out_cpu = paddle.to_tensor(np.full((bs, length), -1), dtype="int64") + for i in range(bs): + if stop_flags[i] or (seq_lens_encoder[i] == 0 and seq_lens_decoder[i] == 0): + continue + if step_idx[i] >= 0: + for j in range(accept_num[i]): + out_cpu[i, step_idx[i] - j] = accept_tokens[i, accept_num[i] - 1 - j] + + # print(f"accept_tokens: {accept_tokens}") + # print(f"accept_num: {accept_num}") + # print(f"stop_flags: {stop_flags}") + # print(f"seq_lens_this_time: {seq_lens_this_time}") + # print(f"seq_lens_encoder: {seq_lens_encoder}") + # print(f"seq_lens_decoder: {seq_lens_decoder}") + # print(f"step_idx: {step_idx}") + # print(f"out_xpu: {out_xpu}") + # print(f"out_cpu: {out_cpu}") + + assert np.array_equal(out_xpu, out_cpu), "out_xpu != out_cpu" + print("test_speculate_set_value_by_flags_and_idx passed!") + + +if __name__ == "__main__": + test_speculate_set_value_by_flags_and_idx() diff --git a/custom_ops/xpu_ops/test/test_speculate_update_v3.py b/custom_ops/xpu_ops/test/test_speculate_update_v3.py new file mode 100644 index 000000000..1ecebc6e7 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_update_v3.py @@ -0,0 +1,210 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +# tests/test_speculate_update_v3.py +import paddle + +from fastdeploy.model_executor.ops.xpu import speculate_update_v3 + + +# ---------------- NumPy 参考实现 ---------------- +def speculate_update_v3_np( + seq_lens_encoder, + seq_lens_decoder, + not_need_stop, + draft_tokens, + actual_draft_token_nums, + accept_tokens, + accept_num, + stop_flags, + seq_lens_this_time, + is_block_step, + stop_nums, +): + """ + 完全复现 CPU / CUDA 逻辑的 NumPy 参考版本(就地修改)。 + """ + stop_sum = 0 + real_bsz = seq_lens_this_time.shape[0] + max_bsz = stop_flags.shape[0] + max_draft_tokens = draft_tokens.shape[1] + + for bid in range(max_bsz): + stop_flag_now_int = 0 + inactive = bid >= real_bsz + block_step = (not inactive) and is_block_step[bid] + + if (not block_step) and (not inactive): + + if stop_flags[bid]: + stop_flag_now_int = 1 + + # encoder 长度为 0 时直接累加 decoder + if seq_lens_encoder[bid] == 0: + seq_lens_decoder[bid] += accept_num[bid] + + # draft 长度自适应 + if (seq_lens_encoder[bid] == 0) and (seq_lens_this_time[bid] > 1): + cur_len = actual_draft_token_nums[bid] + if accept_num[bid] - 1 == cur_len: # 全部接受 + if cur_len + 2 <= max_draft_tokens - 1: + cur_len += 2 + elif cur_len + 1 <= max_draft_tokens - 1: + cur_len += 1 + else: + cur_len = max_draft_tokens - 1 + else: # 有拒绝 + cur_len = max(1, cur_len - 1) + actual_draft_token_nums[bid] = cur_len + + # 偿还 encoder 欠账 + if seq_lens_encoder[bid] != 0: + seq_lens_decoder[bid] += seq_lens_encoder[bid] + seq_lens_encoder[bid] = 0 + + # 写回下一轮首 token + draft_tokens[bid, 0] = accept_tokens[bid, accept_num[bid] - 1] + + # 停止则清零 decoder + if stop_flag_now_int: + seq_lens_decoder[bid] = 0 + + elif inactive: + stop_flag_now_int = 1 # padding slot 视为 stop + + stop_sum += stop_flag_now_int + + # print("stop_sum: ", stop_sum) + not_need_stop[0] = stop_sum < stop_nums[0] + + # 返回引用,仅供一致性 + return ( + seq_lens_encoder, + seq_lens_decoder, + not_need_stop, + draft_tokens, + actual_draft_token_nums, + ) + + +# ---------------- 生成随机输入 ---------------- +def gen_inputs( + max_bsz=512, # 与 CUDA BlockSize 对齐 + max_draft_tokens=16, + real_bsz=123, # 可自调;须 ≤ max_bsz + seed=2022, +): + rng = np.random.default_rng(seed) + + # 基本张量 + seq_lens_encoder = rng.integers(0, 3, size=max_bsz, dtype=np.int32) + seq_lens_decoder = rng.integers(0, 20, size=max_bsz, dtype=np.int32) + not_need_stop = rng.integers(0, 1, size=1, dtype=np.bool_) + draft_tokens = rng.integers(0, 1000, size=(max_bsz, max_draft_tokens), dtype=np.int64) + actual_draft_nums = rng.integers(1, max_draft_tokens, size=max_bsz, dtype=np.int32) + accept_tokens = rng.integers(0, 1000, size=(max_bsz, max_draft_tokens), dtype=np.int64) + accept_num = rng.integers(1, max_draft_tokens, size=max_bsz, dtype=np.int32) + stop_flags = rng.integers(0, 2, size=max_bsz, dtype=np.bool_) + is_block_step = rng.integers(0, 2, size=max_bsz, dtype=np.bool_) + stop_nums = np.array([5], dtype=np.int64) # 阈值随意 + + # seq_lens_this_time 仅取 real_bsz 长度 + seq_lens_this_time = rng.integers(1, max_draft_tokens, size=real_bsz, dtype=np.int32) + + return { + "seq_lens_encoder": seq_lens_encoder, + "seq_lens_decoder": seq_lens_decoder, + "not_need_stop": not_need_stop, + "draft_tokens": draft_tokens, + "actual_draft_token_nums": actual_draft_nums, + "accept_tokens": accept_tokens, + "accept_num": accept_num, + "stop_flags": stop_flags, + "seq_lens_this_time": seq_lens_this_time, + "is_block_step": is_block_step, + "stop_nums": stop_nums, + # real_bsz = real_bsz, + # max_bsz = max_bsz, + # max_draft_tokens = max_draft_tokens + } + + +# ------------------- 单测主体 ------------------- +inputs = gen_inputs(max_bsz=512, max_draft_tokens=32, real_bsz=201) + +# ---- Paddle 端 ---- +paddle_inputs = {} +for k, v in inputs.items(): + if k in ("real_bsz", "max_bsz", "max_draft_tokens"): + paddle_inputs[k] = v # 纯 python int + else: + if k == "not_need_stop": + paddle_inputs[k] = paddle.to_tensor(v, place=paddle.CPUPlace()) + else: + # 其余张量保持默认 place(想测 GPU 就手动加 place=paddle.CUDAPlace(0)) + paddle_inputs[k] = paddle.to_tensor(v) + +# ---- NumPy 端 ---- +# 为保证初值一致,这里必须复制 Paddle 入参的 numpy 值再传给参考实现 +np_inputs = { + k: (paddle_inputs[k].numpy().copy() if isinstance(paddle_inputs[k], paddle.Tensor) else paddle_inputs[k]) + for k in paddle_inputs +} + +# 调用自定义算子 +# print("seq_lens_encoder_xpu_before: ", paddle_inputs["seq_lens_encoder"]) +out_pd = speculate_update_v3(**paddle_inputs) +# print("seq_lens_encoder_xpu_after: ", out_pd[0]) +# print("not_need_stop: ", out_pd[2]) + +# speculate_update_v3 返回 5 个张量(与 Outputs 对应) +( + seq_lens_encoder_pd, + seq_lens_decoder_pd, + not_need_stop_pd, + draft_tokens_pd, + actual_draft_nums_pd, +) = out_pd + +# print("seq_lens_encoder_np_before: ", np_inputs["seq_lens_encoder"]) +out_np = speculate_update_v3_np(**np_inputs) +# print("seq_lens_encoder_np_after: ", out_np[0]) +# print("not_need_stop: ", out_np[2]) + + +# ---------------- 校对 ---------------- +names = [ + "seq_lens_encoder", + "seq_lens_decoder", + "not_need_stop", + "draft_tokens", + "actual_draft_token_nums", +] +pd_tensors = [ + seq_lens_encoder_pd, + seq_lens_decoder_pd, + not_need_stop_pd, + draft_tokens_pd, + actual_draft_nums_pd, +] + +for name, pd_val, np_val in zip(names, pd_tensors, out_np): + pd_arr = pd_val.numpy() + ok = np.array_equal(pd_arr, np_val) + print(f"{name:25s} equal :", ok) + + # 也可以加 assert,配合 pytest + # assert all(np.array_equal(p.numpy(), n) for p,n in zip(pd_tensors, out_np)) diff --git a/custom_ops/xpu_ops/test/test_speculate_verify.py b/custom_ops/xpu_ops/test/test_speculate_verify.py new file mode 100644 index 000000000..17733a6e2 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_verify.py @@ -0,0 +1,634 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from typing import List + +import numpy as np + +# tests/speculate_verify.py +import paddle + +from fastdeploy.model_executor.ops.xpu import speculate_verify + + +def topp_sampling_kernel(candidate_ids, candidate_scores, curand_value, candidate_len, topp, tid=0): + """ + Python 仿真版 Top-p 样本选择函数。 + + 参数: + - candidate_ids: [candidate_len] int64 array,候选 token + - candidate_scores: [candidate_len] float32 array,对应概率 + - curand_value: float,范围在 [0,1),模拟 GPU 中的 curand_uniform + - candidate_len: int,候选个数 + - topp: float,TopP 截断阈值 + - tid: 模拟线程 ID,仅用于调试(非必须) + + 返回: + - 采样得到的 token(int64) + """ + rand_top_p = curand_value * topp + sum_scores = 0.0 + for i in range(candidate_len): + print( + f"debug sample i:{i} scores:{candidate_scores[i]},ids:{candidate_ids[i]},curand_value{curand_value},topp{topp}, value*topp{rand_top_p}" + ) + sum_scores += candidate_scores[i] + sum_scores += candidate_scores[i] + if rand_top_p <= sum_scores: + return candidate_ids[i] + return candidate_ids[0] # fallback(理论上不会走到这) + + +# def is_in_end(id: int, end_ids: np.ndarray, length: int) -> bool: +# """ +# 判断 id 是否存在于 end_ids 前 length 个元素中。 +# """ +# for i in range(length): +# if id == end_ids[i]: +# return True +# return False + +# def is_in(candidates: np.ndarray, draft: int, candidate_len: int) -> bool: +# """ +# 判断 draft 是否在 candidates 的前 candidate_len 个元素中。 +# """ +# for i in range(candidate_len): +# if draft == candidates[i]: +# return True +# return False + + +# ---------------- NumPy 参考实现 ---------------- +def speculate_verify_np( + accept_tokens, + accept_num, + step_idx, + stop_flags, + seq_lens_encoder, + seq_lens_decoder, + draft_tokens, + seq_lens_this_time, + verify_tokens, + verify_scores, + max_dec_len, + end_tokens, + is_block_step, + output_cum_offsets, + actual_candidate_len, + actual_draft_token_nums, + topp, + max_seq_len, + verify_window, + enable_topp, +): + def is_in_end(token, end_tokens, end_length): + return token in end_tokens[:end_length] + + def is_in(candidate_list, token, length): + return token in candidate_list[:length] + + bsz = accept_tokens.shape[0] + real_bsz = seq_lens_this_time.shape[0] + max_draft_tokens = draft_tokens.shape[1] + end_length = end_tokens.shape[0] + max_candidate_len = verify_tokens.shape[1] + use_topk = False + prefill_one_step_stop = False + + # random + initial_seed = 0 + infer_seed: List[int] = [initial_seed] * bsz + dev_curand_states: List[float] = [] + + # 循环生成随机数 + for i in range(bsz): + current_seed = infer_seed[i] # 这里 current_seed 总是等于 initial_seed + + # 使用当前的种子创建一个独立的随机数生成器实例 + # 这对应于 C++ 的 std::mt19937_64 engine(infer_seed[i]); + rng = random.Random(current_seed) + + # 从独立的生成器中获取一个 [0.0, 1.0) 范围内的浮点数 + # 这对应于 C++ 的 dist(engine); + dev_curand_states.append(rng.random()) + # --- 在函数内部进行扁平化操作 --- + # 只有那些在 C++ 中通过指针算术访问的多维数组需要扁平化 + accept_tokens_flat = accept_tokens.reshape(-1) + draft_tokens_flat = draft_tokens.reshape(-1) + verify_tokens_flat = verify_tokens.reshape(-1) + verify_scores_flat = verify_scores.reshape(-1) + print(f"DEBUG: accept_tokens_flat shape: {accept_tokens_flat.shape}") + print(f"DEBUG: draft_tokens_flat shape: {draft_tokens_flat.shape}") + print(f"DEBUG: verify_tokens_flat shape: {verify_tokens_flat.shape}") + print(f"DEBUG: verify_scores_flat shape: {verify_scores_flat.shape}") + # 其他数组 (如 accept_num, step_idx, stop_flags, end_tokens, dev_curand_states, actual_candidate_len, + # seq_lens_encoder, seq_lens_decoder, actual_draft_token_nums, topp_values, + # seq_lens_this_time, max_dec_len, is_block_step, output_cum_offsets) + # 根据其 C++ 原始定义,如果本身就是一维的,则不需要额外的 reshape。 + # 这里直接使用其原始引用,或者如果其维度不确定,也可以做 flatten()。 + # 为了明确,我们假设这些参数如果不是 (N, K) 形式,就已经是 (N,) 形式。 + print() + # 遍历批次中的每个样本 + for bid in range(real_bsz): + # C++: const int start_token_id = bid * max_seq_len - output_cum_offsets[bid]; + start_token_id = bid * max_seq_len - output_cum_offsets[bid] + accept_num_now = 1 + stop_flag_now_int = 0 + print( + f"DEBUG: start_token_id: {start_token_id}, max_seq_len: {max_seq_len}, output_cum_offsets[{bid}]: {output_cum_offsets[bid]}" + ) + + # C++: if (!(is_block_step[bid] || bid >= real_bsz)) + if not ( + is_block_step[bid] or bid >= real_bsz + ): # bid >= real_bsz 在 Python for 循环中天然满足,但为保持一致保留 + if stop_flags[bid]: + stop_flag_now_int = 1 + else: + # C++: auto *verify_tokens_now = verify_tokens + start_token_id * max_candidate_len; + # Python: verify_tokens_now 是一个指向当前批次 verify_tokens 起始的扁平视图 + # 模拟了 C++ 中指针偏移后的“基地址” + verify_tokens_now = verify_tokens_flat[start_token_id * max_candidate_len :] # 从基址到末尾 + + # C++: auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens; + # Python: draft_tokens_now 是当前批次 draft_tokens 起始的扁平视图 + draft_tokens_now = draft_tokens_flat[bid * max_draft_tokens :] # 从基址到末尾 + + # C++: auto *actual_candidate_len_now = actual_candidate_len + start_token_id; + # Python: actual_candidate_len_now 是当前批次 actual_candidate_len 起始的扁平视图 + actual_candidate_len_now = actual_candidate_len[start_token_id:] # actual_candidate_len 已经是 1D + + # C++: int i = 0; + i = 0 + + # C++: for (; i < seq_lens_this_time[bid] - 1; i++) + for loop_i in range(seq_lens_this_time[bid] - 1): # 使用 loop_i 作为 Python 的循环变量 + i = loop_i # 保持 C++ 的 i 在每次迭代中更新为当前索引 + + # C++: if (seq_lens_encoder[bid] != 0) + if seq_lens_encoder[bid] != 0: + break + + if use_topk: + # C++: if (verify_tokens_now[i * max_candidate_len] == draft_tokens_now[i + 1]) + if verify_tokens_now[i * max_candidate_len] == draft_tokens_now[i + 1]: + step_idx[bid] += 1 + accept_token = draft_tokens_now[i + 1] + # C++: accept_tokens[bid * max_draft_tokens + i] = accept_token; + accept_tokens_flat[bid * max_draft_tokens + i] = accept_token + + # C++: if (is_in_end(accept_token, end_tokens, end_length) || step_idx[bid] >= max_dec_len[bid]) + if is_in_end(accept_token, end_tokens, end_length) or step_idx[bid] >= max_dec_len[bid]: + stop_flags[bid] = True + stop_flag_now_int = 1 + if step_idx[bid] >= max_dec_len[bid]: + accept_tokens_flat[bid * max_draft_tokens + i] = end_tokens[0] + break + else: + accept_num_now += 1 + else: + break + else: # C++: else (Top P verify) + # C++: auto actual_candidate_len_value = actual_candidate_len_now[i] > max_candidate_len ? max_candidate_len : actual_candidate_len_now[i]; + actual_candidate_len_value = min(actual_candidate_len_now[i], max_candidate_len) + + # C++: if (is_in(verify_tokens_now + i * max_candidate_len, draft_tokens_now[i + 1], actual_candidate_len_value)) + # 传入当前候选的扁平视图 + verify_tokens_current_candidate_view = verify_tokens_now[ + i * max_candidate_len : (i + 1) * max_candidate_len + ] + + if is_in( + verify_tokens_current_candidate_view, + draft_tokens_now[i + 1], + actual_candidate_len_value, + ): + step_idx[bid] += 1 + accept_token = draft_tokens_now[i + 1] + accept_tokens_flat[bid * max_draft_tokens + i] = accept_token + + if is_in_end(accept_token, end_tokens, end_length) or step_idx[bid] >= max_dec_len[bid]: + stop_flags[bid] = True + stop_flag_now_int = 1 + if step_idx[bid] >= max_dec_len[bid]: + accept_tokens_flat[bid * max_draft_tokens + i] = end_tokens[0] + break + else: + accept_num_now += 1 + else: + # TopK verify + ii = i # C++ 中 ii 从 i 开始 + # C++: if (max_candidate_len >= 2 && verify_tokens_now[ii * max_candidate_len + 1] == draft_tokens_now[ii + 1]) + if ( + max_candidate_len >= 2 + and verify_tokens_now[ii * max_candidate_len + 1] == draft_tokens_now[ii + 1] + ): # top-2 + j = 0 + ii += 1 # C++ 中 ii 从下一个位置开始检查 + # C++: for (; j < verify_window && ii < seq_lens_this_time[bid] - 1; j++, ii++) + while j < verify_window and ii < seq_lens_this_time[bid] - 1: + if verify_tokens_now[ii * max_candidate_len] != draft_tokens_now[ii + 1]: + break + j += 1 + ii += 1 + + # C++: if (j >= verify_window) + if j >= verify_window: # accept all + accept_num_now += verify_window + 1 + step_idx[bid] += verify_window + 1 + # C++: for (; i < ii; i++) + for k_accepted_idx in range(i, ii): # i 会被更新 + accept_token = draft_tokens_now[k_accepted_idx + 1] + accept_tokens_flat[bid * max_draft_tokens + k_accepted_idx] = accept_token + + if ( + is_in_end( + accept_token, + end_tokens, + end_length, + ) + or step_idx[bid] >= max_dec_len[bid] + ): + stop_flags[bid] = True + stop_flag_now_int = 1 + if step_idx[bid] >= max_dec_len[bid]: + accept_tokens_flat[bid * max_draft_tokens + k_accepted_idx] = ( + end_tokens[0] + ) + accept_num_now -= 1 + step_idx[bid] -= 1 + break # 跳出内层接受循环 + break # 跳出主验证循环 (TopK 逻辑结束,无论成功与否) + # else 的 break 对应 is_in(Top P 验证失败,也不是 TopK 匹配) + break # 跳出主验证循环 + + # 采样阶段 (Sampling Phase) + # C++ 中 i 变量在循环结束后会保留其最终值,直接用于采样 + # Python 同样,loop_i 的最终值赋值给了 i + + if not stop_flag_now_int: + accept_token: int + + # C++: const float *verify_scores_now = verify_scores + start_token_id * max_candidate_len; + # Python: verify_scores_now 对应 C++ 中从 start_token_id 开始的 verify_scores 视图 + verify_scores_now = verify_scores_flat[start_token_id * max_candidate_len :] + + step_idx[bid] += 1 + + if enable_topp: + # C++: auto actual_candidate_len_value = actual_candidate_len_now[i] > max_candidate_len ? max_candidate_len : actual_candidate_len_now[i]; + actual_candidate_len_value = min(actual_candidate_len_now[i], max_candidate_len) + + # 传入当前候选的扁平视图 + verify_tokens_sampling_view = verify_tokens_now[ + i * max_candidate_len : (i + 1) * max_candidate_len + ] + verify_scores_sampling_view = verify_scores_now[ + i * max_candidate_len : (i + 1) * max_candidate_len + ] + + # C++: accept_token = topp_sampling_kernel(...) + accept_token = topp_sampling_kernel( + verify_tokens_sampling_view, + verify_scores_sampling_view, + dev_curand_states[i], # C++: dev_curand_states + i + actual_candidate_len_value, + topp[bid], # C++: topp[bid] + bid, # C++: bid + ) + else: + accept_token = int(verify_tokens_now[i * max_candidate_len]) + print( + "debug python last accept_token", + accept_token, + "prefill_one_step_stop", + prefill_one_step_stop, + ) + # C++: accept_tokens[bid * max_draft_tokens + i] = accept_token; + accept_tokens_flat[bid * max_draft_tokens + i] = accept_token + + if prefill_one_step_stop: + stop_flags[bid] = True + + if is_in_end(accept_token, end_tokens, end_length) or step_idx[bid] >= max_dec_len[bid]: + stop_flags[bid] = True + stop_flag_now_int = 1 + if step_idx[bid] >= max_dec_len[bid]: + accept_tokens_flat[bid * max_draft_tokens + i] = end_tokens[0] + + accept_num[bid] = accept_num_now + + return accept_tokens, accept_num, step_idx, stop_flags + + +# ---------------- 生成随机输入 ---------------- +def gen_speculate_verify_inputs( + real_bsz=123, + max_draft_tokens=16, + max_seq_len=256, + max_candidate_len=8, + verify_window=2, + end_length=4, + enable_topp=True, + seed=2025, +): + rng = np.random.default_rng(seed) + + # 基础输入 + seq_lens_encoder = rng.integers(0, 3, size=real_bsz, dtype=np.int32) + seq_lens_decoder = rng.integers(1, max_draft_tokens, size=real_bsz, dtype=np.int32) + draft_tokens = rng.integers(0, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64) + actual_draft_token_nums = rng.integers(1, max_draft_tokens + 1, size=real_bsz, dtype=np.int32) + + seq_lens_this_time = rng.integers(1, max_seq_len + 1, size=real_bsz, dtype=np.int32) + sum_seq_this_time = int(np.sum(seq_lens_this_time)) + # print("debug param set sum_seq_this_time",sum_seq_this_time) + # print("debug param real_bsz * max_draft_tokens < 2k",real_bsz * max_draft_tokens) + # print("debug sum_seq_this_time * max_candidate_len < 2k",sum_seq_this_time * max_candidate_len) + + verify_tokens = rng.integers(0, 1000, size=(sum_seq_this_time, max_candidate_len), dtype=np.int64) + verify_scores = rng.random(size=(sum_seq_this_time, max_candidate_len)).astype(np.float32) + + max_dec_len = rng.integers(16, 64, size=real_bsz, dtype=np.int64) + end_tokens = rng.integers(1, 1000, size=end_length, dtype=np.int64) + is_block_step = rng.integers(0, 2, size=real_bsz, dtype=bool) + + # output_cum_offsets = np.zeros_like(seq_lens_this_time) + # output_cum_offsets[1:] = np.cumsum(seq_lens_this_time[:-1]) + blank_lengths = max_seq_len - seq_lens_this_time + output_cum_offsets = np.concatenate([[0], np.cumsum(blank_lengths[:-1])]) + output_cum_offsets = output_cum_offsets.astype("int32") + actual_candidate_len = rng.integers(1, max_candidate_len + 1, size=sum_seq_this_time, dtype=np.int32) + + topp = ( + rng.uniform(0.8, 1.0, size=real_bsz).astype(np.float32) + if enable_topp + else np.zeros(real_bsz, dtype=np.float32) + ) + + # 输出(占位) + accept_tokens = np.zeros((real_bsz, max_draft_tokens), dtype=np.int64) + accept_num = np.zeros(real_bsz, dtype=np.int32) + step_idx = np.zeros(real_bsz, dtype=np.int64) + stop_flags = np.zeros(real_bsz, dtype=bool) + + return { + "accept_tokens": accept_tokens, + "accept_num": accept_num, + "step_idx": step_idx, + "stop_flags": stop_flags, + "seq_lens_encoder": seq_lens_encoder, + "seq_lens_decoder": seq_lens_decoder, + "draft_tokens": draft_tokens, + "seq_lens_this_time": seq_lens_this_time, + "verify_tokens": verify_tokens, + "verify_scores": verify_scores, + "max_dec_len": max_dec_len, + "end_tokens": end_tokens, + "is_block_step": is_block_step, + "output_cum_offsets": output_cum_offsets, + "actual_candidate_len": actual_candidate_len, + "actual_draft_token_nums": actual_draft_token_nums, + "topp": topp, + "max_seq_len": max_seq_len, + "verify_window": verify_window, + "enable_topp": enable_topp, + } + + +# ------------------- 单测主体 ------------------- +# # ---- Paddle 端 ---- +def run_speculate_verify_test( + real_bsz, + max_draft_tokens, + max_seq_len, + max_candidate_len, + verify_window, + end_length, + enable_topp, + seed, +): + inputs = gen_speculate_verify_inputs( + real_bsz=real_bsz, + max_draft_tokens=max_draft_tokens, + max_seq_len=max_seq_len, + max_candidate_len=max_candidate_len, + verify_window=verify_window, + end_length=end_length, + enable_topp=enable_topp, + seed=seed, + ) + + paddle_inputs = {} + + print("========= 1 xpu process==========") + + for k, v in inputs.items(): + if isinstance(v, (int, bool)): + paddle_inputs[k] = v + # print(f"{k:<25} type: {type(v).__name__}, value: {v}") + else: + # paddle_inputs[k] = paddle.to_tensor(v, place=paddle.CPUPlace()) + paddle_inputs[k] = paddle.to_tensor(v, place=paddle.XPUPlace(0)) + # print(f"{k:<25} type: Tensor, dtype: {paddle_inputs[k].dtype}, shape: {paddle_inputs[k].shape}") + + out_pd = speculate_verify(**paddle_inputs) + (accept_tokens_pd, accept_num_pd, step_idx_pd, stop_flags_pd) = out_pd + pd_tensors = [accept_tokens_pd, accept_num_pd, step_idx_pd, stop_flags_pd] + + print("========= 1 end==========") + print("========= 2 python process==========") + + # np_inputs = {k: (paddle_inputs[k].numpy().copy() if isinstance(paddle_inputs[k], paddle.Tensor) + # else paddle_inputs[k]) + # for k in paddle_inputs} + + # out_np = speculate_verify_np(**np_inputs) + # (accept_tokens_np, accept_num_np, step_idx_np, stop_flags_np) = out_np + # np_tensors = [accept_tokens_np, accept_num_np, step_idx_np, stop_flags_np] + + print("=========2 end =======") + + print("========= 3 (CPU)==========") + paddle_inputs_cpu = {} + + for k, v in inputs.items(): # 重新使用原始的 inputs 字典,确保数据原始状态 + if isinstance(v, (int, bool)): + paddle_inputs_cpu[k] = v + # print(f"{k:<25} type: {type(v).__name__}, value: {v}") + else: + # 核心修改:使用 paddle.CPUPlace() + paddle_inputs_cpu[k] = paddle.to_tensor(v, place=paddle.CPUPlace()) + # print(f"{k:<25} type: Tensor, dtype: {paddle_inputs_cpu[k].dtype}, shape: {paddle_inputs_cpu[k].shape}") + + out_cpu = speculate_verify(**paddle_inputs_cpu) + (accept_tokens_cpu, accept_num_cpu, step_idx_cpu, stop_flags_cpu) = out_cpu + + cpu_tensors = [ + accept_tokens_cpu, + accept_num_cpu, + step_idx_cpu, + stop_flags_cpu, + ] + print("========= 3 (CPU) end==========") + + # ---------------- 校对 ---------------- + # print("========= python/cpu vs xpu verify ==========") + + # names = ["accept_tokens", "accept_num", "step_idx", "stop_flags"] + # for name, pd_val, np_val in zip(names, pd_tensors, np_tensors): + # pd_arr = pd_val.numpy() + # ok = np.array_equal(pd_arr, np_val) + # print(f"{name:20s} equal: {ok}") + # if not ok: + # print(f"{name} mismatch!\nPaddle:\n{pd_arr}\n\nNumPy:\n{np_val}") + + print("========= cpu vs xpu verify ==========") + + names = ["accept_tokens", "accept_num", "step_idx", "stop_flags"] + # for name, pd_val, np_val in zip(names, pd_tensors, cpu_tensors): + # pd_arr = pd_val.numpy() + # ok = np.array_equal(pd_arr, np_val) + # print(f"{name:20s} equal: {ok}") + # if not ok: + # print(f"{name} mismatch!\nPaddle:\n{pd_arr}\n\nNumPy:\n{np_val}") + + for name, pd_val, np_val in zip(names, pd_tensors, cpu_tensors): + pd_arr = pd_val.numpy() + ok = np.array_equal(pd_arr, np_val) + print(f"{name:20s} equal: {ok}") + if not ok: + print(f"{name} mismatch!") + + # 输出不同位置的索引和对应值 + print(f"{name} mismatch!\nPaddle:\n{pd_arr}\n\nNumPy:\n{np_val}") + mismatches = np.where(pd_arr != np_val) + for idx in zip(*mismatches): + print(f" idx {idx}: Paddle = {pd_arr[idx]}, NumPy = {np_val[idx]}") + + # 如果差异太多可限制输出数量 + if len(mismatches[0]) > 20: + print(" ... (truncated)") + + +# ------------------------------------- +# 测试用例 +# ------------------------------------- +test_configs = [ + { + "real_bsz": 4, + "max_draft_tokens": 3, + "max_seq_len": 30, + "max_candidate_len": 4, + "verify_window": 2, + "end_length": 2, + "enable_topp": True, + "seed": 2025, + }, + { + "real_bsz": 77, + "max_draft_tokens": 10, + "max_seq_len": 12000, + "max_candidate_len": 8, + "verify_window": 2, + "end_length": 4, + "enable_topp": True, + "seed": 2025, + }, + { + "real_bsz": 1, + "max_draft_tokens": 2, + "max_seq_len": 10, + "max_candidate_len": 1, + "verify_window": 1, + "end_length": 1, + "enable_topp": True, + "seed": 42, + }, + { + "real_bsz": 128, + "max_draft_tokens": 7, + "max_seq_len": 999, + "max_candidate_len": 5, + "verify_window": 3, + "end_length": 3, + "enable_topp": True, + "seed": 422, + }, + { + "real_bsz": 99, + "max_draft_tokens": 5, + "max_seq_len": 10, + "max_candidate_len": 3, + "verify_window": 4, + "end_length": 4, + "enable_topp": True, + "seed": 42, + }, + { + "real_bsz": 1, + "max_draft_tokens": 9, + "max_seq_len": 11, + "max_candidate_len": 4, + "verify_window": 2, + "end_length": 5, + "enable_topp": False, + "seed": 42, + }, + { + "real_bsz": 33, + "max_draft_tokens": 5, + "max_seq_len": 10111, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 6, + "enable_topp": False, + "seed": 42, + }, + { + "real_bsz": 6, + "max_draft_tokens": 4, + "max_seq_len": 10001, + "max_candidate_len": 6, + "verify_window": 2, + "end_length": 7, + "enable_topp": False, + "seed": 42, + }, + { + "real_bsz": 7, + "max_draft_tokens": 3, + "max_seq_len": 777, + "max_candidate_len": 7, + "verify_window": 2, + "end_length": 5, + "enable_topp": False, + "seed": 42, + }, + { + "real_bsz": 55, + "max_draft_tokens": 5, + "max_seq_len": 31, + "max_candidate_len": 9, + "verify_window": 2, + "end_length": 3, + "enable_topp": False, + "seed": 42, + }, +] + +for i, cfg in enumerate(test_configs): + print(f"\n\n======== Running Test Case {i} ========") + run_speculate_verify_test(**cfg) diff --git a/custom_ops/xpu_ops/test/python/ops/test_step.py b/custom_ops/xpu_ops/test/test_step.py similarity index 100% rename from custom_ops/xpu_ops/test/python/ops/test_step.py rename to custom_ops/xpu_ops/test/test_step.py diff --git a/custom_ops/xpu_ops/test/python/ops/test_stop_generation_multi_ends.py b/custom_ops/xpu_ops/test/test_stop_generation_multi_ends.py similarity index 100% rename from custom_ops/xpu_ops/test/python/ops/test_stop_generation_multi_ends.py rename to custom_ops/xpu_ops/test/test_stop_generation_multi_ends.py diff --git a/custom_ops/xpu_ops/test/python/ops/test_token_repetition_penalty.py b/custom_ops/xpu_ops/test/test_token_repetition_penalty.py similarity index 100% rename from custom_ops/xpu_ops/test/python/ops/test_token_repetition_penalty.py rename to custom_ops/xpu_ops/test/test_token_repetition_penalty.py diff --git a/custom_ops/xpu_ops/test/python/ops/test_update_inputs.py b/custom_ops/xpu_ops/test/test_update_inputs.py similarity index 100% rename from custom_ops/xpu_ops/test/python/ops/test_update_inputs.py rename to custom_ops/xpu_ops/test/test_update_inputs.py diff --git a/custom_ops/xpu_ops/test/python/ops/test_weight_quantize_xpu.py b/custom_ops/xpu_ops/test/test_weight_quantize_xpu.py similarity index 100% rename from custom_ops/xpu_ops/test/python/ops/test_weight_quantize_xpu.py rename to custom_ops/xpu_ops/test/test_weight_quantize_xpu.py