diff --git a/custom_ops/xpu_ops/src/ops/recover_decode_task.cc b/custom_ops/xpu_ops/src/ops/recover_decode_task.cc new file mode 100644 index 000000000..34871f0d3 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/recover_decode_task.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" + +void RecoverDecodeTask(const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &step_seq_lens_decoder, + const paddle::Tensor &block_tables, + const paddle::Tensor &is_block_step, + const int block_size) { +phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = + paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + const int bsz = seq_lens_this_time.shape()[0]; + const int block_num_per_seq = block_tables.shape()[1]; + int r = baidu::xpu::api::plugin::recover_decode_task( + xpu_ctx->x_context(), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(is_block_step.data()), + bsz, + block_num_per_seq, + block_size); + PD_CHECK(r == 0, "baidu::xpu::api::plugin::recover_decode_task failed."); +} + +PD_BUILD_OP(recover_decode_task) + .Inputs({"stop_flags", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "step_seq_lens_decoder", + "block_tables", + "is_block_step"}) + .Attrs({"block_size: int"}) + .Outputs({"seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "stop_flags_out", + "is_block_step_out"}) + .SetInplaceMap({{"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"stop_flags", "stop_flags_out"}, + {"is_block_step", "is_block_step_out"}}) + .SetKernelFn(PD_KERNEL(RecoverDecodeTask)); diff --git a/custom_ops/xpu_ops/src/ops/update_inputs_v1.cc b/custom_ops/xpu_ops/src/ops/update_inputs_v1.cc new file mode 100644 index 000000000..50dc8d748 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/update_inputs_v1.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" + +void UpdateInputesV1(const paddle::Tensor &stop_flags, + const paddle::Tensor ¬_need_stop, // only on cpu + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &step_seq_lens_decoder, + const paddle::Tensor &prompt_lens, + const paddle::Tensor &topk_ids, + const paddle::Tensor &input_ids, + const paddle::Tensor &block_tables, + const paddle::Tensor &stop_nums, + const paddle::Tensor &next_tokens, + const paddle::Tensor &is_block_step, + const int block_size) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = + paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + + const int max_bsz = stop_flags.shape()[0]; + const int now_bsz = seq_lens_this_time.shape()[0]; + // std::cout << "now_bsz: " << now_bsz << std::endl; + const int input_ids_stride = input_ids.shape()[1]; + const int block_num_per_seq = block_tables.shape()[1]; + auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); + int r = baidu::xpu::api::plugin::update_inputs_v1( + xpu_ctx->x_context(), + const_cast(not_need_stop_gpu.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(prompt_lens.data()), + const_cast(topk_ids.data()), + const_cast(input_ids.data()), + const_cast(block_tables.data()), + stop_nums.data(), + const_cast(stop_flags.data()), + const_cast(is_block_step.data()), + next_tokens.data(), + now_bsz, + max_bsz, + input_ids_stride, + block_num_per_seq, + block_size); + PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs_kernel_v1 failed."); + auto not_need_stop_cpu = + not_need_stop_gpu.copy_to(not_need_stop.place(), false); + bool *not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; +} + +PD_BUILD_OP(update_inputs_v1) + .Inputs({"stop_flags", + "not_need_stop", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "step_seq_lens_decoder", + "prompt_lens", + "topk_ids", + "input_ids", + "block_tables", + "stop_nums", + "next_tokens", + "is_block_step"}) + .Attrs({"block_size: int"}) + .Outputs({"not_need_stop_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "step_seq_lens_decoder_out", + "topk_ids_out", + "input_ids_out", + "stop_flags_out", + "is_block_step_out"}) + .SetInplaceMap({{"not_need_stop", "not_need_stop_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"topk_ids", "topk_ids_out"}, + {"input_ids", "input_ids_out"}, + {"stop_flags", "stop_flags_out"}, + {"step_seq_lens_decoder", "step_seq_lens_decoder_out"}, + {"is_block_step", "is_block_step_out"}}) + .SetKernelFn(PD_KERNEL(UpdateInputesV1)); diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index ddf5aab33..ce6262044 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -86,6 +86,39 @@ recover_block(Context *ctx, const int block_num_per_seq, const int length, const int pre_id_length); + +DLL_EXPORT int +recover_decode_task(Context *ctx, bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int *block_tables, + bool *is_block_step, + const int bsz, + const int block_num_per_seq, + const int block_size); + +DLL_EXPORT int +update_inputs_v1(Context *ctx, bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int64_t *prompt_lens, + int64_t *topk_ids, + int64_t *input_ids, + int *block_tables, + const int64_t *stop_nums, + bool *stop_flags, + bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride, + const int block_num_per_seq, + const int block_size); + template DLL_EXPORT int eb_adjust_batch(Context *ctx, const TX *x, TY *y, diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/recover_decode_task.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/recover_decode_task.xpu new file mode 100644 index 000000000..db6efb4c7 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/recover_decode_task.xpu @@ -0,0 +1,41 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +__global__ void recover_decode_task(bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int *block_tables, + bool *is_block_step, + const int bsz, + const int block_num_per_seq, + const int block_size) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + int thread_idx = clusterid * ncores + cid; + int nthreads = nclusters * ncores; + // if (clusterid != 0) return; + for (; thread_idx < bsz; thread_idx += nthreads) { + if(is_block_step[thread_idx] == true) { + // int *block_table_now = block_tables + thread_idx * block_num_per_seq; + if (block_tables[thread_idx * block_num_per_seq + step_seq_lens_decoder[thread_idx] / block_size] != -1) { + // can be recovered for decoding + is_block_step[thread_idx] = false; + seq_lens_this_time[thread_idx]= 1; + stop_flags[thread_idx] = false; + seq_lens_encoder[thread_idx] = 0; + seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx]; + } + } + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs_v1.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs_v1.xpu new file mode 100644 index 000000000..8eb87c12d --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs_v1.xpu @@ -0,0 +1,131 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +// #include +// using namespace std; + +#include "xpu/kernel/xtdk_io.h" +#include "xpu/kernel/xtdk.h" + +namespace xpu3 { +namespace plugin { + +__global__ void update_inputs_v1(bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int64_t *prompt_lens, + int64_t *topk_ids, + int64_t *input_ids, + int *block_tables, + const int64_t *stop_nums, + bool *stop_flags, + bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride, + const int block_num_per_seq, + const int block_size) { + + + // std::cout << "seq_lens_this_time " << seq_lens_this_time[0] << std::endl; + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + int thread_idx = clusterid * ncores + cid; + if (clusterid != 0) return; + + const int max_bs = 1024; + __shared__ bool stop_flags_sm[max_bs]; + __shared__ int stop_flags_int_sm[max_bs]; + if(cid == 0){ + GM2SM(stop_flags, stop_flags_sm, sizeof(bool) * bsz); + } + sync_all(); + + for(int i = cid; i < bsz; i+= ncores){ + if(i < bsz){ + stop_flags_sm[i] = stop_flags[i]; + stop_flags_int_sm[i] = static_cast(stop_flags_sm[i]); + }else{ + stop_flags_sm[i] = true; + stop_flags_int_sm[i] = 1; + } + if(i= prompt_lens_update){ + seq_len_decoder_update = seq_len_this_time_update + seq_len_decoder_update; + LM2GM(&seq_len_decoder_update, seq_lens_decoder+i, sizeof(int)); + seq_len_this_time_update = 1; + LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int)); + seq_lens_encoder_update = 0; + LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int)); + int64_t input_ids_update; + GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t)); + LM2GM(&input_ids_update, input_ids + i * input_ids_stride, sizeof(int64_t)); + // to judge whether block is not enough + if(seq_len_this_time_update != 0 && block_tables[i * block_num_per_seq + seq_len_decoder_update/block_size] == -1){ + is_block_step[i] = true; + seq_len_this_time_update = 0; + LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int)); + stop_flags_sm[i] = true; + SM2GM(stop_flags_sm+i, stop_flags+i, sizeof(bool)); + LM2GM(&seq_len_decoder_update, step_seq_lens_decoder+i, sizeof(int)); + seq_len_decoder_update = 0; + LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); + seq_len_decoder_update = 0; + LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); + stop_flags_int_sm[i] = 1; + } + }else{ + stop_flags_sm[i] = true; + SM2GM(stop_flags_sm+i, stop_flags+i, sizeof(bool)); + seq_len_this_time_update = 0; + LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int)); + seq_len_decoder_update = 0; + seq_lens_encoder_update = 0; + LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); + LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int)); + int64_t topk_ids_update = -1; + LM2GM(&topk_ids_update, topk_ids + i, sizeof(int64_t)); + stop_flags_int_sm[i] = 1; + } + + } + } + } + sync_all(); + sync_cluster(); + int stop_sum = 0; + if (cid == 0) { + for (int i = 0; i < max_bsz; i++) { + stop_sum += stop_flags_int_sm[i]; + } + // printf("stop_sum : %d\n", stop_sum); + int64_t stop_num; + GM2LM(stop_nums, &stop_num, sizeof(int64_t)); + bool not_need_stop_update = stop_sum < static_cast(stop_num); + mfence_lm(); + LM2GM(¬_need_stop_update, not_need_stop, sizeof(bool)); + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/recover_decode_task.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/recover_decode_task.cpp new file mode 100644 index 000000000..1ed700897 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/recover_decode_task.cpp @@ -0,0 +1,107 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" +#include +#include + +namespace xpu3 { +namespace plugin { + +__attribute__((global)) void +recover_decode_task(bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int *block_tables, + bool *is_block_step, + const int bsz, + const int block_num_per_seq, + const int block_size); + +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int xpu3_wrapper(Context *ctx, bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int *block_tables, + bool *is_block_step, + const int bsz, + const int block_num_per_seq, + const int block_size) { + using XPU_INT64 = typename XPUIndexType::type; + auto recover_decode_task = xpu3::plugin::recover_decode_task; + recover_decode_task<<ncluster(), 64, ctx->xpu_stream>>>( + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_seq_lens_decoder, + block_tables, + is_block_step, + bsz, + block_num_per_seq, + block_size); + return api::SUCCESS; +} + +int recover_decode_task(Context *ctx, bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int *block_tables, + bool *is_block_step, + const int bsz, + const int block_num_per_seq, + const int block_size) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "recover_decode_task", int); + WRAPPER_DUMP_PARAM5(ctx, stop_flags, seq_lens_this_time, + seq_lens_encoder, seq_lens_decoder, step_seq_lens_decoder); + WRAPPER_DUMP_PARAM2(ctx, block_tables, is_block_step); + WRAPPER_DUMP_PARAM3(ctx, bsz, block_num_per_seq, block_size); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + assert(false); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_seq_lens_decoder, + block_tables, + is_block_step, + bsz, + block_num_per_seq, + block_size); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs_v1.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs_v1.cpp new file mode 100644 index 000000000..ce97e91d7 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs_v1.cpp @@ -0,0 +1,149 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" +#include +#include + +namespace xpu3 { +namespace plugin { + +__attribute__((global)) void +update_inputs_v1(bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int64_t *prompt_lens, + int64_t *topk_ids, + int64_t *input_ids, + int *block_tables, + const int64_t *stop_nums, + bool *stop_flags, + bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride, + const int block_num_per_seq, + const int block_size); + +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int xpu3_wrapper(Context *ctx, bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int64_t *prompt_lens, + int64_t *topk_ids, + int64_t *input_ids, + int *block_tables, + const int64_t *stop_nums, + bool *stop_flags, + bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride, + const int block_num_per_seq, + const int block_size) { + using XPU_INT64 = typename XPUIndexType::type; + auto update_inputs_v1 = xpu3::plugin::update_inputs_v1; + // kernel 内要做 reduce,只能用 1 个 cluster + update_inputs_v1<<<1, 64, ctx->xpu_stream>>>( + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_seq_lens_decoder, + reinterpret_cast(prompt_lens), + reinterpret_cast(topk_ids), + reinterpret_cast(input_ids), + block_tables, + reinterpret_cast(stop_nums), + stop_flags, + is_block_step, + reinterpret_cast(next_tokens), + bsz, + max_bsz, + input_ids_stride, + block_num_per_seq, + block_size); + return api::SUCCESS; +} + +int update_inputs_v1(Context *ctx, bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int64_t *prompt_lens, + int64_t *topk_ids, + int64_t *input_ids, + int *block_tables, + const int64_t *stop_nums, + bool *stop_flags, + bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride, + const int block_num_per_seq, + const int block_size) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs_v1", int); + WRAPPER_DUMP_PARAM5(ctx, not_need_stop, seq_lens_this_time, + seq_lens_encoder, seq_lens_decoder, step_seq_lens_decoder); + WRAPPER_DUMP_PARAM5(ctx, prompt_lens, topk_ids, input_ids, block_tables, stop_nums); + WRAPPER_DUMP_PARAM3(ctx, stop_flags, is_block_step, next_tokens); + WRAPPER_DUMP_PARAM5(ctx, bsz, max_bsz, input_ids_stride, block_num_per_seq, block_size); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + assert(false); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_seq_lens_decoder, + prompt_lens, + topk_ids, + input_ids, + block_tables, + stop_nums, + stop_flags, + is_block_step, + next_tokens, + bsz, + max_bsz, + input_ids_stride, + block_num_per_seq, + block_size); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/setup_ops.py b/custom_ops/xpu_ops/src/setup_ops.py index c819cf9d9..5ad31e912 100755 --- a/custom_ops/xpu_ops/src/setup_ops.py +++ b/custom_ops/xpu_ops/src/setup_ops.py @@ -144,6 +144,8 @@ def xpu_setup_ops(): "./ops/get_token_penalty_multi_scores.cc", "./ops/get_padding_offset.cc", "./ops/update_inputs.cc", + "./ops/recover_decode_task.cc", + "./ops/update_inputs_v1.cc", "./ops/get_output.cc", "./ops/step.cc", "./ops/get_infer_param.cc", diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index c169a02a2..01eff6c7c 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -22,8 +22,9 @@ import numpy as np import paddle from paddle import nn +from fastdeploy import envs from fastdeploy.config import FDConfig -from fastdeploy.engine.request import Request +from fastdeploy.engine.request import Request, RequestType from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.base_attention_backend import ( @@ -33,6 +34,13 @@ from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import Sampler from fastdeploy.model_executor.model_loader import get_model_from_loader +from fastdeploy.model_executor.ops.xpu import ( + adjust_batch, + get_infer_param, + get_padding_offset, + recover_decode_task, + update_inputs_v1, +) from fastdeploy.utils import get_logger from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput @@ -53,11 +61,6 @@ def xpu_pre_process( max_len = input_ids.shape[1] cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from fastdeploy.model_executor.ops.xpu import ( - adjust_batch, - get_infer_param, - get_padding_offset, - ) ( ids_remove_padding, @@ -111,6 +114,18 @@ def xpu_pre_process( ) = get_infer_param(seq_lens_encoder, seq_lens_decoder) # Adjust batch + # print(f"=========================adjust_batch 更新前=========================") + # print(f"ids_remove_padding : {ids_remove_padding}") + # print(f"cum_offsets : {cum_offsets}") + # print(f"xpu_forward_meta.encoder_seq_lod : {xpu_forward_meta.encoder_seq_lod}") + # print(f"xpu_forward_meta.encoder_batch_idx: {xpu_forward_meta.encoder_batch_idx}") + # print(f"xpu_forward_meta.decoder_batch_idx : {xpu_forward_meta.decoder_batch_idx}") + # print(f"xpu_forward_meta.encoder_seq_lod_cpu : {xpu_forward_meta.encoder_seq_lod_cpu}") + # print(f"xpu_forward_meta.encoder_batch_idx_cpu : {xpu_forward_meta.encoder_batch_idx_cpu}") + # print(f"xpu_forward_meta.decoder_batch_idx_cpu : {xpu_forward_meta.decoder_batch_idx_cpu}") + # print(f"xpu_forward_meta.enc_batch : {xpu_forward_meta.encoder_batch_map}") + # print(f"xpu_forward_meta.dec_batch : {xpu_forward_meta.decoder_batch_map}") + adjusted_input = adjust_batch( ids_remove_padding.reshape([-1, 1]), cum_offsets, @@ -125,6 +140,17 @@ def xpu_pre_process( None, # output_padding_offset -1, # max_input_length ) + # print(f"=========================adjust_batch 更新后=========================") + # print(f"ids_remove_padding : {ids_remove_padding}") + # print(f"cum_offsets : {cum_offsets}") + # print(f"xpu_forward_meta.encoder_seq_lod : {xpu_forward_meta.encoder_seq_lod}") + # print(f"xpu_forward_meta.encoder_batch_idx: {xpu_forward_meta.encoder_batch_idx}") + # print(f"xpu_forward_meta.decoder_batch_idx : {xpu_forward_meta.decoder_batch_idx}") + # print(f"xpu_forward_meta.encoder_seq_lod_cpu : {xpu_forward_meta.encoder_seq_lod_cpu}") + # print(f"xpu_forward_meta.encoder_batch_idx_cpu : {xpu_forward_meta.encoder_batch_idx_cpu}") + # print(f"xpu_forward_meta.decoder_batch_idx_cpu : {xpu_forward_meta.decoder_batch_idx_cpu}") + # print(f"xpu_forward_meta.enc_batch : {xpu_forward_meta.encoder_batch_map}") + adjusted_input = adjusted_input.squeeze(1) share_inputs["ids_remove_padding"] = adjusted_input @@ -160,7 +186,9 @@ def xpu_process_output( def xpu_post_process( sampled_token_ids: paddle.Tensor, model_output: ModelOutputData, - skip_save_output: bool, + share_inputs: Dict[str, paddle.Tensor], + block_size: int = 64, + skip_save_output: bool = False, ) -> None: """ """ from fastdeploy.model_executor.ops.xpu import ( @@ -194,17 +222,66 @@ def xpu_post_process( # 2. Update the input buffer of the model with paddle.framework._no_check_dy2st_diff(): - update_inputs( - model_output.stop_flags, - model_output.not_need_stop, - model_output.seq_lens_this_time, - model_output.seq_lens_encoder, - model_output.seq_lens_decoder, - model_output.input_ids, - model_output.stop_nums, - sampled_token_ids, - model_output.is_block_step, - ) + if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output: + + # print(f"============================================update_inputs_v1 更新前=========================================") + # print(f"model_output.stop_flags : {model_output.stop_flags}") + # print(f"model_output.not_need_stop : {model_output.not_need_stop}") + # print(f"model_output.seq_lens_this_time : {model_output.seq_lens_this_time}") + # print(f"model_output.seq_lens_encoder : {model_output.seq_lens_encoder}") + # print(f"model_output.seq_lens_decoder : {model_output.seq_lens_decoder}") + # print(f"share_inputs['step_seq_lens_decoder'] : {share_inputs['step_seq_lens_decoder']}") + # print(f"share_inputs['prompt_lens'] : {share_inputs['prompt_lens']}") + # print(f"sampled_token_ids : {sampled_token_ids}") + # print(f"model_output.input_ids : {model_output.input_ids}") + # print(f"model_output.stop_nums : {model_output.stop_nums}") + # print(f"model_output.next_tokens : {model_output.next_tokens}") + # print(f"model_output.is_block_step : {model_output.is_block_step}") + # print(f"share_inputs['block_tables'] : {share_inputs['block_tables']}") + # print(f"block_size : {block_size}") + update_inputs_v1( + model_output.stop_flags, + model_output.not_need_stop, + model_output.seq_lens_this_time, + model_output.seq_lens_encoder, + model_output.seq_lens_decoder, + share_inputs["step_seq_lens_decoder"], + share_inputs["prompt_lens"], + sampled_token_ids, + model_output.input_ids, + share_inputs["block_tables"], + model_output.stop_nums, + model_output.next_tokens, + model_output.is_block_step, + block_size, + ) + # print(f"============================================update_inputs_v1 更新后=========================================") + # print(f"model_output.stop_flags : {model_output.stop_flags}") + # print(f"model_output.not_need_stop : {model_output.not_need_stop}") + # print(f"model_output.seq_lens_this_time : {model_output.seq_lens_this_time}") + # print(f"model_output.seq_lens_encoder : {model_output.seq_lens_encoder}") + # print(f"model_output.seq_lens_decoder : {model_output.seq_lens_decoder}") + # print(f"share_inputs['step_seq_lens_decoder'] : {share_inputs['step_seq_lens_decoder']}") + # print(f"share_inputs['prompt_lens'] : {share_inputs['prompt_lens']}") + # print(f"sampled_token_ids : {sampled_token_ids}") + # print(f"model_output.input_ids : {model_output.input_ids}") + # print(f"model_output.stop_nums : {model_output.stop_nums}") + # print(f"model_output.next_tokens : {model_output.next_tokens}") + # print(f"model_output.is_block_step : {model_output.is_block_step}") + # print(f"share_inputs['block_tables'] : {share_inputs['block_tables']}") + # print(f"block_size : {block_size}") + else: + update_inputs( + model_output.stop_flags, + model_output.not_need_stop, + model_output.seq_lens_this_time, + model_output.seq_lens_encoder, + model_output.seq_lens_decoder, + model_output.input_ids, + model_output.stop_nums, + sampled_token_ids, + model_output.is_block_step, + ) # 3. Transmit the model's output and stop generation signal via message queue. # In the future, we will abandon this approach. if not skip_save_output: @@ -290,6 +367,96 @@ class XPUModelRunner(ModelRunnerBase): # Forward meta store the global meta information of the forward self.forward_meta: ForwardMeta = None + def insert_tasks_v1(self, req_dicts: List[Request]): + """ + Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1 + """ + # NOTE(luotingdan): Lazy initialize kv cache + if "caches" not in self.share_inputs: + self.initialize_kv_cache() + + req_len = len(req_dicts) + has_prefill_task = False + for i in range(req_len): + request = req_dicts[i] + idx = request.idx + if request.task_type.value == RequestType.PREFILL.value: # prefill task + logger.debug(f"Handle prefill request {request} at idx {idx}") + prefill_start_index = request.prefill_start_index + prefill_end_index = request.prefill_end_index + length = prefill_end_index - prefill_start_index + input_ids = request.prompt_token_ids + request.output_token_ids + self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array( + input_ids[prefill_start_index:prefill_end_index] + ) + encoder_block_num = len(request.block_tables) + self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num + self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( + request.block_tables, dtype="int32" + ) + self.share_inputs["stop_flags"][idx : idx + 1] = False + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length + self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0 + self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids) + self.share_inputs["is_block_step"][idx : idx + 1] = False + self.share_inputs["step_idx"][idx : idx + 1] = ( + len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 + ) + has_prefill_task = True + elif request.task_type.value == RequestType.DECODE.value: # decode task + logger.debug(f"Handle decode request {request} at idx {idx}") + encoder_block_num = len(request.block_tables) + self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num + self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( + request.block_tables, dtype="int32" + ) + continue + else: # preempted task + logger.debug(f"Handle preempted request {request} at idx {idx}") + self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + self.share_inputs["stop_flags"][idx : idx + 1] = True + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 0 + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + self.share_inputs["is_block_step"][idx : idx + 1] = False + continue + + if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens: + request.eos_token_ids.append(request.eos_token_ids[0]) + self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) + + self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7) + self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95) + self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0) + self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0) + self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0) + + self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) + self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( + "max_tokens", self.model_config.max_model_len + ) + + self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1] + self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = length + + if request.get("seed") is not None: + self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") + + if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: + stop_seqs_num = len(request.get("stop_seqs_len")) + for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): + request.stop_seqs_len.append(0) + self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32") + self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array( + request.get("stop_token_ids"), dtype="int64" + ) + if has_prefill_task: + self.share_inputs["not_need_stop"][0] = True + def process_prefill_inputs(self, req_dicts: List[Request]): """Process inputs for prefill tasks and update share_inputs buffer""" req_len = len(req_dicts) @@ -392,6 +559,8 @@ class XPUModelRunner(ModelRunnerBase): self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.share_inputs["not_need_stop"] = paddle.full( [1], False, dtype="bool" @@ -455,8 +624,19 @@ class XPUModelRunner(ModelRunnerBase): dtype="int32", ) - def _prepare_inputs(self) -> None: + def _prepare_inputs(self, is_dummy_run=False) -> None: """prepare the model inputs""" + if envs.ENABLE_V1_KVCACHE_SCHEDULER and not is_dummy_run: + recover_decode_task( + self.share_inputs["stop_flags"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["step_seq_lens_decoder"], + self.share_inputs["block_tables"], + self.share_inputs["is_block_step"], + self.parallel_config.block_size, + ) self.forward_meta = xpu_pre_process( self.share_inputs["input_ids"], self.share_inputs["seq_lens_this_time"], @@ -655,7 +835,7 @@ class XPUModelRunner(ModelRunnerBase): intermediate_tensors: """ # 1. Prepare inputs of model and decoder. - self._prepare_inputs() + self._prepare_inputs(is_dummy_run=is_dummy_run) # 2. Padding inputs for cuda grph @@ -699,6 +879,8 @@ class XPUModelRunner(ModelRunnerBase): xpu_post_process( sampled_token_ids=sampler_output.sampled_token_ids, model_output=model_output_data, + share_inputs=self.share_inputs, + block_size=self.parallel_config.block_size, skip_save_output=is_dummy_run, ) diff --git a/fastdeploy/worker/xpu_worker.py b/fastdeploy/worker/xpu_worker.py index 7c935f78e..82e239202 100644 --- a/fastdeploy/worker/xpu_worker.py +++ b/fastdeploy/worker/xpu_worker.py @@ -20,6 +20,7 @@ from typing import List, Optional import paddle from paddle import nn +from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request from fastdeploy.utils import get_logger @@ -154,7 +155,10 @@ class XpuWorker(WorkerBase): TODO(gongshaotian):The scheduler should schedule the handling of prefill, and workers and modelrunners should not perceive it. """ - self.model_runner.process_prefill_inputs(req_dicts=req_dicts) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.model_runner.insert_tasks_v1(req_dicts=req_dicts) + else: + self.model_runner.process_prefill_inputs(req_dicts=req_dicts) def check_health(self) -> bool: """ """