diff --git a/custom_ops/xpu_ops/src/ops/adjust_batch.cc b/custom_ops/xpu_ops/src/ops/adjust_batch.cc index d263d2cae..fb3b31688 100644 --- a/custom_ops/xpu_ops/src/ops/adjust_batch.cc +++ b/custom_ops/xpu_ops/src/ops/adjust_batch.cc @@ -18,38 +18,49 @@ #include "utility/helper.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + template std::vector AdjustBatchKernel( const paddle::Tensor &x, // [token_num, dim_embed] const paddle::Tensor &cum_offsets, // [bsz, 1] const paddle::Tensor &encoder_seq_lod, + const paddle::Tensor &decoder_seq_lod, const paddle::Tensor &encoder_batch_idx, const paddle::Tensor &decoder_batch_idx, const paddle::Tensor &encoder_seq_lod_cpu, + const paddle::Tensor &decoder_seq_lod_cpu, const paddle::Tensor &encoder_batch_idx_cpu, const paddle::Tensor &decoder_batch_idx_cpu, - const paddle::Tensor &enc_batch_tensor, - const paddle::Tensor &dec_batch_tensor, + const paddle::Tensor &len_info_cpu, const paddle::optional &output_padding_offset, int max_input_length) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); + auto ctx = static_cast(dev_ctx)->x_context(); PD_CHECK(x.dtype() == T); PD_CHECK(x.dims().size() == 2); - + if (x.is_cpu()) { + ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU); + } using XPUType = typename XPUTypeTrait::DataType>::Type; using data_t = typename PDTraits::data_t; const int token_num = x.dims()[0]; const int dim = x.dims()[1]; const int bsz = cum_offsets.shape()[0]; - int enc_batch = enc_batch_tensor.data()[0]; - int dec_batch = dec_batch_tensor.data()[0]; + int enc_batch = len_info_cpu.data()[0]; + int dec_batch = len_info_cpu.data()[1]; baidu::xpu::api::VectorParam encoder_seqs_lods_vp{ const_cast(encoder_seq_lod_cpu.data()), enc_batch + 1, const_cast(encoder_seq_lod.data())}; + baidu::xpu::api::VectorParam decoder_seqs_lods_vp{ + const_cast(decoder_seq_lod_cpu.data()), + dec_batch + 1, + const_cast(decoder_seq_lod.data())}; baidu::xpu::api::VectorParam encoder_batch_map_vp{ const_cast(encoder_batch_idx_cpu.data()), enc_batch, @@ -59,13 +70,14 @@ std::vector AdjustBatchKernel( dec_batch, const_cast(decoder_batch_idx.data())}; - auto out = paddle::full({token_num, dim}, -2, x.type(), x.place()); + auto out = paddle::empty({token_num, dim}, x.type(), x.place()); int r = baidu::xpu::api::plugin::eb_adjust_batch( - xpu_ctx->x_context(), + ctx, reinterpret_cast(x.data()), reinterpret_cast(out.data()), encoder_seqs_lods_vp, + decoder_seqs_lods_vp, encoder_batch_map_vp, decoder_batch_map_vp, dim); @@ -76,13 +88,14 @@ using AdjustBatchKernelFuncPtr = std::vector (*)( const paddle::Tensor &x, // [token_num, dim_embed] const paddle::Tensor &cum_offsets, // [bsz, 1] const paddle::Tensor &encoder_seq_lod, + const paddle::Tensor &decoder_seq_lod, const paddle::Tensor &encoder_batch_idx, const paddle::Tensor &decoder_batch_idx, const paddle::Tensor &encoder_seq_lod_cpu, + const paddle::Tensor &decoder_seq_lod_cpu, const paddle::Tensor &encoder_batch_idx_cpu, const paddle::Tensor &decoder_batch_idx_cpu, - const paddle::Tensor &enc_batch_tensor, - const paddle::Tensor &dec_batch_tensor, + const paddle::Tensor &len_info_cpu, const paddle::optional &output_padding_offset, int max_input_length); @@ -90,13 +103,14 @@ std::vector AdjustBatch( const paddle::Tensor &x, // [token_num, dim_embed] const paddle::Tensor &cum_offsets, // [bsz, 1] const paddle::Tensor &encoder_seq_lod, + const paddle::Tensor &decoder_seq_lod, const paddle::Tensor &encoder_batch_idx, const paddle::Tensor &decoder_batch_idx, const paddle::Tensor &encoder_seq_lod_cpu, + const paddle::Tensor &decoder_seq_lod_cpu, const paddle::Tensor &encoder_batch_idx_cpu, const paddle::Tensor &decoder_batch_idx_cpu, - const paddle::Tensor &enc_batch_tensor, - const paddle::Tensor &dec_batch_tensor, + const paddle::Tensor &len_info_cpu, const paddle::optional &output_padding_offset, int max_input_length) { AdjustBatchKernelFuncPtr func = nullptr; @@ -108,12 +122,12 @@ std::vector AdjustBatch( case paddle::DataType::FLOAT16: func = &AdjustBatchKernel; break; - case paddle::DataType::FLOAT32: - func = &AdjustBatchKernel; - break; case paddle::DataType::INT64: func = &AdjustBatchKernel; break; + case paddle::DataType::FLOAT32: + func = &AdjustBatchKernel; + break; default: PD_THROW("Unsupported data type: ", x.dtype()); } @@ -121,13 +135,14 @@ std::vector AdjustBatch( return func(x, cum_offsets, encoder_seq_lod, + decoder_seq_lod, encoder_batch_idx, decoder_batch_idx, encoder_seq_lod_cpu, + decoder_seq_lod_cpu, encoder_batch_idx_cpu, decoder_batch_idx_cpu, - enc_batch_tensor, - dec_batch_tensor, + len_info_cpu, output_padding_offset, max_input_length); } @@ -136,13 +151,14 @@ std::vector> AdjustBatchInferShape( const std::vector &x_shape, const std::vector &cum_offsets_shape, const std::vector &encoder_seq_lod_shape, + const std::vector &decoder_seq_lod_shape, const std::vector &encoder_batch_idx_shape, const std::vector &decoder_batch_idx_shape, const std::vector &encoder_seq_lod_cpu_shape, + const std::vector &decoder_seq_lod_cpu_shape, const std::vector &encoder_batch_idx_cpu_shape, const std::vector &decoder_batch_idx_cpu_shape, - const std::vector &enc_batch_tensor_shape, - const std::vector &dec_batch_tensor_shape, + const std::vector &len_info_cpu_shape, const paddle::optional> &output_padding_offset_shape) { if (output_padding_offset_shape) { PD_THROW("speculative decoding is not supported in XPU."); @@ -156,28 +172,30 @@ std::vector AdjustBatchInferDtype( const paddle::DataType &x_dtype, const paddle::DataType &cum_offsets_dtype, const paddle::DataType &encoder_seq_lod_dtype, + const paddle::DataType &decoder_seq_lod_dtype, const paddle::DataType &encoder_batch_idx_dtype, const paddle::DataType &decoder_batch_idx_dtype, const paddle::DataType &encoder_seq_lod_cpu_dtype, + const paddle::DataType &decoder_seq_lod_cpu_dtype, const paddle::DataType &encoder_batch_idx_cpu_dtype, const paddle::DataType &decoder_batch_idx_cpu_dtype, - const paddle::DataType &enc_batch_tensor_dtype, - const paddle::DataType &dec_batch_tensor_dtype, + const paddle::DataType &len_info_cpu_dtype, const paddle::optional &output_padding_offset_dtype) { return {x_dtype}; } -PD_BUILD_OP(adjust_batch) +PD_BUILD_STATIC_OP(adjust_batch) .Inputs({"x", "cum_offsets", "encoder_seq_lod", + "decoder_seq_lod", "encoder_batch_idx", "decoder_batch_idx", "encoder_seq_lod_cpu", + "decoder_seq_lod_cpu", "encoder_batch_idx_cpu", "decoder_batch_idx_cpu", - "enc_batch_tensor", - "dec_batch_tensor", + "len_info_cpu", paddle::Optional("output_padding_offset")}) .Outputs({"out"}) .Attrs({"max_input_length: int"}) diff --git a/custom_ops/xpu_ops/src/ops/block_attn.cc b/custom_ops/xpu_ops/src/ops/block_attn.cc index 6153a77dd..c9e3313f2 100644 --- a/custom_ops/xpu_ops/src/ops/block_attn.cc +++ b/custom_ops/xpu_ops/src/ops/block_attn.cc @@ -722,7 +722,6 @@ std::vector BlockAttnKernel( : quant_v_scale_inv, nullptr, // o_maxptr param.head_dim); // vo_head_dim - PD_CHECK(0, "speculative_attention unimplemented"); PD_CHECK(ret == api::SUCCESS, "xfa::speculative_attention_decoder failed."); if (!Eq_len) { diff --git a/custom_ops/xpu_ops/src/ops/gather_next_token.cc b/custom_ops/xpu_ops/src/ops/gather_next_token.cc index bc875b372..9a35f91f9 100644 --- a/custom_ops/xpu_ops/src/ops/gather_next_token.cc +++ b/custom_ops/xpu_ops/src/ops/gather_next_token.cc @@ -13,107 +13,169 @@ // limitations under the License. #include +#include #include "paddle/extension.h" #include "xpu/plugin.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + std::vector GatherNextToken( - const paddle::Tensor &tmp_out, // [token_num, dim_embed] - const paddle::Tensor &cum_offsets, // [bsz, 1] - const paddle::Tensor &encoder_seq_lod, - const paddle::Tensor &encoder_batch_map, - const paddle::Tensor &decoder_batch_map, - const paddle::Tensor &encoder_seq_lod_cpu, - const paddle::Tensor &encoder_batch_map_cpu, - const paddle::Tensor &decoder_batch_map_cpu, - const paddle::Tensor &enc_batch_tensor, - const paddle::Tensor &dec_batch_tensor, - const paddle::optional &output_padding_offset, - int max_input_length) { + const paddle::Tensor& x, // [token_num, dim_embed] + const paddle::Tensor& cum_offsets, // [bsz, 1] + const paddle::Tensor& encoder_seq_lod, + const paddle::Tensor& decoder_seq_lod, + const paddle::Tensor& encoder_batch_map, + const paddle::Tensor& decoder_batch_map, + const paddle::Tensor& encoder_seq_lod_cpu, + const paddle::Tensor& decoder_seq_lod_cpu, + const paddle::Tensor& encoder_batch_map_cpu, + const paddle::Tensor& decoder_batch_map_cpu, + const paddle::Tensor& len_info_cpu, + const paddle::optional& output_padding_offset, + int max_bsz) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); + auto ctx = static_cast(dev_ctx)->x_context(); + if (x.is_cpu()) { + ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU); + } using XPUType = typename XPUTypeTrait::Type; // only support bfloat16 typedef paddle::bfloat16 data_t; - const int dim = tmp_out.dims()[1]; - const int bsz = cum_offsets.shape()[0]; - int enc_batch = enc_batch_tensor.data()[0]; - int dec_batch = dec_batch_tensor.data()[0]; - + const int dim = x.dims()[1]; + const int token_num = x.shape()[0]; + int bsz = cum_offsets.shape()[0]; + int enc_batch = len_info_cpu.data()[0]; + int dec_batch = len_info_cpu.data()[1]; + if (max_bsz > 0) { + PD_CHECK(encoder_batch_map_cpu.data()[enc_batch - 1] <= max_bsz, + "encoder_batch_map_cpu check failed"); + PD_CHECK(decoder_batch_map_cpu.data()[dec_batch - 1] <= max_bsz, + "decoder_batch_map_cpu check failed"); + bsz = max_bsz; + } baidu::xpu::api::VectorParam encoder_seqs_lods_vp{ - const_cast(encoder_seq_lod_cpu.data()), + const_cast(encoder_seq_lod_cpu.data()), enc_batch + 1, - const_cast(encoder_seq_lod.data())}; + const_cast(encoder_seq_lod.data())}; + baidu::xpu::api::VectorParam decoder_seqs_lods_vp{ + const_cast(decoder_seq_lod_cpu.data()), + dec_batch + 1, + const_cast(decoder_seq_lod.data())}; baidu::xpu::api::VectorParam encoder_batch_map_vp{ - const_cast(encoder_batch_map_cpu.data()), + const_cast(encoder_batch_map_cpu.data()), enc_batch, - const_cast(encoder_batch_map.data())}; + const_cast(encoder_batch_map.data())}; baidu::xpu::api::VectorParam decoder_batch_map_vp{ - const_cast(decoder_batch_map_cpu.data()), + const_cast(decoder_batch_map_cpu.data()), dec_batch, - const_cast(decoder_batch_map.data())}; + const_cast(decoder_batch_map.data())}; - auto out = paddle::full({bsz, dim}, -2, tmp_out.type(), tmp_out.place()); + paddle::Tensor out; + if (output_padding_offset) { + int need_delete_token_num = 0; + if (enc_batch > 0) { + need_delete_token_num = + encoder_seq_lod_cpu.data()[enc_batch] - enc_batch; + } + out = paddle::empty( + {token_num - need_delete_token_num, dim}, x.type(), x.place()); + } else { + out = paddle::empty({bsz, dim}, x.type(), x.place()); + } + if (x.shape()[0] == 0) { + return {out}; + } - int r = baidu::xpu::api::plugin::eb_gather_next_token( - xpu_ctx->x_context(), - reinterpret_cast(tmp_out.data()), - reinterpret_cast(out.data()), - encoder_seqs_lods_vp, - encoder_batch_map_vp, - decoder_batch_map_vp, - dim); + if (enc_batch <= 0) { + out = x.copy_to(x.place(), false); + } else { + if (output_padding_offset) { + int r = + baidu::xpu::api::plugin::eb_mtp_gather_next_token( + ctx, + reinterpret_cast(x.data()), + reinterpret_cast(out.data()), + encoder_seqs_lods_vp, + decoder_seqs_lods_vp, + encoder_batch_map_vp, + decoder_batch_map_vp, + dim); + PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed."); + } else { + int r = baidu::xpu::api::plugin::eb_gather_next_token( + ctx, + reinterpret_cast(x.data()), + reinterpret_cast(out.data()), + encoder_seqs_lods_vp, + encoder_batch_map_vp, + decoder_batch_map_vp, + dim); + PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed."); + } + } return {out}; } std::vector> GatherNextTokenInferShape( - const std::vector &tmp_out_shape, - const std::vector &cum_offsets_shape, - const std::vector &encoder_seq_lod_shape, - const std::vector &encoder_batch_map_shape, - const std::vector &decoder_batch_map_shape, - const std::vector &encoder_seq_lod_cpu_shape, - const std::vector &encoder_batch_map_cpu_shape, - const std::vector &decoder_batch_map_cpu_shape, - const std::vector &enc_batch_tensor_shape, - const std::vector &dec_batch_tensor_shape, - const paddle::optional> &output_padding_offset_shape) { - if (output_padding_offset_shape) { - PD_THROW("speculative decoding is not supported in XPU."); - } + const std::vector& x_shape, + const std::vector& cum_offsets_shape, + const std::vector& encoder_seq_lod_shape, + const std::vector& decoder_seq_lod_shape, + const std::vector& encoder_batch_map_shape, + const std::vector& decoder_batch_map_shape, + const std::vector& encoder_seq_lod_cpu_shape, + const std::vector& decoder_seq_lod_cpu_shape, + const std::vector& encoder_batch_map_cpu_shape, + const std::vector& decoder_batch_map_cpu_shape, + const std::vector& len_info_cpu_shape, + const paddle::optional>& output_padding_offset_shape) { + // if (output_padding_offset_shape) { + // PD_THROW("speculative decoding is not supported in XPU."); + // } int64_t bsz = cum_offsets_shape[0]; - int64_t dim_embed = tmp_out_shape[1]; - return {{bsz, dim_embed}}; + int64_t dim_embed = x_shape[1]; + if (output_padding_offset_shape) { + return {{-1, dim_embed}}; + } else { + int64_t bsz = cum_offsets_shape[0]; + return {{bsz, dim_embed}}; + } } std::vector GatherNextTokenInferDtype( - const paddle::DataType &tmp_out_dtype, - const paddle::DataType &cum_offsets_dtype, - const paddle::DataType &encoder_seq_lod_dtype, - const paddle::DataType &encoder_batch_map_dtype, - const paddle::DataType &decoder_batch_map_dtype, - const paddle::DataType &encoder_seq_lod_cpu_dtype, - const paddle::DataType &encoder_batch_map_cpu_dtype, - const paddle::DataType &decoder_batch_map_cpu_dtype, - const paddle::DataType &enc_batch_tensor_dtype, - const paddle::DataType &dec_batch_tensor_dtype, - const paddle::optional &output_padding_offset_dtype) { - return {tmp_out_dtype}; + const paddle::DataType& x_dtype, + const paddle::DataType& cum_offsets_dtype, + const paddle::DataType& encoder_seq_lod_dtype, + const paddle::DataType& decoder_seq_lod_dtype, + const paddle::DataType& encoder_batch_map_dtype, + const paddle::DataType& decoder_batch_map_dtype, + const paddle::DataType& encoder_seq_lod_cpu_dtype, + const paddle::DataType& decoder_seq_lod_cpu_dtype, + const paddle::DataType& encoder_batch_map_cpu_dtype, + const paddle::DataType& decoder_batch_map_cpu_dtype, + const paddle::DataType& len_info_cpu_dtype, + const paddle::optional& output_padding_offset_dtype) { + return {x_dtype}; } -PD_BUILD_OP(gather_next_token) - .Inputs({"tmp_out", +PD_BUILD_STATIC_OP(gather_next_token) + .Inputs({"x", "cum_offsets", "encoder_seq_lod", + "decoder_seq_lod", "encoder_batch_map", "decoder_batch_map", "encoder_seq_lod_cpu", + "decoder_seq_lod_cpu", "encoder_batch_map_cpu", "decoder_batch_map_cpu", - "enc_batch_tensor", - "dec_batch_tensor", + "len_info_cpu", paddle::Optional("output_padding_offset")}) .Outputs({"out"}) - .Attrs({"max_input_length: int"}) + .Attrs({"max_bsz: int"}) .SetKernelFn(PD_KERNEL(GatherNextToken)) .SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess.cc b/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess.cc index ec501a790..a4cf8e687 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess.cc @@ -29,21 +29,23 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, - const paddle::Tensor& seq_lens_encoder_record, - const paddle::Tensor& seq_lens_decoder_record, const paddle::Tensor& not_need_stop, + const paddle::Tensor& is_block_step, const paddle::Tensor& batch_drop, + const paddle::Tensor& pre_ids, const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, + const paddle::Tensor& base_model_seq_lens_this_time, const paddle::Tensor& base_model_seq_lens_encoder, const paddle::Tensor& base_model_seq_lens_decoder, const paddle::Tensor& base_model_step_idx, const paddle::Tensor& base_model_stop_flags, const paddle::Tensor& base_model_is_block_step, const paddle::Tensor& base_model_draft_tokens, - const int max_draft_token, + const int num_model_step, const bool truncate_first_token, - const bool splitwise_prefill) { + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); api::Context* ctx = static_cast(dev_ctx)->x_context(); @@ -54,6 +56,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, int accept_tokens_len = accept_tokens.shape()[1]; int input_ids_len = input_ids.shape()[1]; int draft_tokens_len = draft_tokens.shape()[1]; + int pre_ids_len = pre_ids.shape()[1]; + constexpr int BlockSize = 512; int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1]; auto not_need_stop_gpu = not_need_stop.copy_to(seq_lens_this_time.place(), false); @@ -67,12 +71,13 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const_cast(seq_lens_encoder.data()), const_cast(seq_lens_decoder.data()), const_cast(step_idx.data()), - const_cast(seq_lens_encoder_record.data()), - const_cast(seq_lens_decoder_record.data()), const_cast(not_need_stop_gpu.data()), + const_cast(is_block_step.data()), const_cast(batch_drop.data()), + const_cast(pre_ids.data()), accept_tokens.data(), accept_num.data(), + base_model_seq_lens_this_time.data(), base_model_seq_lens_encoder.data(), base_model_seq_lens_decoder.data(), base_model_step_idx.data(), @@ -80,13 +85,16 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, base_model_is_block_step.data(), const_cast(base_model_draft_tokens.data()), real_bsz, - max_draft_token, + num_model_step, accept_tokens_len, draft_tokens_len, input_ids_len, base_model_draft_tokens_len, + pre_ids_len, truncate_first_token, - splitwise_prefill); + splitwise_prefill, + kvcache_scheduler_v1); + PD_CHECK(r == 0, "xpu::plugin::draft_model_preprocess failed."); auto not_need_stop_cpu = not_need_stop_gpu.copy_to(not_need_stop.place(), false); @@ -102,12 +110,13 @@ PD_BUILD_STATIC_OP(draft_model_preprocess) "seq_lens_encoder", "seq_lens_decoder", "step_idx", - "seq_lens_encoder_record", - "seq_lens_decoder_record", "not_need_stop", + "is_block_step", "batch_drop", + "pre_ids", "accept_tokens", "accept_num", + "base_model_seq_lens_this_time", "base_model_seq_lens_encoder", "base_model_seq_lens_decoder", "base_model_step_idx", @@ -123,11 +132,11 @@ PD_BUILD_STATIC_OP(draft_model_preprocess) "step_idx_out", "not_need_stop_out", "batch_drop_out", - "seq_lens_encoder_record_out", - "seq_lens_decoder_record_out"}) - .Attrs({"max_draft_token: int", + "pre_ids_out"}) + .Attrs({"num_model_step: int", "truncate_first_token: bool", - "splitwise_prefill: bool"}) + "splitwise_prefill: bool", + "kvcache_scheduler_v1: bool"}) .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, {"input_ids", "input_ids_out"}, {"stop_flags", "stop_flags_out"}, @@ -137,6 +146,5 @@ PD_BUILD_STATIC_OP(draft_model_preprocess) {"step_idx", "step_idx_out"}, {"not_need_stop", "not_need_stop_out"}, {"batch_drop", "batch_drop_out"}, - {"seq_lens_encoder_record", "seq_lens_encoder_record_out"}, - {"seq_lens_decoder_record", "seq_lens_decoder_record_out"}}) + {"pre_ids", "pre_ids_out"}}) .SetKernelFn(PD_KERNEL(DraftModelPreprocess)); diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc index 1cf14b810..f22dc7aaa 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc @@ -43,6 +43,8 @@ std::vector SpeculateGetPaddingOffset( {token_num_data}, paddle::DataType::INT64, input_ids.place()); auto padding_offset = paddle::empty( {token_num_data}, paddle::DataType::INT32, input_ids.place()); + auto batch_id_per_token = paddle::empty( + {token_num_data}, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_q = paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_k = @@ -57,7 +59,7 @@ std::vector SpeculateGetPaddingOffset( int r = baidu::xpu::api::plugin::speculate_get_padding_offset( xpu_ctx->x_context(), - padding_offset.data(), + batch_id_per_token.data(), cum_offsets_out.data(), cu_seqlens_q.data(), cu_seqlens_k.data(), @@ -83,7 +85,7 @@ std::vector SpeculateGetPaddingOffset( return {x_remove_padding, cum_offsets_out, - padding_offset, + batch_id_per_token, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num}; } @@ -123,7 +125,7 @@ PD_BUILD_STATIC_OP(speculate_get_padding_offset) "seq_lens_encoder"}) .Outputs({"x_remove_padding", "cum_offsets_out", - "padding_offset", + "batch_id_per_token", "cu_seqlens_q", "cu_seqlens_k"}) .SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset)) diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_save_output.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_save_output.cc index 60764b26a..a8e61c708 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_save_output.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_save_output.cc @@ -35,7 +35,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens, const paddle::Tensor& not_need_stop, int64_t rank_id, int msg_queue_id, - int save_each_rank) { + bool save_each_rank) { // printf("enter save output"); if (!save_each_rank && rank_id > 0) { return; diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc new file mode 100644 index 000000000..d8b113fb8 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc @@ -0,0 +1,187 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +void SpeculateStepPaddle( + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &ori_seq_lens_encoder, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &step_block_list, + const paddle::Tensor &step_lens, + const paddle::Tensor &recover_block_list, + const paddle::Tensor &recover_lens, + const paddle::Tensor &need_block_list, + const paddle::Tensor &need_block_len, + const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, + const paddle::Tensor &free_list_len, + const paddle::Tensor &input_ids, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &next_tokens, + const paddle::Tensor &first_token_ids, + const paddle::Tensor &accept_num, + const int block_size, + const int encoder_decoder_block_num, + const int max_draft_tokens) { + namespace api = baidu::xpu::api; + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + api::Context *ctx = xpu_ctx->x_context(); + if (seq_lens_this_time.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + const int bsz = seq_lens_this_time.shape()[0]; + PADDLE_ENFORCE_LE( + bsz, + 640, + phi::errors::InvalidArgument( + "Only support bsz <= 640, but received bsz is %d", bsz)); + const int block_num_per_seq = block_tables.shape()[1]; + const int length = input_ids.shape()[1]; + const int pre_id_length = pre_ids.shape()[1]; + const int max_decoder_block_num = pre_id_length / block_size; + int r = baidu::xpu::api::plugin::speculate_free_and_dispatch_block( + ctx, + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(encoder_block_lens.data()), + const_cast(is_block_step.data()), + const_cast(step_block_list.data()), + const_cast(step_lens.data()), + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(need_block_list.data()), + const_cast(need_block_len.data()), + const_cast(used_list_len.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + const_cast(first_token_ids.data()), + const_cast(accept_num.data()), + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + PD_CHECK(r == 0, "speculate_free_and_dispatch_block failed."); + auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false); + int recover_lens_cpu_data = recover_lens_cpu.data()[0]; + if (recover_lens_cpu_data > 0) { + r = baidu::xpu::api::plugin::speculate_recover_block( + ctx, + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + ori_seq_lens_encoder.data(), + const_cast(seq_lens_encoder.data()), + seq_lens_decoder.data(), + const_cast(block_tables.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + const_cast(input_ids.data()), + pre_ids.data(), + step_idx.data(), + encoder_block_lens.data(), + used_list_len.data(), + next_tokens.data(), + first_token_ids.data(), + bsz, + block_num_per_seq, + length, + pre_id_length); + PD_CHECK(r == 0, "speculate_recover_block failed."); + } +} + +PD_BUILD_STATIC_OP(speculate_step_paddle) + .Inputs({"stop_flags", + "seq_lens_this_time", + "ori_seq_lens_encoder", + "seq_lens_encoder", + "seq_lens_decoder", + "block_tables", + "encoder_block_lens", + "is_block_step", + "step_block_list", + "step_lens", + "recover_block_list", + "recover_lens", + "need_block_list", + "need_block_len", + "used_list_len", + "free_list", + "free_list_len", + "input_ids", + "pre_ids", + "step_idx", + "next_tokens", + "first_token_ids", + "accept_num"}) + .Attrs({"block_size: int", + "encoder_decoder_block_num: int", + "max_draft_tokens: int"}) + .Outputs({"stop_flags_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "block_tables_out", + "encoder_block_lens_out", + "is_block_step_out", + "step_block_list_out", + "step_lens_out", + "recover_block_list_out", + "recover_lens_out", + "need_block_list_out", + "need_block_len_out", + "used_list_len_out", + "free_list_out", + "free_list_len_out", + "input_ids_out", + "first_token_ids_out"}) + .SetInplaceMap({{"stop_flags", "stop_flags_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"block_tables", "block_tables_out"}, + {"encoder_block_lens", "encoder_block_lens_out"}, + {"is_block_step", "is_block_step_out"}, + {"step_block_list", "step_block_list_out"}, + {"step_lens", "step_lens_out"}, + {"recover_block_list", "recover_block_list_out"}, + {"recover_lens", "recover_lens_out"}, + {"need_block_list", "need_block_list_out"}, + {"need_block_len", "need_block_len_out"}, + {"used_list_len", "used_list_len_out"}, + {"free_list", "free_list_out"}, + {"free_list_len", "free_list_len_out"}, + {"input_ids", "input_ids_out"}, + {"first_token_ids", "first_token_ids_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateStepPaddle)); diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc index 53b5b90dc..59df0f0f2 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc @@ -45,7 +45,10 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, const paddle::Tensor &topp, int max_seq_len, int verify_window, - bool enable_topp) { + bool enable_topp, + bool benchmark_mode, + bool accept_all_drafts) { + // TODO(chenhuan09):support accept_all_drafts auto bsz = accept_tokens.shape()[0]; int real_bsz = seq_lens_this_time.shape()[0]; auto max_draft_tokens = draft_tokens.shape()[1]; @@ -133,7 +136,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); } else { baidu::xpu::api::plugin::speculate_verify( ctx, @@ -161,7 +165,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); } } else { if (enable_topp) { @@ -191,7 +196,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); } else { baidu::xpu::api::plugin::speculate_verify( ctx, @@ -219,7 +225,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); } } } @@ -246,7 +253,11 @@ PD_BUILD_STATIC_OP(speculate_verify) "accept_num_out", "step_idx_out", "stop_flags_out"}) - .Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"}) + .Attrs({"max_seq_len: int", + "verify_window: int", + "enable_topp: bool", + "benchmark_mode: bool", + "accept_all_drafts: bool"}) .SetInplaceMap({{"accept_tokens", "accept_tokens_out"}, {"accept_num", "accept_num_out"}, {"step_idx", "step_idx_out"}, diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 79f89df37..0400aa02d 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -37,13 +37,14 @@ std::vector AdjustBatch( const paddle::Tensor& x, // [token_num, dim_embed] const paddle::Tensor& cum_offsets, // [bsz, 1] const paddle::Tensor& encoder_seq_lod, + const paddle::Tensor& decoder_seq_lod, const paddle::Tensor& encoder_batch_idx, const paddle::Tensor& decoder_batch_idx, const paddle::Tensor& encoder_seq_lod_cpu, + const paddle::Tensor& decoder_seq_lod_cpu, const paddle::Tensor& encoder_batch_idx_cpu, const paddle::Tensor& decoder_batch_idx_cpu, - const paddle::Tensor& enc_batch_tensor, - const paddle::Tensor& dec_batch_tensor, + const paddle::Tensor& len_info_cpu, const paddle::optional& output_padding_offset, int max_input_length); @@ -264,7 +265,9 @@ void SpeculateVerify(const paddle::Tensor& accept_tokens, const paddle::Tensor& topp, int max_seq_len, int verify_window, - bool enable_topp); + bool enable_topp, + bool benchmark_mode, + bool accept_all_drafts); void SpeculateClearAcceptNums(const paddle::Tensor& accept_num, const paddle::Tensor& seq_lens_decoder); @@ -285,21 +288,23 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, - const paddle::Tensor& seq_lens_encoder_record, - const paddle::Tensor& seq_lens_decoder_record, const paddle::Tensor& not_need_stop, + const paddle::Tensor& is_block_step, const paddle::Tensor& batch_drop, + const paddle::Tensor& pre_ids, const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, + const paddle::Tensor& base_model_seq_lens_this_time, const paddle::Tensor& base_model_seq_lens_encoder, const paddle::Tensor& base_model_seq_lens_decoder, const paddle::Tensor& base_model_step_idx, const paddle::Tensor& base_model_stop_flags, const paddle::Tensor& base_model_is_block_step, const paddle::Tensor& base_model_draft_tokens, - const int max_draft_token, + const int num_model_step, const bool truncate_first_token, - const bool splitwise_prefill); + const bool splitwise_prefill, + const bool kvcache_scheduler_v1); void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens, const paddle::Tensor& base_model_seq_lens_this_time, @@ -324,18 +329,19 @@ std::vector EagleGetSelfHiddenStates( const paddle::Tensor& step_idx); std::vector GatherNextToken( - const paddle::Tensor& tmp_out, // [token_num, dim_embed] + const paddle::Tensor& x, // [token_num, dim_embed] const paddle::Tensor& cum_offsets, // [bsz, 1] const paddle::Tensor& encoder_seq_lod, + const paddle::Tensor& decoder_seq_lod, const paddle::Tensor& encoder_batch_map, const paddle::Tensor& decoder_batch_map, const paddle::Tensor& encoder_seq_lod_cpu, + const paddle::Tensor& decoder_seq_lod_cpu, const paddle::Tensor& encoder_batch_map_cpu, const paddle::Tensor& decoder_batch_map_cpu, - const paddle::Tensor& enc_batch_tensor, - const paddle::Tensor& dec_batch_tensor, + const paddle::Tensor& len_info_cpu, const paddle::optional& output_padding_offset, - int max_input_length); + int max_bsz); std::vector GetImgBoundaries( const paddle::Tensor& task_input_ids, @@ -436,6 +442,34 @@ void MTPStepPaddle( const int block_size, const int max_draft_tokens); +void SpeculateStepPaddle( + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& ori_seq_lens_encoder, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor& encoder_block_lens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& step_block_list, + const paddle::Tensor& step_lens, + const paddle::Tensor& recover_block_list, + const paddle::Tensor& recover_lens, + const paddle::Tensor& need_block_list, + const paddle::Tensor& need_block_len, + const paddle::Tensor& used_list_len, + const paddle::Tensor& free_list, + const paddle::Tensor& free_list_len, + const paddle::Tensor& input_ids, + const paddle::Tensor& pre_ids, + const paddle::Tensor& step_idx, + const paddle::Tensor& next_tokens, + const paddle::Tensor& first_token_ids, + const paddle::Tensor& accept_num, + const int block_size, + const int encoder_decoder_block_num, + const int max_draft_tokens); + void SaveOutMmsgStatic(const paddle::Tensor& x, const paddle::Tensor& not_need_stop, int64_t rank_id, @@ -542,13 +576,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("x"), py::arg("cum_offsets"), py::arg("encoder_seq_lod"), + py::arg("decoder_seq_lod"), py::arg("encoder_batch_idx"), py::arg("decoder_batch_idx"), py::arg("encoder_seq_lod_cpu"), + py::arg("decoder_seq_lod_cpu"), py::arg("encoder_batch_idx_cpu"), py::arg("decoder_batch_idx_cpu"), - py::arg("enc_batch_tensor"), - py::arg("dec_batch_tensor"), + py::arg("len_info_cpu"), py::arg("output_padding_offset"), py::arg("max_input_length"), "adjust batch in XPU"); @@ -620,21 +655,23 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("seq_lens_encoder"), py::arg("seq_lens_decoder"), py::arg("step_idx"), - py::arg("seq_lens_encoder_record"), - py::arg("seq_lens_decoder_record"), py::arg("not_need_stop"), + py::arg("is_block_step"), py::arg("batch_drop"), + py::arg("pre_ids"), py::arg("accept_tokens"), py::arg("accept_num"), + py::arg("base_model_seq_lens_this_time"), py::arg("base_model_seq_lens_encoder"), py::arg("base_model_seq_lens_decoder"), py::arg("base_model_step_idx"), py::arg("base_model_stop_flags"), py::arg("base_model_is_block_step"), py::arg("base_model_draft_tokens"), - py::arg("max_draft_token"), + py::arg("num_model_step"), py::arg("truncate_first_token"), py::arg("splitwise_prefill"), + py::arg("kvcache_scheduler_v1"), "Preprocess data for draft model in speculative decoding"); m.def("draft_model_postprocess", @@ -727,18 +764,19 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("gather_next_token", &GatherNextToken, - py::arg("tmp_out"), + py::arg("x"), py::arg("cum_offsets"), py::arg("encoder_seq_lod"), + py::arg("decoder_seq_lod"), py::arg("encoder_batch_map"), py::arg("decoder_batch_map"), py::arg("encoder_seq_lod_cpu"), + py::arg("decoder_seq_lod_cpu"), py::arg("encoder_batch_map_cpu"), py::arg("decoder_batch_map_cpu"), - py::arg("enc_batch_tensor"), - py::arg("dec_batch_tensor"), + py::arg("len_info_cpu"), py::arg("output_padding_offset"), - py::arg("max_input_length"), + py::arg("max_bsz"), "Gather next token for XPU"); m.def("get_img_boundaries", @@ -983,6 +1021,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("max_seq_len"), py::arg("verify_window"), py::arg("enable_topp"), + py::arg("benchmark_mode"), + py::arg("accept_all_drafts"), "Perform speculative verification for decoding"); m.def("speculate_clear_accept_nums", @@ -1104,6 +1144,36 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("encoder_decoder_block_num"), "Step paddle function"); + m.def("speculate_step_paddle", + &SpeculateStepPaddle, + py::arg("stop_flags"), + py::arg("seq_lens_this_time"), + py::arg("ori_seq_lens_encoder"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("block_tables"), + py::arg("encoder_block_lens"), + py::arg("is_block_step"), + py::arg("step_block_list"), + py::arg("step_lens"), + py::arg("recover_block_list"), + py::arg("recover_lens"), + py::arg("need_block_list"), + py::arg("need_block_len"), + py::arg("used_list_len"), + py::arg("free_list"), + py::arg("free_list_len"), + py::arg("input_ids"), + py::arg("pre_ids"), + py::arg("step_idx"), + py::arg("next_tokens"), + py::arg("first_token_ids"), + py::arg("accept_num"), + py::arg("block_size"), + py::arg("encoder_decoder_block_num"), + py::arg("max_draft_tokens"), + "Step paddle function"); + m.def("text_image_gather_scatter", &TextImageGatherScatter, py::arg("input"), diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index 4e393b868..09a426a31 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -75,6 +75,48 @@ DLL_EXPORT int get_padding_offset(Context* ctx, const int max_seq_len, const int bs); +DLL_EXPORT int speculate_get_padding_offset(Context* ctx, + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz); + +DLL_EXPORT int draft_model_preprocess(api::Context* ctx, + int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + bool* not_need_stop, + bool* is_block_step, + bool* batch_drop, + int64_t* pre_ids, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1); + DLL_EXPORT int update_inputs(Context* ctx, bool* not_need_stop, int* seq_lens_this_time, @@ -111,6 +153,31 @@ DLL_EXPORT int free_and_dispatch_block(Context* ctx, const int block_num_per_seq, const int max_decoder_block_num); +DLL_EXPORT int speculate_free_and_dispatch_block( + Context* ctx, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_decoder, + int* block_tables, + int* encoder_block_lens, + bool* is_block_step, + int* step_block_list, // [bsz] + int* step_len, + int* recover_block_list, + int* recover_len, + int* need_block_list, + int* need_block_len, + int* used_list_len, + int* free_list, + int* free_list_len, + int64_t* first_token_ids, + int* accept_num, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens); + DLL_EXPORT int recover_block(Context* ctx, int* recover_block_list, // [bsz] int* recover_len, @@ -134,6 +201,29 @@ DLL_EXPORT int recover_block(Context* ctx, const int length, const int pre_id_length); +DLL_EXPORT int speculate_recover_block(Context* ctx, + int* recover_block_list, // [bsz] + int* recover_len, + bool* stop_flags, + int* seq_lens_this_time, + const int* ori_seq_lens_encoder, + int* seq_lens_encoder, + const int* seq_lens_decoder, + int* block_tables, + int* free_list, + int* free_list_len, + int64_t* input_ids, + const int64_t* pre_ids, + const int64_t* step_idx, + const int* encoder_block_lens, + const int* used_list_len, + const int64_t* next_tokens, + const int64_t* first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length); + DLL_EXPORT int recover_decode_task(Context* ctx, bool* stop_flags, int* seq_lens_this_time, @@ -172,6 +262,7 @@ DLL_EXPORT int eb_adjust_batch( const TX* x, TY* y, VectorParam& encoder_seqs_lods, // NOLINT + VectorParam& decoder_seqs_lods, // NOLINT VectorParam& encoder_batch_map, // NOLINT VectorParam& decoder_batch_map, // NOLINT int64_t hidden_dim); @@ -186,6 +277,17 @@ DLL_EXPORT int eb_gather_next_token( VectorParam& decoder_batch_map, // NOLINT int64_t hidden_dim); +template +DLL_EXPORT int eb_mtp_gather_next_token( + Context* ctx, + const TX* x, + TY* y, + VectorParam& encoder_seqs_lods, // NOLINT + VectorParam& decoder_seqs_lods, // NOLINT + VectorParam& encoder_batch_map, // NOLINT + VectorParam& decoder_batch_map, // NOLINT + int64_t hidden_dim); + template DLL_EXPORT int quant2d_per_channel(api::Context* ctx, const TX* x, @@ -305,7 +407,8 @@ DLL_EXPORT int speculate_verify(Context* ctx, const int max_seq_len, const int max_candidate_len, const int verify_window, - const bool prefill_one_step_stop); + const bool prefill_one_step_stop, + const bool benchmark_mode); DLL_EXPORT int speculate_clear_accept_nums(Context* ctx, int* accept_num, @@ -342,35 +445,6 @@ DLL_EXPORT int draft_model_update(Context* ctx, const int substep, const bool prefill_one_step_stop); -DLL_EXPORT int draft_model_preprocess(api::Context* ctx, - int64_t* draft_tokens, - int64_t* input_ids, - bool* stop_flags, - int* seq_lens_this_time, - int* seq_lens_encoder, - int* seq_lens_decoder, - int64_t* step_idx, - int* seq_lens_encoder_record, - int* seq_lens_decoder_record, - bool* not_need_stop, - bool* batch_drop, - const int64_t* accept_tokens, - const int* accept_num, - const int* base_model_seq_lens_encoder, - const int* base_model_seq_lens_decoder, - const int64_t* base_model_step_idx, - const bool* base_model_stop_flags, - const bool* base_model_is_block_step, - int64_t* base_model_draft_tokens, - int real_bsz, - int max_draft_token, - int accept_tokens_len, - int draft_tokens_len, - int input_ids_len, - int base_model_draft_tokens_len, - bool truncate_first_token, - bool splitwise_prefill); - DLL_EXPORT int speculate_set_stop_value_multi_seqs(Context* ctx, bool* stop_flags, int64_t* accept_tokens, @@ -411,16 +485,6 @@ DLL_EXPORT int speculate_remove_padding(Context* ctx, int bsz, int token_num_data); -DLL_EXPORT int speculate_get_padding_offset(Context* ctx, - int* padding_offset, - int* cum_offsets_out, - int* cu_seqlens_q, - int* cu_seqlens_k, - const int* cum_offsets, - const int* seq_lens, - const int max_seq_len, - int bsz); - DLL_EXPORT int compute_self_order(api::Context* ctx, const int* last_seq_lens_this_time, const int* seq_lens_this_time, diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/eb_adjust_batch.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/eb_adjust_batch.xpu index b675785a4..bd791bd98 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/eb_adjust_batch.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/eb_adjust_batch.xpu @@ -4,7 +4,7 @@ namespace xpu3 { namespace plugin { #define MAX_LM_SIZE 28672 -// One core has 32KB LM(group LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is +// One core has 32KB LM(gropu LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is // the stack space #define MAX_BATCH 512 #define ALIGNMENT 64 @@ -53,6 +53,7 @@ template __global__ void eb_adjust_batch(TX* src, TY* dst, int* encoder_seqs_lods, + int* decoder_seqs_lods, int* encoder_batch_map, int* decoder_batch_map, int en_batch, @@ -61,9 +62,11 @@ __global__ void eb_adjust_batch(TX* src, int tid = core_id() * cluster_num() + cluster_id(); int nthreads = core_num() * cluster_num(); __group_shared__ int local_lods_en[MAX_BATCH + 1]; + __group_shared__ int local_lods_de[MAX_BATCH + 1]; __group_shared__ int local_map_en[MAX_BATCH]; __group_shared__ int local_map_de[MAX_BATCH]; GM2GSM_ASYNC(encoder_seqs_lods, local_lods_en, (en_batch + 1) * sizeof(int)); + GM2GSM_ASYNC(decoder_seqs_lods, local_lods_de, (de_batch + 1) * sizeof(int)); if (en_batch > 0) { GM2GSM_ASYNC(encoder_batch_map, local_map_en, en_batch * sizeof(int)); } @@ -72,7 +75,8 @@ __global__ void eb_adjust_batch(TX* src, } mfence(); int max_encoder_len = local_lods_en[en_batch]; - int seq_sum = max_encoder_len + de_batch; + int max_decoder_len = local_lods_de[de_batch]; + int seq_sum = max_encoder_len + max_decoder_len; int total_batch = en_batch + de_batch; int start = 0; int end = 0; @@ -82,13 +86,16 @@ __global__ void eb_adjust_batch(TX* src, while (i < end) { if (i >= max_encoder_len) { // dst decode part - int cur_de_bs = i - max_encoder_len; + int cur_de_bs = 0; + get_cur_batch(local_lods_de, de_batch, i - max_encoder_len, cur_de_bs); int cur_en_bs = local_map_de[cur_de_bs] - cur_de_bs; + int cur_len = + min(end, local_lods_de[cur_de_bs + 1] + max_encoder_len) - i; _global_ptr_ TY* cur_dst = dst + i * copy_size; _global_ptr_ TX* cur_src = - src + (cur_de_bs + local_lods_en[cur_en_bs]) * copy_size; - do_memcpy_1d(cur_src, cur_dst, copy_size); - i++; + src + (local_lods_en[cur_en_bs] + i - max_encoder_len) * copy_size; + do_memcpy_1d(cur_src, cur_dst, copy_size * cur_len); + i += cur_len; } else { // dst encode part int cur_en_bs = 0; @@ -97,7 +104,8 @@ __global__ void eb_adjust_batch(TX* src, cur_de_bs = local_map_en[cur_en_bs] - cur_en_bs; int cur_len = min(end, local_lods_en[cur_en_bs + 1]) - i; _global_ptr_ TY* cur_dst = dst + i * copy_size; - _global_ptr_ TX* cur_src = src + (cur_de_bs + i) * copy_size; + _global_ptr_ TX* cur_src = + src + (local_lods_de[cur_de_bs] + i) * copy_size; do_memcpy_1d(cur_src, cur_dst, copy_size * cur_len); i += cur_len; } @@ -108,6 +116,7 @@ __global__ void eb_adjust_batch(TX* src, template __global__ void eb_adjust_batch(TX * src, \ TY * dst, \ int* encoder_seqs_lods, \ + int* decoder_seqs_lods, \ int* encoder_batch_map, \ int* decoder_batch_map, \ int en_batch, \ diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu index 7cd399d09..b8d70544a 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu @@ -20,6 +20,7 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time, return; } + // 256 * int char lm[6 * 1024]; int buf_size = 6 * 1024 / (6 * sizeof(int)); int* lm_base_model_seq_lens_this_time = (int*)lm; @@ -68,10 +69,7 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time, in_offset += write_size; } mfence_lm(); - // 2. base model encoder. Base step=0 - } else if (cur_base_model_seq_lens_encoder != 0) { - // nothing happens - // 3. New end + // 2. Base model stop at last verify-step. } else if (cur_base_model_seq_lens_this_time != 0 && cur_seq_lens_this_time == 0) { in_offset += cur_base_model_seq_lens_this_time; @@ -80,27 +78,16 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time, cur_seq_lens_this_time == 0) { // nothing happens } else { - if (accept_num <= actual_draft_token_num) { - int position_map_val = out_offset; - LM2GM(&position_map_val, - position_map + in_offset + accept_num - 1, - sizeof(int)); - out_offset++; - in_offset += cur_base_model_seq_lens_this_time; - } else { - int position_map_val_1 = out_offset; - LM2GM(&position_map_val_1, - position_map + in_offset + accept_num - 2, - sizeof(int)); - out_offset++; - int position_map_val_2 = out_offset; - LM2GM(&position_map_val_2, - position_map + in_offset + accept_num - 1, - sizeof(int)); - out_offset++; - in_offset += cur_base_model_seq_lens_this_time; + // accept_num << buf_size, so do not need split + for (int i = 0; i < accept_num; i++) { + lm_position_map[i] = out_offset++; } mfence_lm(); + LM2GM(lm_position_map, + position_map + in_offset, + accept_num * sizeof(int)); + in_offset += cur_base_model_seq_lens_this_time; + mfence_lm(); } } } diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess.xpu index 9471fd096..425dc4b22 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess.xpu @@ -13,26 +13,29 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - int* seq_lens_encoder_record, - int* seq_lens_decoder_record, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, + int64_t* pre_ids, const int64_t* accept_tokens, const int* accept_num, + const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_decoder, const int64_t* base_model_step_idx, const bool* base_model_stop_flags, const bool* base_model_is_block_step, int64_t* base_model_draft_tokens, - int real_bsz, - int max_draft_token, - int accept_tokens_len, - int draft_tokens_len, - int input_ids_len, - int base_model_draft_tokens_len, - bool truncate_first_token, - bool splitwise_prefill) { + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { int cid = core_id(); int ncores = core_num(); int clusterid = cluster_id(); @@ -46,7 +49,7 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, int64_t value_fu = -1; if (splitwise_prefill) { - for (; tid < real_bsz; tid += ncores * nclusters) { + for (; tid < bsz; tid += ncores * nclusters) { int64_t base_model_step_idx_now = 0; int seq_lens_encoder_now = 0; int seq_lens_this_time_now = 0; @@ -57,35 +60,25 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, GM2LM_ASYNC( base_model_step_idx + tid, &base_model_step_idx_now, sizeof(int64_t)); - GM2LM_ASYNC(seq_lens_encoder_record + tid, - &seq_lens_encoder_record_now, - sizeof(int)); + GM2LM_ASYNC(seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int)); GM2LM(accept_tokens + tid * accept_tokens_len, &base_model_first_token, sizeof(int64_t)); - - if (base_model_step_idx_now == 1 && seq_lens_encoder_record_now > 0) { + if (seq_lens_encoder_now > 0) { not_stop_flag_sm[cid] += 1; - int seq_len_encoder_record = seq_lens_encoder_record_now; - seq_lens_encoder_now = seq_len_encoder_record; - seq_lens_encoder_record_now = -1; stop_flags_now = false; - int position = seq_len_encoder_record; + int position = seq_lens_encoder_now; if (truncate_first_token) { position = position - 1; input_ids_now = base_model_first_token; - seq_lens_this_time_now = seq_len_encoder_record; + seq_lens_this_time_now = seq_lens_encoder_now; } else { input_ids_now = base_model_first_token; - seq_lens_this_time_now = seq_len_encoder_record + 1; + seq_lens_this_time_now = seq_lens_encoder_now + 1; } LM2GM_ASYNC(&input_ids_now, input_ids + tid * input_ids_len + position, sizeof(int64_t)); - LM2GM_ASYNC(&seq_lens_encoder_record_now, - seq_lens_encoder_record + tid, - sizeof(int)); - } else { stop_flags_now = true; seq_lens_this_time_now = 0; @@ -98,21 +91,23 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, LM2GM(&seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int)); } } else { - for (; tid < real_bsz; tid += ncores * nclusters) { + for (; tid < bsz; tid += ncores * nclusters) { bool base_model_stop_flags_now = false; bool base_model_is_block_step_now = false; bool batch_drop_now = false; bool stop_flags_now = false; + bool is_block_step_now = false; int seq_lens_this_time_now = 0; - int seq_lens_encoder_record_now = 0; int seq_lens_encoder_now = 0; int seq_lens_decoder_new = 0; - int seq_lens_decoder_record_now = 0; int accept_num_now = 0; int base_model_seq_lens_decoder_now = 0; + int base_model_seq_lens_this_time_now = 0; int64_t step_id_now = 0; int64_t base_model_step_idx_now; + int64_t pre_ids_now; mfence(); + GM2LM_ASYNC(is_block_step + tid, &is_block_step_now, sizeof(bool)); GM2LM_ASYNC(base_model_stop_flags + tid, &base_model_stop_flags_now, sizeof(bool)); @@ -121,12 +116,6 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, sizeof(bool)); GM2LM_ASYNC(batch_drop + tid, &batch_drop_now, sizeof(bool)); GM2LM_ASYNC(stop_flags + tid, &stop_flags_now, sizeof(bool)); - GM2LM_ASYNC(seq_lens_encoder_record + tid, - &seq_lens_encoder_record_now, - sizeof(int)); - GM2LM_ASYNC(seq_lens_decoder_record + tid, - &seq_lens_decoder_record_now, - sizeof(int)); GM2LM_ASYNC(seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int)); GM2LM_ASYNC(seq_lens_decoder + tid, &seq_lens_decoder_new, sizeof(int)); @@ -135,6 +124,9 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, accept_tokens_len * sizeof(int64_t)); GM2LM_ASYNC(accept_num + tid, &accept_num_now, sizeof(int)); + GM2LM_ASYNC(base_model_seq_lens_this_time + tid, + &base_model_seq_lens_this_time_now, + sizeof(int)); GM2LM_ASYNC(base_model_seq_lens_decoder + tid, &base_model_seq_lens_decoder_now, sizeof(int)); @@ -148,57 +140,67 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, base_model_draft_tokens + tid * base_model_draft_tokens_len + i, sizeof(int)); } - if (base_model_stop_flags_now && base_model_is_block_step_now) { - batch_drop_now = true; - stop_flags_now = true; + if (kvcache_scheduler_v1) { + if (base_model_stop_flags_now && base_model_is_block_step_now) { + stop_flags_now = true; + is_block_step_now = true; + } + } else { + if (base_model_stop_flags_now && base_model_is_block_step_now) { + batch_drop_now = true; + stop_flags_now = true; + } } if (!(base_model_stop_flags_now || batch_drop_now)) { not_stop_flag_sm[cid] += 1; - if (base_model_step_idx_now == 0) { - seq_lens_this_time_now = 0; - not_stop_flag_sm[cid] -= 1; // 因为上面加过,这次减去,符合=0逻辑 - } else if (base_model_step_idx_now == 1 && - seq_lens_encoder_record_now > 0) { - int seq_len_encoder_record = seq_lens_encoder_record_now; - seq_lens_encoder_now = seq_len_encoder_record; - seq_lens_encoder_record_now = -1; - seq_lens_decoder_new = seq_lens_decoder_record_now; - seq_lens_decoder_record_now = 0; + if (seq_lens_encoder_now > 0) { + int seq_len_encoder = seq_lens_encoder_now; stop_flags_now = false; int64_t base_model_first_token = accept_tokens_now[0]; - int position = seq_len_encoder_record; + LM2GM(&base_model_first_token, + pre_ids + tid * pre_ids_len, + sizeof(int64_t)); + int position = seq_len_encoder; if (truncate_first_token) { LM2GM(&base_model_first_token, input_ids + tid * input_ids_len + position - 1, sizeof(int64_t)); - seq_lens_this_time_now = seq_len_encoder_record; + seq_lens_this_time_now = seq_len_encoder; } else { LM2GM(&base_model_first_token, input_ids + tid * input_ids_len + position, sizeof(int64_t)); - seq_lens_this_time_now = seq_len_encoder_record + 1; + seq_lens_this_time_now = seq_len_encoder + 1; + } + } else { + if (kvcache_scheduler_v1) { + if (!base_model_is_block_step_now && is_block_step_now) { + is_block_step_now = false; + } } - } else if (accept_num_now <= max_draft_token) { if (stop_flags_now) { stop_flags_now = false; - seq_lens_decoder_new = base_model_seq_lens_decoder_now; - step_id_now = base_model_step_idx_now; - } else { - seq_lens_decoder_new -= max_draft_token - accept_num_now; - step_id_now -= max_draft_token - accept_num_now; - } - int64_t modified_token = accept_tokens_now[accept_num_now - 1]; - LM2GM(&modified_token, - draft_tokens + tid * draft_tokens_len, - sizeof(int64_t)); - seq_lens_this_time_now = 1; + seq_lens_decoder_new = base_model_seq_lens_decoder_now - + base_model_seq_lens_this_time_now; + step_id_now = + base_model_step_idx_now - base_model_seq_lens_this_time_now; - } else /*Accept all draft tokens*/ { - LM2GM(accept_tokens_now + max_draft_token, - draft_tokens + tid * draft_tokens_len + 1, - sizeof(int64_t)); - seq_lens_this_time_now = 2; + } else { + seq_lens_decoder_new -= num_model_step - 1; + step_id_now -= num_model_step - 1; + } + for (int i = 0; i < accept_num_now; i++) { + const int pre_id_pos = + base_model_step_idx_now - (accept_num_now - i); + LM2GM(accept_tokens_now + i, + draft_tokens + tid * draft_tokens_len + i, + sizeof(int64_t)); + LM2GM(accept_tokens_now + i, + pre_ids + tid * pre_ids_len + pre_id_pos, + sizeof(int64_t)); + } + seq_lens_this_time_now = accept_num_now; } } else { @@ -209,17 +211,11 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, } LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool)); LM2GM_ASYNC(&batch_drop_now, batch_drop + tid, sizeof(bool)); - + LM2GM_ASYNC(&is_block_step_now, is_block_step + tid, sizeof(bool)); LM2GM_ASYNC(&seq_lens_decoder_new, seq_lens_decoder + tid, sizeof(int)); LM2GM_ASYNC( &seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int)); LM2GM_ASYNC(&seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int)); - LM2GM_ASYNC(&seq_lens_encoder_record_now, - seq_lens_encoder_record + tid, - sizeof(int)); - LM2GM_ASYNC(&seq_lens_decoder_record_now, - seq_lens_decoder_record + tid, - sizeof(int)); LM2GM_ASYNC(&step_id_now, step_idx + tid, sizeof(int64_t)); } } diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu index 0334995f9..50ba31d61 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu @@ -60,10 +60,8 @@ __global__ void draft_model_update(const int64_t* inter_next_tokens, token_this_time = next_tokens_start[seq_len_this_time - 1]; draft_token_now[0] = next_tokens_start[seq_len_this_time - 1]; base_model_draft_tokens_now[substep + 1] = token_this_time; - for (int i = 0; i < seq_len_this_time; ++i) { - pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i]; - } step_idx[tid] += seq_len_this_time; + pre_ids_now[step_idx[tid]] = token_this_time; } else { token_this_time = next_tokens_start[0]; seq_lens_decoder[tid] = seq_len_encoder + seq_len_decoder; diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/eb_mtp_gather_next_token.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/eb_mtp_gather_next_token.xpu new file mode 100644 index 000000000..522e2911e --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/eb_mtp_gather_next_token.xpu @@ -0,0 +1,129 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_debug.h" +#include "xpu/kernel/cluster_primitive.h" +namespace xpu3 { +namespace plugin { +#define MAX_LM_SIZE 28672 +// One core has 32KB LM(group LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is +// the stack space +#define MAX_BATCH 512 +#define ALIGNMENT 64 + +template +static __device__ void do_memcpy_1d(_global_ptr_ TX* src, + _global_ptr_ TY* dst, + int64_t copy_size) { +#ifdef __XPU3__ + constexpr int buf_size = 2048; +#else + constexpr int buf_size = 512; +#endif + __group_shared__ __simd__ float double_lmx[2][buf_size]; + int64_t pingpong = 0; + for (int64_t i = 0; i < copy_size; i += buf_size) { + int real_size = min(buf_size, copy_size - i); + _group_shared_ptr_ float* lmx = double_lmx[pingpong]; + GM2GSM(src + i, lmx, real_size * sizeof(TX)); + if (!xpu_std::is_same::value) { + primitive_cast_gsm( + (_group_shared_ptr_ TX*)lmx, lmx, real_size); + primitive_cast_gsm( + lmx, (_group_shared_ptr_ TY*)lmx, real_size); + } + GSM2GM_ASYNC((_group_shared_ptr_ TY*)lmx, dst + i, real_size * sizeof(TY)); + pingpong = 1 - pingpong; + } + mfence(); +} + +template +__global__ void eb_mtp_gather_next_token(TX* src, + TY* dst, + int* encoder_seqs_lods, + int* decoder_seqs_lods, + int* encoder_batch_map, + int* decoder_batch_map, + int en_batch, + int de_batch, + int64_t copy_size) { + int tid = core_id() * cluster_num() + cluster_id(); + int nthreads = core_num() * cluster_num(); + __group_shared__ int local_lods_en[MAX_BATCH + 1]; + __group_shared__ int local_lods_de[MAX_BATCH + 1]; + __group_shared__ int local_map_en[MAX_BATCH]; + __group_shared__ int local_map_de[MAX_BATCH]; + GM2GSM_ASYNC(encoder_seqs_lods, local_lods_en, (en_batch + 1) * sizeof(int)); + GM2GSM_ASYNC(decoder_seqs_lods, local_lods_de, (de_batch + 1) * sizeof(int)); + if (en_batch > 0) { + GM2GSM_ASYNC(encoder_batch_map, local_map_en, en_batch * sizeof(int)); + } + if (de_batch > 0) { + GM2GSM_ASYNC(decoder_batch_map, local_map_de, de_batch * sizeof(int)); + } + mfence(); + int encoder_len_total = en_batch > 0 ? local_lods_en[en_batch] : 0; + int output_len = en_batch + local_lods_de[de_batch]; + int start = 0; + int end = 0; + partition(tid, nthreads, output_len, 1, &start, &end); + for (int i = start; i < end; i++) { + int len = 0; + int enc_idx = 0, dec_idx = 0; + bool is_enc; + while (i >= len) { + if (enc_idx >= en_batch) { + len += local_lods_de[dec_idx + 1] - local_lods_de[dec_idx]; + dec_idx++; + is_enc = false; + continue; + } + if (dec_idx >= de_batch) { + len += 1; + enc_idx++; + is_enc = true; + continue; + } + if (local_map_en[enc_idx] < local_map_de[dec_idx]) { + len += 1; + enc_idx++; + is_enc = true; + } else { + len += local_lods_de[dec_idx + 1] - local_lods_de[dec_idx]; + dec_idx++; + is_enc = false; + } + } + _global_ptr_ TX* cur_src = nullptr; + _global_ptr_ TY* cur_dst = dst + i * copy_size; + if (is_enc) { + cur_src = src + (local_lods_en[enc_idx] - 1) * copy_size; + } else { + cur_src = src + (encoder_len_total + local_lods_de[dec_idx] - (len - i)) * + copy_size; + } + do_memcpy_1d(cur_src, cur_dst, copy_size); + } +} +#define _XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(TX, TY) \ + template __global__ void eb_mtp_gather_next_token( \ + TX * src, \ + TY * dst, \ + int* encoder_seqs_lods, \ + int* decoder_seqs_lods, \ + int* encoder_batch_map, \ + int* decoder_batch_map, \ + int en_batch, \ + int de_batch, \ + int64_t copy_size); + +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float16, float16); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(bfloat16, bfloat16); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float, float); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float16, float); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float, float16); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float16); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float16, bfloat16); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float, bfloat16); +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu new file mode 100644 index 000000000..133e9d798 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu @@ -0,0 +1,337 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +static __device__ inline int loada_float(_shared_ptr_ const int *ptr) { + int ret; + __asm__ __volatile__("loada.w %0,%1" : "=&r"(ret) : "r"(ptr)); + return ret; +} + +static __device__ inline bool storea_float(_shared_ptr_ int *ptr, int value) { + bool ret; + __asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(ret) : "r"(value), "r"(ptr)); + return ret; +} + +static __device__ int atomic_add(_shared_ptr_ int *ptr, int value) { + bool fail = true; + int old_value; + while (fail) { + old_value = loada_float(ptr); + int new_value = old_value + value; + fail = storea_float(ptr, new_value); + } + return old_value; +} + +static __device__ bool in_need_block_list(const int qid, + _shared_ptr_ int *need_block_list, + const int need_block_len) { + bool res = false; + for (int i = 0; i < need_block_len; i++) { + if (qid == need_block_list[i]) { + need_block_list[i] = -1; + res = true; + break; + } + } + return res; +} + +__global__ void speculate_free_and_dispatch_block( + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + int *accept_num, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + if (clusterid != 0 || cid >= bsz) return; + + // assert bsz <= 640 + const int max_bs = 640; + int value_zero = 0; + bool flag_true = true; + + // 128 = seq_len(8192) / block_size(64) + // 每次最多处理block_table数量为128 + const int block_table_now_len = 128; + int block_table_now[block_table_now_len]; + for (int i = 0; i < block_table_now_len; i++) { + block_table_now[i] = -1; + } + bool stop_flag_lm; + int seq_lens_decoder_lm; + + __shared__ int free_list_len_sm; + // 每次最多处理free_list数量为block_table_now_len + int free_list_now[block_table_now_len]; + __shared__ int need_block_len_sm; + __shared__ int need_block_list_sm[max_bs]; + __shared__ int used_list_len_sm[max_bs]; + __shared__ bool step_max_block_flag; + __shared__ int in_need_block_list_len; + if (cid == 0) { + step_max_block_flag = false; + in_need_block_list_len = 0; + GM2SM_ASYNC(free_list_len, &free_list_len_sm, sizeof(int)); + GM2SM_ASYNC(need_block_len, &need_block_len_sm, sizeof(int)); + mfence(); + if (need_block_len_sm > 0) { + GM2SM_ASYNC( + need_block_list, need_block_list_sm, sizeof(int) * need_block_len_sm); + } + GM2SM_ASYNC(used_list_len, used_list_len_sm, sizeof(int) * bsz); + mfence(); + } + sync_cluster(); + + for (int tid = cid; tid < bsz; tid += ncores) { + bool is_block_step_lm; + int seq_lens_this_time_lm; + mfence(); + GM2LM_ASYNC(stop_flags + tid, &stop_flag_lm, sizeof(bool)); + GM2LM_ASYNC(is_block_step + tid, &is_block_step_lm, sizeof(bool)); + GM2LM_ASYNC(seq_lens_decoder + tid, &seq_lens_decoder_lm, sizeof(int)); + GM2LM_ASYNC(seq_lens_this_time + tid, &seq_lens_this_time_lm, sizeof(int)); + mfence(); + int max_possible_block_idx = + (seq_lens_decoder_lm + max_draft_tokens + 1) / block_size; + if (stop_flag_lm && !is_block_step_lm) { + // 回收block块 + int64_t first_token_id_lm = -1; + mfence_lm(); + LM2GM(&first_token_id_lm, first_token_ids + tid, sizeof(int64_t)); + int encoder_block_len_lm; + int decoder_used_len_lm = used_list_len_sm[tid]; + GM2LM(encoder_block_lens + tid, &encoder_block_len_lm, sizeof(int)); + if (decoder_used_len_lm > 0) { + const int ori_free_list_len = + atomic_add(&free_list_len_sm, decoder_used_len_lm); + for (int i = 0; i < decoder_used_len_lm; i += block_table_now_len) { + int process_len = min(block_table_now_len, decoder_used_len_lm - i); + GM2LM( + block_tables + tid * block_num_per_seq + encoder_block_len_lm + i, + free_list_now, + process_len * sizeof(int)); + LM2GM(free_list_now, + free_list + ori_free_list_len + i, + process_len * sizeof(int)); + LM2GM( + block_table_now, + block_tables + tid * block_num_per_seq + encoder_block_len_lm + i, + process_len * sizeof(int)); + } + used_list_len_sm[tid] = 0; + mfence(); + LM2GM(&value_zero, encoder_block_lens + tid, sizeof(int)); + } + } else if (seq_lens_this_time_lm != 0 && + max_possible_block_idx < block_num_per_seq) { + int next_block_id; + GM2LM(block_tables + tid * block_num_per_seq + + (seq_lens_decoder_lm + max_draft_tokens + 1) / block_size, + &next_block_id, + sizeof(int)); + if (next_block_id == -1) { + // 统计需要分配block的位置和总数 + const int ori_need_block_len = atomic_add(&need_block_len_sm, 1); + need_block_list_sm[ori_need_block_len] = tid; + } + } + } + sync_cluster(); + + bool is_block_step_lm[max_bs]; + int step_len_lm; + int step_block_list_lm[max_bs]; + int recover_len_lm; + int recover_block_list_lm[max_bs]; + if (cid == 0) { + GM2LM_ASYNC(is_block_step, is_block_step_lm, sizeof(bool) * bsz); + GM2LM_ASYNC(step_len, &step_len_lm, sizeof(int)); + GM2LM_ASYNC(step_block_list, step_block_list_lm, sizeof(int) * bsz); + GM2LM_ASYNC(recover_len, &recover_len_lm, sizeof(int)); + GM2LM_ASYNC(recover_block_list, recover_block_list_lm, sizeof(int) * bsz); + mfence(); + } + + if (cid == 0) { + while (need_block_len_sm > free_list_len_sm) { + // 调度block,根据used_list_len从大到小回收block,直到满足need_block_len,已解码到最后一个block的query不参与调度(马上就结束) + int max_used_list_len_id = 0; + int max_used_list_len = 0; + for (int i = 0; i < bsz; i++) { + if (!is_block_step_lm[i] && + (step_max_block_flag || + used_list_len_sm[i] != max_decoder_block_num) && + (used_list_len_sm[i] > max_used_list_len)) { + max_used_list_len_id = i; + max_used_list_len = used_list_len_sm[i]; + } + } + + if (max_used_list_len == 0) { + step_max_block_flag = true; + } else { + int encoder_block_len; + GM2LM(encoder_block_lens + max_used_list_len_id, + &encoder_block_len, + sizeof(int)); + for (int i = 0; i < max_used_list_len; i += block_table_now_len) { + int process_len = min(block_table_now_len, max_used_list_len - i); + GM2LM(block_tables + max_used_list_len_id * block_num_per_seq + + encoder_block_len + i, + free_list_now, + process_len * sizeof(int)); + LM2GM(free_list_now, + free_list + free_list_len_sm + i, + process_len * sizeof(int)); + LM2GM(block_table_now, + block_tables + max_used_list_len_id * block_num_per_seq + + encoder_block_len + i, + process_len * sizeof(int)); + } + step_block_list_lm[step_len_lm] = max_used_list_len_id; + int need_block_len_all = need_block_len_sm + in_need_block_list_len; + if (in_need_block_list( + max_used_list_len_id, need_block_list_sm, need_block_len_all)) { + need_block_len_sm--; + in_need_block_list_len++; + } + step_len_lm++; + free_list_len_sm += max_used_list_len; + LM2GM_ASYNC( + &flag_true, stop_flags + max_used_list_len_id, sizeof(bool)); + is_block_step_lm[max_used_list_len_id] = true; + LM2GM_ASYNC(&value_zero, + seq_lens_this_time + max_used_list_len_id, + sizeof(int)); + LM2GM_ASYNC( + &value_zero, seq_lens_decoder + max_used_list_len_id, sizeof(int)); + // Note(@wufeisheng): when step, accept num will not be 0 so + // that next step even if this batch member is stepped, save + // output still stream output, so accept num should be set to 0 + LM2GM_ASYNC( + &accept_num, accept_num + max_used_list_len_id, sizeof(int)); + mfence(); + } + } + } + sync_cluster(); + + int need_block_len_all = need_block_len_sm + in_need_block_list_len; + for (int tid = cid; tid < need_block_len_all; tid += ncores) { + // 为需要block的位置分配block,每个位置分配一个block + const int need_block_id = need_block_list_sm[tid]; + if (need_block_id != -1) { + GM2LM(stop_flags + need_block_id, &stop_flag_lm, sizeof(bool)); + if (!stop_flag_lm) { + // 如果需要的位置正好是上一步中被释放的位置,不做处理 + used_list_len_sm[need_block_id]++; + const int ori_free_list_len = atomic_add(&free_list_len_sm, -1); + int tmp_seq_lens_decoder; + GM2LM(seq_lens_decoder + need_block_id, + &tmp_seq_lens_decoder, + sizeof(int)); + int free_block_id; + GM2LM(free_list + ori_free_list_len - 1, &free_block_id, sizeof(int)); + LM2GM(&free_block_id, + block_tables + need_block_id * block_num_per_seq + + (tmp_seq_lens_decoder + max_draft_tokens + 1) / block_size, + sizeof(int)); + } + need_block_list_sm[tid] = -1; + } + } + sync_cluster(); + + // 计算可以复原的query id + // 每次最多只恢复max_recover_num个query + int max_recover_num = 1; + if (cid == 0 && step_len_lm > 0) { + int ori_free_list_len = free_list_len_sm; + int ori_step_block_id = step_block_list_lm[step_len_lm - 1]; + int tmp_used_len = used_list_len_sm[ori_step_block_id]; + int encoder_block_len_lm; + GM2LM(encoder_block_lens + ori_step_block_id, + &encoder_block_len_lm, + sizeof(int)); + const int max_decoder_block_num_this_seq = + max_decoder_block_num - encoder_block_len_lm; + // 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中) + int used_len = tmp_used_len + 1 < max_decoder_block_num_this_seq + ? tmp_used_len + 1 + : max_decoder_block_num_this_seq; + while (step_len_lm > 0 && ori_free_list_len >= used_len && + max_recover_num-- > 0) { + recover_block_list_lm[recover_len_lm] = ori_step_block_id; + is_block_step_lm[ori_step_block_id] = false; + used_list_len_sm[ori_step_block_id] = used_len; + ori_free_list_len -= used_len; + step_block_list_lm[step_len_lm - 1] = -1; + step_len_lm--; + recover_len_lm++; + if (step_len_lm > 0) { + ori_step_block_id = step_block_list_lm[step_len_lm - 1]; + tmp_used_len = used_list_len_sm[ori_step_block_id]; + used_len = tmp_used_len + 1 < max_decoder_block_num_this_seq + ? tmp_used_len + 1 + : max_decoder_block_num_this_seq; + } + } + } + + // TODO(zhupengyang): + // Before the operator: need_block_len is 0, need_block_list is -1 + // After the operator: need_block_len is 0, need_block_list is -1 + // May need_block_len and need_block_list not need update? + int ori_need_block_len; + if (cid == 0) { + ori_need_block_len = need_block_len_sm; + need_block_len_sm = 0; + } + + if (cid == 0) { + mfence(); + LM2GM_ASYNC(step_block_list_lm, step_block_list, sizeof(int) * bsz); + LM2GM_ASYNC(is_block_step_lm, is_block_step, sizeof(bool) * bsz); + LM2GM_ASYNC(&step_len_lm, step_len, sizeof(int)); + LM2GM_ASYNC(&recover_len_lm, recover_len, sizeof(int)); + LM2GM_ASYNC(recover_block_list_lm, recover_block_list, sizeof(int) * bsz); + SM2GM_ASYNC(&free_list_len_sm, free_list_len, sizeof(int)); + SM2GM_ASYNC(&need_block_len_sm, need_block_len, sizeof(int)); + if (ori_need_block_len > 0) { + SM2GM_ASYNC(need_block_list_sm, + need_block_list, + sizeof(int) * ori_need_block_len); + } + SM2GM_ASYNC(used_list_len_sm, used_list_len, sizeof(int) * bsz); + mfence(); + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu index c08d756d7..a1e766d31 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu @@ -65,7 +65,7 @@ __global__ void speculate_remove_padding(T* output_data, } } -__global__ void speculate_get_padding_offset(int* padding_offset, +__global__ void speculate_get_padding_offset(int* batch_id_per_token, int* cum_offsets_out, int* cu_seqlens_q, int* cu_seqlens_k, @@ -90,8 +90,8 @@ __global__ void speculate_get_padding_offset(int* padding_offset, GM2LM(cum_offsets + bi, &cum_offsets_now_ind, sizeof(int)); for (int i = tid; i < seq_lens_now; i += ncores) { - LM2GM(&cum_offsets_now, - padding_offset + bi * max_seq_len - cum_offsets_now + i, + LM2GM(&bi, + batch_id_per_token + bi * max_seq_len - cum_offsets_now + i, sizeof(int)); } LM2GM(&cum_offsets_now, cum_offsets_out + bi, sizeof(int)); diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu new file mode 100644 index 000000000..46d24821d --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu @@ -0,0 +1,154 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +static __device__ inline int loada_float(_shared_ptr_ const int* ptr) { + int ret; + __asm__ __volatile__("loada.w %0,%1" : "=&r"(ret) : "r"(ptr)); + return ret; +} + +static __device__ inline bool storea_float(_shared_ptr_ int* ptr, int value) { + bool ret; + __asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(ret) : "r"(value), "r"(ptr)); + return ret; +} + +static __device__ int atomic_add(_shared_ptr_ int* ptr, int value) { + bool fail = true; + int old_value; + while (fail) { + old_value = loada_float(ptr); + int new_value = old_value + value; + fail = storea_float(ptr, new_value); + } + return old_value; +} + +__global__ void speculate_recover_block(int* recover_block_list, // [bsz] + int* recover_len, + bool* stop_flags, + int* seq_lens_this_time, + const int* ori_seq_lens_encoder, + int* seq_lens_encoder, + const int* seq_lens_decoder, + int* block_tables, + int* free_list, + int* free_list_len, + int64_t* input_ids, + const int64_t* pre_ids, + const int64_t* step_idx, + const int* encoder_block_lens, + const int* used_list_len, + const int64_t* next_tokens, + const int64_t* first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + if (clusterid != 0) return; + + // 128 = seq_len(8192) / block_size(64) + // 每次最多处理block_table数量为128 + const int block_table_now_len = 128; + int block_table_now[block_table_now_len]; + // max_seq_len == length + // max_seq_len == pre_id_length + + // 32k local memory per 4 core on kl2. + // No enough memory for 16382 input_ids. + const int buf_len = 256; + int64_t input_ids_now[buf_len]; + + bool flag_false = false; + + __shared__ int free_list_len_sm; + // 每次最多处理free_list数量为block_table_now_len + int free_list_now[block_table_now_len]; + if (cid == 0) { + GM2SM(free_list_len, &free_list_len_sm, sizeof(int)); + } + sync_cluster(); + + int recover_len_lm; + GM2LM(recover_len, &recover_len_lm, sizeof(int)); + + for (int bid = cid; bid < recover_len_lm; bid += ncores) { + int recover_id; + int ori_seq_len_encoder; + int step_idx_now; + int encoder_block_len; + int decoder_used_len; + int64_t next_token; + GM2LM(recover_block_list + bid, &recover_id, sizeof(int)); + GM2LM_ASYNC( + ori_seq_lens_encoder + recover_id, &ori_seq_len_encoder, sizeof(int)); + GM2LM_ASYNC(step_idx + recover_id, &step_idx_now, sizeof(int)); + GM2LM_ASYNC( + encoder_block_lens + recover_id, &encoder_block_len, sizeof(int)); + GM2LM_ASYNC(used_list_len + recover_id, &decoder_used_len, sizeof(int)); + GM2LM_ASYNC(next_tokens + recover_id, &next_token, sizeof(int64_t)); + mfence(); + + int seq_len = ori_seq_len_encoder + step_idx_now; + mfence(); + LM2GM_ASYNC(&seq_len, seq_lens_this_time + recover_id, sizeof(int)); + LM2GM_ASYNC(&seq_len, seq_lens_encoder + recover_id, sizeof(int)); + LM2GM_ASYNC(&flag_false, stop_flags + recover_id, sizeof(bool)); + mfence(); + // // next tokens + // LM2GM_ASYNC(&next_token, + // input_ids + recover_id * length + seq_len - 1, + // sizeof(int64_t)); + // set first prompt token + int64_t first_token_id; + GM2LM(first_token_ids + recover_id, &first_token_id, sizeof(int64_t)); + LM2GM_ASYNC( + &first_token_id, input_ids + recover_id * length, sizeof(int64_t)); + + int ori_free_list_len = atomic_add(&free_list_len_sm, -decoder_used_len); + // 恢复block table + for (int i = 0; i < decoder_used_len; i += block_table_now_len) { + int process_len = min(block_table_now_len, decoder_used_len - i); + GM2LM(free_list + ori_free_list_len - i - process_len, + free_list_now, + process_len * sizeof(int)); + for (int j = 0; j < process_len; j++) { + block_table_now[j] = free_list_now[process_len - 1 - j]; + } + mfence(); + LM2GM( + block_table_now, + block_tables + recover_id * block_num_per_seq + encoder_block_len + i, + process_len * sizeof(int)); + } + // 恢复input_ids + for (int i = 0; i < step_idx_now; i += buf_len) { + int real_len = min(buf_len, step_idx_now - i); + GM2LM(pre_ids + recover_id * pre_id_length + i + 1, + input_ids_now, + sizeof(int64_t) * real_len); + LM2GM(input_ids_now, + input_ids + recover_id * length + ori_seq_len_encoder + i, + sizeof(int64_t) * real_len); + } + mfence(); + } + + if (cid == 0) { + recover_len_lm = 0; + mfence(); + LM2GM_ASYNC(&recover_len_lm, recover_len, sizeof(int)); + SM2GM_ASYNC(&free_list_len_sm, free_list_len, sizeof(int)); + mfence(); + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu index 68eb2bd60..26ad38c9f 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu @@ -138,7 +138,8 @@ __global__ void speculate_verify( const int max_candidate_len, // scalar, 每个 verify token // 的最大候选数(用于验证或采样) const int verify_window, // scalar, TopK 验证窗口(允许连续 top1 匹配次数) - const bool prefill_one_step_stop) { + const bool prefill_one_step_stop, + const bool benchmark_mode) { const int cid = core_id(); const int64_t tid = cluster_id() * core_num() + core_id(); const int64_t nthreads = cluster_num() * core_num(); @@ -161,6 +162,9 @@ __global__ void speculate_verify( // printf("seq_lens_this_time[%d]-1: %d \n",bid, // seq_lens_this_time[bid]-1); for (; i < seq_lens_this_time[bid] - 1; i++) { + if (benchmark_mode) { + break; + } if (seq_lens_encoder[bid] != 0) { break; } @@ -326,7 +330,8 @@ __global__ void speculate_verify( int max_seq_len, \ int max_candidate_len, \ int verify_window, \ - bool prefill_one_step_stop); + bool prefill_one_step_stop, \ + bool benchmark_mode); SPECULATE_VERIFY_INSTANTIATE(true, true) SPECULATE_VERIFY_INSTANTIATE(true, false) SPECULATE_VERIFY_INSTANTIATE(false, true) diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_adjust_batch.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_adjust_batch.cpp index 94f235213..4a4ff43c2 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_adjust_batch.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_adjust_batch.cpp @@ -23,6 +23,7 @@ template __attribute__((global)) void eb_adjust_batch(TX *src, TY *dst, int *encoder_seqs_lods, + int *decoder_seqs_lods, int *encoder_batch_map, int *decoder_batch_map, int en_batch, @@ -41,6 +42,7 @@ static int cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods, + const int *decoder_seqs_lods, const int *encoder_batch_map, const int *decoder_batch_map, int en_batch, @@ -56,11 +58,12 @@ static int cpu_wrapper(api::Context *ctx, // get copy size && src_offset int cpy_m = 0; if (de_batch > 0 && decoder_batch_map[de_idx] == i) { - cpy_m = 1; - ret = api::cast(ctx, - x + cur_offset * hidden_dim, - y + (encoder_len_total + de_idx) * hidden_dim, - cpy_m * hidden_dim); + cpy_m = decoder_seqs_lods[de_idx + 1] - decoder_seqs_lods[de_idx]; + ret = api::cast( + ctx, + x + cur_offset * hidden_dim, + y + (encoder_len_total + decoder_seqs_lods[de_idx]) * hidden_dim, + cpy_m * hidden_dim); WRAPPER_ASSERT_SUCCESS(ctx, ret); de_idx++; } @@ -84,6 +87,7 @@ static int xpu3_wrapper(api::Context *ctx, const TX *x, TY *y, api::VectorParam &encoder_seqs_lods, // NOLINT + api::VectorParam &decoder_seqs_lods, // NOLINT api::VectorParam &encoder_batch_map, // NOLINT api::VectorParam &decoder_batch_map, // NOLINT int en_batch, @@ -98,6 +102,7 @@ static int xpu3_wrapper(api::Context *ctx, reinterpret_cast(const_cast(x)), reinterpret_cast(y), encoder_seqs_lods.xpu, + decoder_seqs_lods.xpu, encoder_batch_map.xpu, decoder_batch_map.xpu, en_batch, @@ -111,6 +116,7 @@ int eb_adjust_batch(api::Context *ctx, const TX *x, TY *y, api::VectorParam &encoder_seqs_lods, // NOLINT + api::VectorParam &decoder_seqs_lods, // NOLINT api::VectorParam &encoder_batch_map, // NOLINT api::VectorParam &decoder_batch_map, // NOLINT int64_t hidden_dim) { @@ -119,28 +125,35 @@ int eb_adjust_batch(api::Context *ctx, // if (dev_id ==0) { // ctx->set_debug_level(0xA1); // } - + // std::cout << decoder_seqs_lods.cpu[0] << " " << decoder_seqs_lods.cpu[1] << + // std::endl; WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_adjust_batch", TX, TY); WRAPPER_DUMP_PARAM6(ctx, x, y, encoder_seqs_lods, + decoder_seqs_lods, encoder_batch_map, - decoder_batch_map, - hidden_dim); + decoder_batch_map); + WRAPPER_DUMP_PARAM1(ctx, hidden_dim); WRAPPER_DUMP(ctx); int encoder_batch = encoder_batch_map.len; - int total_batch = encoder_batch + decoder_batch_map.len; + int decoder_batch = decoder_batch_map.len; + int total_batch = encoder_batch + decoder_batch; int max_encoder_lod = encoder_seqs_lods.cpu[encoder_batch]; - int m = max_encoder_lod + decoder_batch_map.len; + int max_decoder_lod = decoder_seqs_lods.cpu[decoder_batch]; + int m = max_encoder_lod + max_decoder_lod; WRAPPER_CHECK_PTR(ctx, TX, m * hidden_dim, x); WRAPPER_CHECK_PTR(ctx, TY, m * hidden_dim, y); WRAPPER_ASSERT_GT(ctx, hidden_dim, 0); // check VectorParam WRAPPER_ASSERT_EQ(ctx, encoder_seqs_lods.len, encoder_batch_map.len + 1); + WRAPPER_ASSERT_EQ(ctx, decoder_seqs_lods.len, decoder_batch_map.len + 1); WRAPPER_ASSERT_GE(ctx, encoder_seqs_lods.cpu[0], 0); WRAPPER_ASSERT_LE(ctx, encoder_seqs_lods.cpu[0], max_encoder_lod); + WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[0], 0); + WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[0], max_decoder_lod); for (int i = 0; i < encoder_batch_map.len; ++i) { WRAPPER_ASSERT_GE(ctx, encoder_batch_map.cpu[i], 0); WRAPPER_ASSERT_LT(ctx, encoder_batch_map.cpu[i], total_batch) @@ -150,12 +163,15 @@ int eb_adjust_batch(api::Context *ctx, for (int i = 0; i < decoder_batch_map.len; ++i) { WRAPPER_ASSERT_GE(ctx, decoder_batch_map.cpu[i], 0); WRAPPER_ASSERT_LT(ctx, decoder_batch_map.cpu[i], total_batch) + WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[i + 1], 0); + WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[i + 1], max_decoder_lod); } if (ctx->dev().type() == api::kCPU) { return cpu_wrapper(ctx, x, y, encoder_seqs_lods.cpu, + decoder_seqs_lods.cpu, encoder_batch_map.cpu, decoder_batch_map.cpu, encoder_batch_map.len, @@ -166,6 +182,8 @@ int eb_adjust_batch(api::Context *ctx, api::ctx_guard RAII_GUARD(ctx); api::VectorParam encoder_seqs_lods_xpu = encoder_seqs_lods.to_xpu(RAII_GUARD); + api::VectorParam decoder_seqs_lods_xpu = + decoder_seqs_lods.to_xpu(RAII_GUARD); api::VectorParam encoder_batch_map_xpu = encoder_batch_map.to_xpu(RAII_GUARD); api::VectorParam decoder_batch_map_xpu = @@ -174,6 +192,7 @@ int eb_adjust_batch(api::Context *ctx, x, y, encoder_seqs_lods_xpu, + decoder_seqs_lods_xpu, encoder_batch_map_xpu, decoder_batch_map_xpu, encoder_batch_map.len, @@ -190,6 +209,7 @@ int eb_adjust_batch(api::Context *ctx, api::VectorParam &, \ api::VectorParam &, \ api::VectorParam &, \ + api::VectorParam &, \ int64_t); INSTANTIATION_EB_ADJUST_BATCH(float16, float16); diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess.cpp index 9ca1f2224..3a9273aee 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess.cpp @@ -27,26 +27,29 @@ __attribute__((global)) void draft_model_preprocess( int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - int* seq_lens_encoder_record, - int* seq_lens_decoder_record, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, + int64_t* pre_ids, const int64_t* accept_tokens, const int* accept_num, + const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_decoder, const int64_t* base_model_step_idx, const bool* base_model_stop_flags, const bool* base_model_is_block_step, int64_t* base_model_draft_tokens, - int real_bsz, - int max_draft_token, - int accept_tokens_len, - int draft_tokens_len, - int input_ids_len, - int base_model_draft_tokens_len, - bool truncate_first_token, - bool splitwise_prefill); + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1); } // namespace plugin } // namespace xpu3 @@ -67,49 +70,47 @@ static int cpu_wrapper(api::Context* ctx, int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - int* seq_lens_encoder_record, - int* seq_lens_decoder_record, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, + int64_t* pre_ids, const int64_t* accept_tokens, const int* accept_num, + const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_decoder, const int64_t* base_model_step_idx, const bool* base_model_stop_flags, const bool* base_model_is_block_step, int64_t* base_model_draft_tokens, - int real_bsz, - int max_draft_token, - int accept_tokens_len, - int draft_tokens_len, - int input_ids_len, - int base_model_draft_tokens_len, - bool truncate_first_token, - bool splitwise_prefill) { + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { int64_t not_stop_flag_sum = 0; int64_t not_stop_flag = 0; - for (int tid = 0; tid < real_bsz; tid++) { + for (int tid = 0; tid < bsz; tid++) { if (splitwise_prefill) { - int base_model_step_idx_now = base_model_step_idx[tid]; auto* input_ids_now = input_ids + tid * input_ids_len; auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; - // printf("bid: %d, base_model_step_idx_now: %d seq_lens_encoder_record: - // %d\n", tid, base_model_step_idx_now, seq_lens_encoder_record[tid]); - if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) { + if (seq_lens_encoder[tid] > 0) { not_stop_flag = 1; - int seq_len_encoder_record = seq_lens_encoder_record[tid]; - seq_lens_encoder[tid] = seq_len_encoder_record; - seq_lens_encoder_record[tid] = -1; + int seq_len_encoder = seq_lens_encoder[tid]; stop_flags[tid] = false; int64_t base_model_first_token = accept_tokens_now[0]; - int position = seq_len_encoder_record; + int position = seq_len_encoder; if (truncate_first_token) { input_ids_now[position - 1] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder_record; + seq_lens_this_time[tid] = seq_len_encoder; } else { input_ids_now[position] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder_record + 1; + seq_lens_this_time[tid] = seq_len_encoder + 1; } } else { stop_flags[tid] = true; @@ -120,63 +121,77 @@ static int cpu_wrapper(api::Context* ctx, } not_stop_flag_sum += not_stop_flag; } else { - auto base_model_step_idx_now = base_model_step_idx[tid]; auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len; auto accept_num_now = accept_num[tid]; auto* input_ids_now = input_ids + tid * input_ids_len; auto* base_model_draft_tokens_now = base_model_draft_tokens + tid * base_model_draft_tokens_len; + auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid]; + const int32_t base_model_seq_len_this_time = + base_model_seq_lens_this_time[tid]; + auto* pre_ids_now = pre_ids + tid * pre_ids_len; for (int i = 1; i < base_model_draft_tokens_len; i++) { base_model_draft_tokens_now[i] = -1; } - if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { - batch_drop[tid] = true; - stop_flags[tid] = true; + if (kvcache_scheduler_v1) { + if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { + stop_flags[tid] = true; + is_block_step[tid] = true; + // Need to continue infer + } + } else { + if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { + batch_drop[tid] = true; + stop_flags[tid] = true; + } } if (!(base_model_stop_flags[tid] || batch_drop[tid])) { not_stop_flag = 1; - // 1. first token - - if (base_model_step_idx_now == 0) { - seq_lens_this_time[tid] = 0; - not_stop_flag = 0; - } else if (base_model_step_idx_now == 1 && - seq_lens_encoder_record[tid] > 0) { + // prefill generation + if (seq_lens_encoder[tid] > 0) { // Can be extended to first few tokens - int seq_len_encoder_record = seq_lens_encoder_record[tid]; - seq_lens_encoder[tid] = seq_len_encoder_record; - seq_lens_encoder_record[tid] = -1; - seq_lens_decoder[tid] = seq_lens_decoder_record[tid]; - seq_lens_decoder_record[tid] = 0; + int seq_len_encoder = seq_lens_encoder[tid]; stop_flags[tid] = false; int64_t base_model_first_token = accept_tokens_now[0]; - int position = seq_len_encoder_record; + pre_ids_now[0] = base_model_first_token; + int position = seq_len_encoder; if (truncate_first_token) { input_ids_now[position - 1] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder_record; + seq_lens_this_time[tid] = seq_len_encoder; } else { input_ids_now[position] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder_record + 1; + seq_lens_this_time[tid] = seq_len_encoder + 1; + } + } else { // decode generation + if (kvcache_scheduler_v1) { + // 3. try to recover mtp infer in V1 mode + if (!base_model_is_block_step[tid] && is_block_step[tid]) { + is_block_step[tid] = false; + } } - } else if (accept_num_now <= - max_draft_token) /*Accept partial draft tokens*/ { - // Base Model reject stop if (stop_flags[tid]) { stop_flags[tid] = false; - seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid]; - step_idx[tid] = base_model_step_idx[tid]; + // TODO: check + seq_lens_decoder[tid] = + base_model_seq_len_decoder - base_model_seq_len_this_time; + step_idx[tid] = + base_model_step_idx[tid] - base_model_seq_len_this_time; } else { - seq_lens_decoder[tid] -= max_draft_token - accept_num_now; - step_idx[tid] -= max_draft_token - accept_num_now; + // 2: Last base model generated token and first MTP + // token + seq_lens_decoder[tid] -= num_model_step - 1; + step_idx[tid] -= num_model_step - 1; } - int64_t modified_token = accept_tokens_now[accept_num_now - 1]; - draft_tokens_now[0] = modified_token; - seq_lens_this_time[tid] = 1; - } else /*Accept all draft tokens*/ { - draft_tokens_now[1] = accept_tokens_now[max_draft_token]; - seq_lens_this_time[tid] = 2; + for (int i = 0; i < accept_num_now; i++) { + draft_tokens_now[i] = accept_tokens_now[i]; + const int pre_id_pos = + base_model_step_idx[tid] - (accept_num_now - i); + const int64_t accept_token = accept_tokens_now[i]; + pre_ids_now[pre_id_pos] = accept_token; + } + seq_lens_this_time[tid] = accept_num_now; } } else { stop_flags[tid] = true; @@ -199,26 +214,29 @@ static int xpu3_wrapper(api::Context* ctx, int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - int* seq_lens_encoder_record, - int* seq_lens_decoder_record, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, + int64_t* pre_ids, const int64_t* accept_tokens, const int* accept_num, + const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_decoder, const int64_t* base_model_step_idx, const bool* base_model_stop_flags, const bool* base_model_is_block_step, int64_t* base_model_draft_tokens, - int real_bsz, - int max_draft_token, - int accept_tokens_len, - int draft_tokens_len, - int input_ids_len, - int base_model_draft_tokens_len, - bool truncate_first_token, - bool splitwise_prefill) { + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { using XPU_INT64 = typename XPUIndexType::type; // NOTE: Don't change 16 to 64, because kernel use gsm @@ -230,26 +248,29 @@ static int xpu3_wrapper(api::Context* ctx, seq_lens_encoder, seq_lens_decoder, reinterpret_cast(step_idx), - seq_lens_encoder_record, - seq_lens_decoder_record, not_need_stop, + is_block_step, batch_drop, + reinterpret_cast(pre_ids), reinterpret_cast(accept_tokens), accept_num, + base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, reinterpret_cast(base_model_step_idx), base_model_stop_flags, base_model_is_block_step, reinterpret_cast(base_model_draft_tokens), - real_bsz, - max_draft_token, + bsz, + num_model_step, accept_tokens_len, draft_tokens_len, input_ids_len, base_model_draft_tokens_len, + pre_ids_len, truncate_first_token, - splitwise_prefill); + splitwise_prefill, + kvcache_scheduler_v1); return api::SUCCESS; } @@ -261,26 +282,29 @@ int draft_model_preprocess(api::Context* ctx, int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - int* seq_lens_encoder_record, - int* seq_lens_decoder_record, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, + int64_t* pre_ids, const int64_t* accept_tokens, const int* accept_num, + const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_decoder, const int64_t* base_model_step_idx, const bool* base_model_stop_flags, const bool* base_model_is_block_step, int64_t* base_model_draft_tokens, - int real_bsz, - int max_draft_token, - int accept_tokens_len, - int draft_tokens_len, - int input_ids_len, - int base_model_draft_tokens_len, - bool truncate_first_token, - bool splitwise_prefill) { + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T1(ctx, "draft_model_preprocess", int64_t); WRAPPER_DUMP_PARAM6(ctx, @@ -290,37 +314,34 @@ int draft_model_preprocess(api::Context* ctx, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder); - WRAPPER_DUMP_PARAM5(ctx, - step_idx, - seq_lens_encoder_record, - seq_lens_decoder_record, - not_need_stop, - batch_drop); + WRAPPER_DUMP_PARAM5( + ctx, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids); WRAPPER_DUMP_PARAM3( ctx, accept_tokens, accept_num, base_model_seq_lens_encoder); - WRAPPER_DUMP_PARAM3(ctx, + WRAPPER_DUMP_PARAM4(ctx, + base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags); WRAPPER_DUMP_PARAM3( - ctx, base_model_is_block_step, base_model_draft_tokens, real_bsz); - WRAPPER_DUMP_PARAM3( - ctx, max_draft_token, accept_tokens_len, draft_tokens_len); - WRAPPER_DUMP_PARAM3( - ctx, input_ids_len, base_model_draft_tokens_len, truncate_first_token); - WRAPPER_DUMP_PARAM1(ctx, splitwise_prefill); + ctx, base_model_is_block_step, base_model_draft_tokens, bsz); + WRAPPER_DUMP_PARAM3(ctx, num_model_step, accept_tokens_len, draft_tokens_len); + WRAPPER_DUMP_PARAM4(ctx, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len, + truncate_first_token); + WRAPPER_DUMP_PARAM2(ctx, splitwise_prefill, kvcache_scheduler_v1); WRAPPER_DUMP(ctx); - WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time); - WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * accept_tokens_len, accept_tokens); - WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * input_ids_len, input_ids); - WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * draft_tokens_len, draft_tokens); - WRAPPER_CHECK_PTR(ctx, - int64_t, - real_bsz * base_model_draft_tokens_len, - base_model_draft_tokens); + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * accept_tokens_len, accept_tokens); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * input_ids_len, input_ids); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * draft_tokens_len, draft_tokens); + WRAPPER_CHECK_PTR( + ctx, int64_t, bsz * base_model_draft_tokens_len, base_model_draft_tokens); - WRAPPER_ASSERT_GT(ctx, real_bsz, 0); + WRAPPER_ASSERT_GT(ctx, bsz, 0); WRAPPER_ASSERT_LT(ctx, accept_tokens_len, 128); if (ctx->dev().type() == api::kCPU) { @@ -332,26 +353,29 @@ int draft_model_preprocess(api::Context* ctx, seq_lens_encoder, seq_lens_decoder, step_idx, - seq_lens_encoder_record, - seq_lens_decoder_record, not_need_stop, + is_block_step, batch_drop, + pre_ids, accept_tokens, accept_num, + base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags, base_model_is_block_step, base_model_draft_tokens, - real_bsz, - max_draft_token, + bsz, + num_model_step, accept_tokens_len, draft_tokens_len, input_ids_len, base_model_draft_tokens_len, + pre_ids_len, truncate_first_token, - splitwise_prefill); + splitwise_prefill, + kvcache_scheduler_v1); } if (ctx->dev().type() == api::kXPU3) { return xpu3_wrapper(ctx, @@ -362,26 +386,29 @@ int draft_model_preprocess(api::Context* ctx, seq_lens_encoder, seq_lens_decoder, step_idx, - seq_lens_encoder_record, - seq_lens_decoder_record, not_need_stop, + is_block_step, batch_drop, + pre_ids, accept_tokens, accept_num, + base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags, base_model_is_block_step, base_model_draft_tokens, - real_bsz, - max_draft_token, + bsz, + num_model_step, accept_tokens_len, draft_tokens_len, input_ids_len, base_model_draft_tokens_len, + pre_ids_len, truncate_first_token, - splitwise_prefill); + splitwise_prefill, + kvcache_scheduler_v1); } WRAPPER_UNIMPLEMENTED(ctx); } diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/eb_mtp_gather_next_token.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/eb_mtp_gather_next_token.cpp new file mode 100644 index 000000000..f817baf89 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/eb_mtp_gather_next_token.cpp @@ -0,0 +1,227 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl/launch_strategy.h" +#include "xpu/refactor/impl_public/wrapper_check.h" +#include "xpu/xdnn.h" + +namespace xpu3 { +namespace plugin { +template +__attribute__((global)) void eb_mtp_gather_next_token(TX *src, + TY *dst, + int *encoder_seqs_lods, + int *decoder_seqs_lods, + int *encoder_batch_map, + int *decoder_batch_map, + int en_batch, + int de_batch, + int64_t copy_size); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { +template +static int cpu_wrapper(api::Context *ctx, + const TX *x, + TY *y, + const int *encoder_seqs_lods, + const int *decoder_seqs_lods, + const int *encoder_batch_map, + const int *decoder_batch_map, + int en_batch, + int de_batch, + int64_t hidden_dim) { + int ret = 0; + int encoder_len_total = encoder_seqs_lods[en_batch]; + int decoder_len_total = decoder_seqs_lods[de_batch]; + int output_token_num = en_batch + decoder_len_total; + for (int i = 0; i < output_token_num; i++) { + int len = 0; + int enc_idx = 0, dec_idx = 0; + bool is_enc; + while (i >= len) { + if (enc_idx >= en_batch) { + len += decoder_seqs_lods[dec_idx + 1] - decoder_seqs_lods[dec_idx]; + dec_idx++; + is_enc = false; + continue; + } + if (dec_idx >= de_batch) { + len += 1; + enc_idx++; + is_enc = true; + continue; + } + if ((encoder_batch_map[enc_idx] < decoder_batch_map[dec_idx])) { + len += 1; + enc_idx++; + is_enc = true; + } else { + len += decoder_seqs_lods[dec_idx + 1] - decoder_seqs_lods[dec_idx]; + dec_idx++; + is_enc = false; + } + } + const TX *src = nullptr; + if (is_enc) { + src = x + (encoder_seqs_lods[enc_idx] - 1) * hidden_dim; + } else { + src = x + (encoder_len_total + decoder_seqs_lods[dec_idx] - (len - i)) * + hidden_dim; + } + ret = api::cast(ctx, src, y + i * hidden_dim, hidden_dim); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + } + return api::SUCCESS; +} +template +static int xpu3_wrapper(api::Context *ctx, + const TX *x, + TY *y, + api::VectorParam &encoder_seqs_lods, // NOLINT + api::VectorParam &decoder_seqs_lods, // NOLINT + api::VectorParam &encoder_batch_map, // NOLINT + api::VectorParam &decoder_batch_map, // NOLINT + int en_batch, + int de_batch, + int64_t hidden_dim) { + auto eb_mtp_gather_next_token_kernel = + xpu3::plugin::eb_mtp_gather_next_token; + // NOTE: Don't change 16 to 64, because kernel use gsm + eb_mtp_gather_next_token_kernel<<ncluster(), 16, ctx->xpu_stream>>>( + const_cast(x), + y, + encoder_seqs_lods.xpu, + decoder_seqs_lods.xpu, + encoder_batch_map.xpu, + decoder_batch_map.xpu, + en_batch, + de_batch, + hidden_dim); + return api::SUCCESS; +} + +template +int eb_mtp_gather_next_token( + api::Context *ctx, + const TX *x, + TY *y, + api::VectorParam &encoder_seqs_lods, // NOLINT + api::VectorParam &decoder_seqs_lods, // NOLINT + api::VectorParam &encoder_batch_map, // NOLINT + api::VectorParam &decoder_batch_map, // NOLINT + int64_t hidden_dim) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_mtp_gather_next_token", TX, TY); + WRAPPER_DUMP_PARAM6(ctx, + x, + y, + encoder_seqs_lods, + decoder_seqs_lods, + encoder_batch_map, + decoder_batch_map); + WRAPPER_DUMP_PARAM1(ctx, hidden_dim); + WRAPPER_DUMP(ctx); + int encoder_batch = encoder_batch_map.len; + int decoder_batch = decoder_batch_map.len; + int max_encoder_lod = encoder_seqs_lods.cpu[encoder_batch]; + int max_decoder_lod = decoder_seqs_lods.cpu[decoder_batch]; + int m = encoder_seqs_lods.cpu[encoder_batch] + + decoder_seqs_lods.cpu[decoder_batch]; + int out_m = encoder_batch + decoder_seqs_lods.cpu[decoder_batch]; + WRAPPER_CHECK_PTR(ctx, TX, m * hidden_dim, x); + WRAPPER_CHECK_PTR(ctx, TY, out_m * hidden_dim, y); + WRAPPER_ASSERT_GT(ctx, hidden_dim, 0); + // check VectorParam + WRAPPER_ASSERT_EQ(ctx, encoder_seqs_lods.len, encoder_batch_map.len + 1); + WRAPPER_ASSERT_EQ(ctx, decoder_seqs_lods.len, decoder_batch_map.len + 1); + WRAPPER_ASSERT_GE(ctx, encoder_seqs_lods.cpu[0], 0); + WRAPPER_ASSERT_LE(ctx, encoder_seqs_lods.cpu[0], max_encoder_lod); + WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[0], 0); + WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[0], max_decoder_lod); + // 注意: encoder/decoder的batch + // map数值上有可能大于batch,因为复原后的batch排布有可能是稀疏的,所以这里只做非负检查 + for (int i = 0; i < encoder_batch_map.len; ++i) { + WRAPPER_ASSERT_GE(ctx, encoder_batch_map.cpu[i], 0); + WRAPPER_ASSERT_GE(ctx, encoder_seqs_lods.cpu[i + 1], 0); + WRAPPER_ASSERT_LE(ctx, encoder_seqs_lods.cpu[i + 1], max_encoder_lod); + } + for (int i = 0; i < decoder_batch_map.len; ++i) { + WRAPPER_ASSERT_GE(ctx, decoder_batch_map.cpu[i], 0); + WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[i + 1], 0); + WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[i + 1], max_decoder_lod); + } + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + x, + y, + encoder_seqs_lods.cpu, + decoder_seqs_lods.cpu, + encoder_batch_map.cpu, + decoder_batch_map.cpu, + encoder_batch_map.len, + decoder_batch_map.len, + hidden_dim); + } + if (ctx->dev().type() == api::kXPU3) { + api::ctx_guard RAII_GUARD(ctx); + api::VectorParam encoder_seqs_lods_xpu = + encoder_seqs_lods.to_xpu(RAII_GUARD); + api::VectorParam decoder_seqs_lods_xpu = + decoder_seqs_lods.to_xpu(RAII_GUARD); + api::VectorParam encoder_batch_map_xpu = + encoder_batch_map.to_xpu(RAII_GUARD); + api::VectorParam decoder_batch_map_xpu = + decoder_batch_map.to_xpu(RAII_GUARD); + return xpu3_wrapper(ctx, + x, + y, + encoder_seqs_lods_xpu, + decoder_seqs_lods_xpu, + encoder_batch_map_xpu, + decoder_batch_map_xpu, + encoder_batch_map.len, + decoder_batch_map.len, + hidden_dim); + } + WRAPPER_UNIMPLEMENTED(ctx); +} +#define INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(TX, TY) \ + template int eb_mtp_gather_next_token(api::Context *, \ + const TX *, \ + TY *, \ + api::VectorParam &, \ + api::VectorParam &, \ + api::VectorParam &, \ + api::VectorParam &, \ + int64_t); + +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float16, float16); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(bfloat16, bfloat16); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float, float); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float16, float); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float, float16); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float16); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float16, bfloat16); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float, bfloat16); +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp new file mode 100644 index 000000000..cefe893e9 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp @@ -0,0 +1,340 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { + +__attribute__((global)) void speculate_free_and_dispatch_block( + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + int *accept_num, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens); + +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context *ctx, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + int *accept_num, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens) { + for (int i = 0; i < bsz; i++) { + int *block_table_now = block_tables + i * block_num_per_seq; + int max_possible_block_idx = + (seq_lens_decoder[i] + max_draft_tokens + 1) / block_size; + if (stop_flags[i] && !is_block_step[i]) { + // 回收block块 + first_token_ids[i] = -1; + const int encoder_block_len = encoder_block_lens[i]; + const int decoder_used_len = used_list_len[i]; + if (decoder_used_len > 0) { + const int ori_free_list_len = free_list_len[0]; + free_list_len[0] += decoder_used_len; + for (int j = 0; j < decoder_used_len; j++) { + free_list[ori_free_list_len + j] = + block_table_now[encoder_block_len + j]; + block_table_now[encoder_block_len + j] = -1; + } + encoder_block_lens[i] = 0; + used_list_len[i] = 0; + } + } else if (seq_lens_this_time[i] != 0 && + max_possible_block_idx < block_num_per_seq && + block_table_now[(seq_lens_decoder[i] + max_draft_tokens + 1) / + block_size] == -1) { + // 统计需要分配block的位置和总数 + const int ori_need_block_len = need_block_len[0]; + need_block_len[0] += 1; + need_block_list[ori_need_block_len] = i; + } + } + + while (need_block_len[0] > free_list_len[0]) { + // 调度block,根据used_list_len从大到小回收block,直到满足need_block_len + int max_used_list_len_id = 0; + int max_used_list_len = 0; + for (int i = 0; i < bsz; i++) { + const int used_block_num = !is_block_step[i] ? used_list_len[i] : 0; + if (used_block_num > max_used_list_len) { + max_used_list_len_id = i; + max_used_list_len = used_block_num; + } + } + + const int encoder_block_len = encoder_block_lens[max_used_list_len_id]; + int *block_table_now = + block_tables + max_used_list_len_id * block_num_per_seq; + for (int i = 0; i < max_used_list_len; i++) { + free_list[free_list_len[0] + i] = block_table_now[encoder_block_len + i]; + block_table_now[encoder_block_len + i] = -1; + } + step_block_list[step_len[0]] = max_used_list_len_id; + step_len[0] += 1; + free_list_len[0] += max_used_list_len; + stop_flags[max_used_list_len_id] = true; + is_block_step[max_used_list_len_id] = true; + seq_lens_this_time[max_used_list_len_id] = 0; + seq_lens_decoder[max_used_list_len_id] = 0; + accept_num[max_used_list_len_id] = 0; + } + + // 为需要block的位置分配block,每个位置分配一个block + for (int i = 0; i < bsz; i++) { + if (i < need_block_len[0]) { + const int need_block_id = need_block_list[i]; + if (!stop_flags[need_block_id]) { + // 如果需要的位置正好是上一步中被释放的位置,不做处理 + used_list_len[need_block_id] += 1; + const int ori_free_list_len = free_list_len[0]; + free_list_len[0]--; + int *block_table_now = block_tables + need_block_id * block_num_per_seq; + block_table_now[(seq_lens_decoder[need_block_id] + max_draft_tokens + + 1) / + block_size] = free_list[ori_free_list_len - 1]; + } + need_block_list[i] = -1; + } + } + + // 计算可以复原的query id + int ori_step_len = step_len[0]; + if (ori_step_len > 0) { + int ori_free_list_len = free_list_len[0]; + int ori_step_block_id = step_block_list[ori_step_len - 1]; + int tmp_used_len = used_list_len[ori_step_block_id]; + // 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中) + int used_len = + tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len; + if (ori_step_len > 0 && ori_free_list_len >= used_len) { + recover_block_list[recover_len[0]] = ori_step_block_id; + is_block_step[ori_step_block_id] = false; + used_list_len[ori_step_block_id] = used_len; + ori_free_list_len -= used_len; + step_block_list[ori_step_len - 1] = -1; + step_len[0] -= 1; + recover_len[0] += 1; + ori_step_len = step_len[0]; + if (ori_step_len > 0) { + ori_step_block_id = step_block_list[ori_step_len - 1]; + tmp_used_len = used_list_len[ori_step_block_id]; + used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 + : tmp_used_len; + } + } + need_block_len[0] = 0; + } + return api::SUCCESS; +} + +static int xpu3_wrapper(Context *ctx, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + int *accept_num, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens) { + using XPU_INT64 = typename XPUIndexType::type; + auto speculate_free_and_dispatch_block_kernel = + xpu3::plugin::speculate_free_and_dispatch_block; + speculate_free_and_dispatch_block_kernel<<ncluster(), + 64, + ctx->xpu_stream>>>( + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + reinterpret_cast(first_token_ids), + accept_num, + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + return api::SUCCESS; +} + +int speculate_free_and_dispatch_block(Context *ctx, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + int *accept_num, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_free_and_dispatch_block", float); + WRAPPER_DUMP_PARAM6(ctx, + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step); + WRAPPER_DUMP_PARAM6(ctx, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len); + WRAPPER_DUMP_PARAM4( + ctx, used_list_len, free_list, free_list_len, first_token_ids); + WRAPPER_DUMP_PARAM4( + ctx, bsz, block_size, block_num_per_seq, max_decoder_block_num); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + first_token_ids, + accept_num, + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + first_token_ids, + accept_num, + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp index a0066e455..fe4096cac 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp @@ -33,7 +33,7 @@ __attribute__((global)) void speculate_remove_padding( int token_num_data); __attribute__((global)) void speculate_get_padding_offset( - int* padding_offset, + int* batch_id_per_token, int* cum_offsets_out, int* cu_seqlens_q, int* cu_seqlens_k, @@ -78,7 +78,7 @@ static int cpu_wrapper_remove_padding(Context* ctx, } static int cpu_wrapper_get_padding_offset(Context* ctx, - int* padding_offset, + int* batch_id_per_token, int* cum_offsets_out, int* cu_seqlens_q, int* cu_seqlens_k, @@ -89,7 +89,7 @@ static int cpu_wrapper_get_padding_offset(Context* ctx, for (int bi = 0; bi < bsz; ++bi) { int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; for (int i = 0; i < seq_lens[bi]; i++) { - padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset; + batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi; } cum_offsets_out[bi] = cum_offset; int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]; @@ -129,7 +129,7 @@ static int xpu3_wrapper_remove_padding(Context* ctx, } static int xpu3_wrapper_get_padding_offset(Context* ctx, - int* padding_offset, + int* batch_id_per_token, int* cum_offsets_out, int* cu_seqlens_q, int* cu_seqlens_k, @@ -139,7 +139,7 @@ static int xpu3_wrapper_get_padding_offset(Context* ctx, int bsz) { xpu3::plugin:: speculate_get_padding_offset<<ncluster(), 64, ctx->xpu_stream>>>( - padding_offset, + batch_id_per_token, cum_offsets_out, cu_seqlens_q, cu_seqlens_k, @@ -215,7 +215,7 @@ int speculate_remove_padding(Context* ctx, } int speculate_get_padding_offset(Context* ctx, - int* padding_offset, + int* batch_id_per_token, int* cum_offsets_out, int* cu_seqlens_q, int* cu_seqlens_k, @@ -227,7 +227,7 @@ int speculate_get_padding_offset(Context* ctx, WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_padding_offset", float); WRAPPER_DUMP_PARAM6(ctx, - padding_offset, + batch_id_per_token, cum_offsets_out, cu_seqlens_q, cu_seqlens_k, @@ -247,7 +247,7 @@ int speculate_get_padding_offset(Context* ctx, if (ctx->dev().type() == api::kCPU) { return cpu_wrapper_get_padding_offset(ctx, - padding_offset, + batch_id_per_token, cum_offsets_out, cu_seqlens_q, cu_seqlens_k, @@ -258,7 +258,7 @@ int speculate_get_padding_offset(Context* ctx, } if (ctx->dev().type() == api::kXPU3) { return xpu3_wrapper_get_padding_offset(ctx, - padding_offset, + batch_id_per_token, cum_offsets_out, cu_seqlens_q, cu_seqlens_k, diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp new file mode 100644 index 000000000..2996325c8 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp @@ -0,0 +1,258 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { + +__attribute__((global)) void speculate_recover_block( + int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, + int *seq_lens_this_time, + const int *ori_seq_lens_encoder, + int *seq_lens_encoder, + const int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *encoder_block_lens, + const int *used_list_len, + const int64_t *next_tokens, + const int64_t *first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length); + +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context *ctx, + int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, + int *seq_lens_this_time, + const int *ori_seq_lens_encoder, + int *seq_lens_encoder, + const int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *encoder_block_lens, + const int *used_list_len, + const int64_t *next_tokens, + const int64_t *first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length) { + for (int bid = 0; bid < recover_len[0]; bid++) { + const int recover_id = recover_block_list[bid]; + const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id]; + const int step_idx_now = step_idx[recover_id]; + const int seq_len = ori_seq_len_encoder + step_idx_now; + const int encoder_block_len = encoder_block_lens[recover_id]; + const int decoder_used_len = used_list_len[recover_id]; + int *block_table_now = block_tables + recover_id * block_num_per_seq; + int64_t *input_ids_now = input_ids + recover_id * length; + const int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length; + + seq_lens_this_time[recover_id] = seq_len; + seq_lens_encoder[recover_id] = seq_len; + stop_flags[recover_id] = false; + // input_ids_now[seq_len - 1] = next_tokens[recover_id]; // next tokens + input_ids_now[0] = first_token_ids[recover_id]; // set first prompt token + int ori_free_list_len = free_list_len[0]; + free_list_len[0] -= decoder_used_len; + + // 恢复block table + for (int i = 0; i < decoder_used_len; i++) { + block_table_now[encoder_block_len + i] = + free_list[ori_free_list_len - i - 1]; + } + // 恢复input_ids + for (int i = 0; i < step_idx_now; i++) { + input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1]; + } + } + recover_len[0] = 0; + return api::SUCCESS; +} + +static int xpu3_wrapper(Context *ctx, + int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, + int *seq_lens_this_time, + const int *ori_seq_lens_encoder, + int *seq_lens_encoder, + const int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *encoder_block_lens, + const int *used_list_len, + const int64_t *next_tokens, + const int64_t *first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length) { + using XPU_INT64 = typename XPUIndexType::type; + auto recover_block_kernel = xpu3::plugin::speculate_recover_block; + recover_block_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + recover_block_list, // [bsz] + recover_len, + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder, + seq_lens_decoder, + block_tables, + free_list, + free_list_len, + reinterpret_cast(input_ids), + reinterpret_cast(pre_ids), + reinterpret_cast(step_idx), + encoder_block_lens, + used_list_len, + reinterpret_cast(next_tokens), + reinterpret_cast(first_token_ids), + bsz, + block_num_per_seq, + length, + pre_id_length); + return api::SUCCESS; +} + +int speculate_recover_block(Context *ctx, + int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, + int *seq_lens_this_time, + const int *ori_seq_lens_encoder, + int *seq_lens_encoder, + const int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *encoder_block_lens, + const int *used_list_len, + const int64_t *next_tokens, + const int64_t *first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_recover_block", float); + WRAPPER_DUMP_PARAM6(ctx, + recover_block_list, + recover_len, + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder); + WRAPPER_DUMP_PARAM6(ctx, + seq_lens_decoder, + block_tables, + free_list, + free_list_len, + input_ids, + pre_ids); + WRAPPER_DUMP_PARAM5(ctx, + step_idx, + encoder_block_lens, + used_list_len, + next_tokens, + first_token_ids); + WRAPPER_DUMP_PARAM4(ctx, bsz, block_num_per_seq, length, pre_id_length); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + recover_block_list, // [bsz] + recover_len, + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder, + seq_lens_decoder, + block_tables, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + encoder_block_lens, + used_list_len, + next_tokens, + first_token_ids, + bsz, + block_num_per_seq, + length, + pre_id_length); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + recover_block_list, // [bsz] + recover_len, + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder, + seq_lens_decoder, + block_tables, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + encoder_block_lens, + used_list_len, + next_tokens, + first_token_ids, + bsz, + block_num_per_seq, + length, + pre_id_length); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp index c5e3e425b..3989ce8de 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp @@ -48,7 +48,8 @@ __attribute__((global)) void speculate_verify( const int max_seq_len, const int max_candidate_len, const int verify_window, - const bool prefill_one_step_stop); + const bool prefill_one_step_stop, + const bool benchmark_mode); } // namespace plugin } // namespace xpu3 @@ -136,14 +137,15 @@ static int cpu_wrapper(Context *ctx, const int max_seq_len, const int max_candidate_len, const int verify_window, - const bool prefill_one_step_stop) { + const bool prefill_one_step_stop, + const bool benchmark_mode) { for (int bid = 0; bid < real_bsz; ++bid) { - const int start_token_id = bid * max_seq_len - output_cum_offsets[bid]; // verify and set stop flags int accept_num_now = 1; int stop_flag_now_int = 0; if (!(is_block_step[bid] || bid >= real_bsz)) { + const int start_token_id = bid * max_seq_len - output_cum_offsets[bid]; // printf("debug cpu bid:%d,start_token_id:%d\n",bid, start_token_id); // printf("bid %d\n", bid); @@ -160,6 +162,9 @@ static int cpu_wrapper(Context *ctx, // printf("seq_lens_this_time[%d]-1: %d \n",bid, // seq_lens_this_time[bid]-1); for (; i < seq_lens_this_time[bid] - 1; i++) { + if (benchmark_mode) { + break; + } if (seq_lens_encoder[bid] != 0) { break; } @@ -326,7 +331,8 @@ static int xpu3_wrapper(Context *ctx, const int max_seq_len, const int max_candidate_len, const int verify_window, - const bool prefill_one_step_stop) { + const bool prefill_one_step_stop, + const bool benchmark_mode) { using XPU_INT64 = typename XPUIndexType::type; xpu3::plugin::speculate_verify <<<1, 64, ctx->xpu_stream>>>( @@ -354,7 +360,8 @@ static int xpu3_wrapper(Context *ctx, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); return api::SUCCESS; } template @@ -383,7 +390,8 @@ int speculate_verify(Context *ctx, const int max_seq_len, const int max_candidate_len, const int verify_window, - const bool prefill_one_step_stop) { + const bool prefill_one_step_stop, + const bool benchmark_mode) { WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_verify", int64_t); WRAPPER_DUMP_PARAM3(ctx, accept_tokens, accept_num, step_idx); @@ -406,12 +414,13 @@ int speculate_verify(Context *ctx, actual_candidate_len, real_bsz, max_draft_tokens); - WRAPPER_DUMP_PARAM5(ctx, + WRAPPER_DUMP_PARAM6(ctx, end_length, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); WRAPPER_DUMP(ctx); WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, accept_tokens); WRAPPER_CHECK_PTR(ctx, int, real_bsz, accept_num); @@ -469,7 +478,8 @@ int speculate_verify(Context *ctx, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); } if (ctx->dev().type() == api::kXPU3) { return xpu3_wrapper(ctx, @@ -497,7 +507,8 @@ int speculate_verify(Context *ctx, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); } WRAPPER_UNIMPLEMENTED(ctx); } @@ -530,7 +541,8 @@ int speculate_verify(Context *ctx, int, /* max_seq_len */ \ int, /* max_candidate_len */ \ int, /* verify_window */ \ - bool); /* prefill_one_step_stop */ + bool, \ + bool); /* prefill_one_step_stop */ INSTANTIATE_SPECULATE_VERIFY(false, false) INSTANTIATE_SPECULATE_VERIFY(false, true) diff --git a/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py b/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py new file mode 100644 index 000000000..758dff17e --- /dev/null +++ b/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py @@ -0,0 +1,180 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest # 导入 unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import ( + adjust_batch, + gather_next_token, + get_infer_param, +) + + +def _run_test_base(seq_lens_this_time_data, output_padding_offset): + """ + 通用的基础测试执行函数,包含了两个场景共有的逻辑。 + """ + seq_lens_encoder = paddle.to_tensor([100, 0, 0, 0, 120, 140, 0], dtype="int32") + seq_lens_decoder = paddle.to_tensor([0, 5, 0, 25, 64, 0, 128], dtype="int32") + seq_lens_this_time = paddle.to_tensor(seq_lens_this_time_data, dtype="int32") + + bsz = seq_lens_this_time.shape[0] + cum_offsets = paddle.zeros(bsz, dtype="int32") + block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8)) + + infer_params = get_infer_param(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64) + + ( + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + _, + _, + _, + _, + _, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + _, + _, + _, + _, + len_info_cpu, + ) = infer_params + + token_num = seq_lens_this_time.sum().cpu().item() + hidden_dim = 8192 + row_indices = paddle.arange(token_num, dtype="int32") + row_indices_bf16 = row_indices.astype("bfloat16") + input_tensor = paddle.unsqueeze(row_indices_bf16, axis=1).expand(shape=[token_num, hidden_dim]) + + # 测试 adjust_batch + adjusted_output = adjust_batch( + input_tensor, + cum_offsets, + encoder_seq_lod, + decoder_seq_lod, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + len_info_cpu, + None, # output_padding_offset + -1, # max_input_length + ) + + adjusted_output_cpu = adjust_batch( + input_tensor.cpu(), + cum_offsets, + encoder_seq_lod, + decoder_seq_lod, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + len_info_cpu, + None, # output_padding_offset + -1, # max_input_length + ) + + # 用 np.testing 替代原生 assert,错误信息更友好 + adjusted_output_np = adjusted_output.astype("float32").cpu().numpy() + adjusted_output_cpu_np = adjusted_output_cpu.astype("float32").cpu().numpy() + np.testing.assert_allclose(adjusted_output_np, adjusted_output_cpu_np, err_msg="adjust_batch check failed!") + + # 测试 gather_next_token + gather_out = gather_next_token( + adjusted_output, + cum_offsets, + encoder_seq_lod, + decoder_seq_lod, + encoder_batch_map, + decoder_batch_map, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + len_info_cpu, + output_padding_offset, + -1, + ) + + gather_out_cpu = gather_next_token( + adjusted_output.cpu(), + cum_offsets, + encoder_seq_lod, + decoder_seq_lod, + encoder_batch_map, + decoder_batch_map, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + len_info_cpu, + output_padding_offset, + -1, + ) + + gather_out_np = gather_out.astype("float32").cpu().numpy() + gather_out_cpu_np = gather_out_cpu.astype("float32").cpu().numpy() + + if output_padding_offset is not None: + np.testing.assert_allclose(gather_out_np, gather_out_cpu_np, err_msg="gather_next_token check failed!") + else: + for i in range(gather_out_cpu.shape[0]): + if seq_lens_this_time[i] > 0: + np.testing.assert_allclose( + gather_out_np[i], gather_out_cpu_np[i], err_msg=f"gather_next_token check failed at index {i}!" + ) + + +class TestXPUOps(unittest.TestCase): # 继承 unittest.TestCase + """测试 XPU ops 的 adjust_batch 和 gather_next_token 功能""" + + def test_mix_with_mtp(self): + """测试混合批次处理中的 MTP (Multi-Token Prediction) 场景""" + print("\nRunning test: test_mix_with_mtp") + seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3] + bsz = len(seq_lens_this_time_data) + output_padding_offset = paddle.zeros(bsz, dtype="int32") + + _run_test_base(seq_lens_this_time_data, output_padding_offset) + print("Test passed for scenario: With MTP") + + def test_mix_without_mtp(self): + """测试非 MTP (Single-Token Prediction) 场景下的功能""" + print("\nRunning test: test_mix_without_mtp") + seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1] + output_padding_offset = None # 非 MTP 场景下,此参数为 None + + _run_test_base(seq_lens_this_time_data, output_padding_offset) + print("Test passed for scenario: Without MTP") + + +if __name__ == "__main__": + unittest.main() # 使用 unittest 运行测试 diff --git a/custom_ops/xpu_ops/test/test_draft_model_preprocess.py b/custom_ops/xpu_ops/test/test_draft_model_preprocess.py index c687bdf30..1348e6fcd 100644 --- a/custom_ops/xpu_ops/test/test_draft_model_preprocess.py +++ b/custom_ops/xpu_ops/test/test_draft_model_preprocess.py @@ -12,50 +12,284 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import numpy as np import paddle from fastdeploy.model_executor.ops.xpu import draft_model_preprocess -def run_test(device="xpu"): - paddle.seed(2022) +def process_splitwise_prefill( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + base_model_draft_tokens_len, + truncate_first_token, + kvcache_scheduler_v1, +): + not_stop_flag_sum = 0 - # Define parameters - bsz = 10 - draft_tokens_len = 4 - input_ids_len = 8 - max_draft_token = 10 + for tid in range(bsz): + not_stop_flag = 0 + input_ids_now = input_ids[tid] + accept_tokens_now = accept_tokens[tid] + if seq_lens_encoder[tid] > 0: + not_stop_flag = 1 + seq_len_encoder = seq_lens_encoder[tid] + stop_flags[tid] = False + base_model_first_token = accept_tokens_now[0] + position = seq_len_encoder + if truncate_first_token: + input_ids_now[position - 1] = base_model_first_token + seq_lens_this_time[tid] = seq_len_encoder + else: + input_ids_now[position] = base_model_first_token + seq_lens_this_time[tid] = seq_len_encoder + 1 + else: + stop_flags[tid] = True + seq_lens_this_time[tid] = 0 + seq_lens_decoder[tid] = 0 + seq_lens_encoder[tid] = 0 + not_stop_flag = 0 + not_stop_flag_sum = not_stop_flag_sum + not_stop_flag + not_need_stop[0] = not_stop_flag_sum > 0 - truncate_first_token = True - splitwise_prefill = False - # Create input tensors - if device == "cpu": - paddle.set_device(device) - draft_tokens = paddle.randint(0, 100, [bsz, draft_tokens_len], dtype="int64") - input_ids = paddle.randint(0, 100, [bsz, input_ids_len], dtype="int64") - stop_flags = paddle.randint(0, 1, [bsz], dtype="int").cast("bool") - seq_lens_this_time = paddle.randint(0, 100, [bsz], dtype="int32") - seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32") - seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32") - step_idx = paddle.randint(0, 100, [bsz], dtype="int64") - seq_lens_encoder_record = paddle.randint(0, 100, [bsz], dtype="int32") - seq_lens_decoder_record = paddle.randint(0, 100, [bsz], dtype="int32") - not_need_stop = paddle.zeros([1], dtype="bool").cpu() - batch_drop = paddle.zeros([bsz], dtype="bool") +def draft_model_preprocess_kernel( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + base_model_draft_tokens_len, + truncate_first_token, + kvcache_scheduler_v1, +): + not_stop_flag_sum = 0 - # Output tensors - accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64") - accept_num = paddle.randint(1, max_draft_token + 5, [bsz], dtype="int32") - base_model_seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32") - base_model_seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32") - base_model_step_idx = paddle.randint(0, 100, [bsz], dtype="int64") - base_model_stop_flags = paddle.zeros([bsz], dtype="bool") - base_model_is_block_step = paddle.zeros([bsz], dtype="bool") - base_model_draft_tokens = paddle.zeros([bsz, max_draft_token], dtype="int64") - # Run the op - outputs = draft_model_preprocess( + for tid in range(bsz): + not_stop_flag = 0 + accept_tokens_now = accept_tokens[tid] + draft_tokens_now = draft_tokens[tid] + accept_num_now = accept_num[tid] + input_ids_now = input_ids[tid] + base_model_draft_tokens_now = base_model_draft_tokens[tid] + base_model_seq_len_decoder = base_model_seq_lens_decoder[tid] + base_model_seq_len_this_time = base_model_seq_lens_this_time[tid] + pre_ids_now = pre_ids[tid] + + base_model_draft_tokens_now[1:base_model_draft_tokens_len] = -1 + + if kvcache_scheduler_v1: + if base_model_stop_flags[tid] and base_model_is_block_step[tid]: + stop_flags[tid] = True + is_block_step[tid] = True + # Need to continue infer + else: + if base_model_stop_flags[tid] and base_model_is_block_step[tid]: + batch_drop[tid] = True + stop_flags[tid] = True + + if not (base_model_stop_flags[tid] or batch_drop[tid]): + not_stop_flag = 1 + # 1. first token + if seq_lens_encoder[tid] > 0: + # Can be extended to first few tokens + seq_len_encoder = seq_lens_encoder[tid] + stop_flags[tid] = False + base_model_first_token = accept_tokens_now[0] + pre_ids_now[0] = base_model_first_token + position = seq_len_encoder + if truncate_first_token: + input_ids_now[position - 1] = base_model_first_token + seq_lens_this_time[tid] = seq_len_encoder + else: + input_ids_now[position] = base_model_first_token + seq_lens_this_time[tid] = seq_len_encoder + 1 + else: + if kvcache_scheduler_v1: + # 3. try to recover mtp infer in V1 mode + if not (base_model_is_block_step[tid] and is_block_step[tid]): + is_block_step[tid] = False + + if stop_flags[tid]: + stop_flags[tid] = False + # TODO: check + seq_lens_decoder[tid] = base_model_seq_len_decoder - base_model_seq_len_this_time + step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time + else: + # 2: Last base model generated token and first MTP token + seq_lens_decoder[tid] -= num_model_step - 1 + step_idx[tid] -= num_model_step - 1 + + for i in range(accept_num_now): + draft_tokens_now[i] = accept_tokens_now[i] + pre_id_pos = base_model_step_idx[tid] - (accept_num_now - i) + accept_token = accept_tokens_now[i] + pre_ids_now[pre_id_pos] = accept_token + + seq_lens_this_time[tid] = accept_num_now + else: + stop_flags[tid] = True + seq_lens_this_time[tid] = 0 + seq_lens_decoder[tid] = 0 + seq_lens_encoder[tid] = 0 + not_stop_flag_sum = not_stop_flag_sum + not_stop_flag + not_need_stop[0] = not_stop_flag_sum > 0 + + +def DispatchRunner( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1, +): + base_model_draft_tokens_len = base_model_draft_tokens.shape[1] + if splitwise_prefill: + process_splitwise_prefill( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + base_model_draft_tokens_len, + truncate_first_token, + kvcache_scheduler_v1, + ) + else: + draft_model_preprocess_kernel( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + base_model_draft_tokens_len, + truncate_first_token, + kvcache_scheduler_v1, + ) + + +def draft_model_preprocess_ref( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + num_model_step, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1, +): + real_bsz = seq_lens_this_time.shape[0] + + DispatchRunner( draft_tokens, input_ids, stop_flags, @@ -63,73 +297,110 @@ def run_test(device="xpu"): seq_lens_encoder, seq_lens_decoder, step_idx, - seq_lens_encoder_record, - seq_lens_decoder_record, not_need_stop, + is_block_step, batch_drop, + pre_ids, accept_tokens, accept_num, + base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags, base_model_is_block_step, base_model_draft_tokens, - max_draft_token=max_draft_token, - truncate_first_token=truncate_first_token, - splitwise_prefill=splitwise_prefill, + real_bsz, + num_model_step, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1, ) - # Return results for comparison - results = { - "draft_tokens": draft_tokens.numpy(), - "input_ids": input_ids.numpy(), - "stop_flags": stop_flags.numpy(), - "seq_lens_this_time": seq_lens_this_time.numpy(), - "accept_tokens": accept_tokens.numpy(), - "accept_num": accept_num.numpy(), - "not_need_stop": not_need_stop.numpy(), - "outputs": [x.numpy() for x in outputs], - } - return results +class TestDraftModelPreprocess: + def _run_tests(self): + paddle.seed(2022) -def compare_results(cpu_results, xpu_results): - # Compare all outputs - for key in cpu_results: - if key == "outputs": - for i, (cpu_out, xpu_out) in enumerate(zip(cpu_results[key], xpu_results[key])): - np.testing.assert_allclose( - cpu_out, - xpu_out, - rtol=1e-5, - atol=1e-8, - err_msg=f"Output {i} mismatch between CPU and GPU", - ) - else: - np.testing.assert_allclose( - cpu_results[key], - xpu_results[key], - rtol=1e-5, - atol=1e-8, - err_msg=f"{key} mismatch between CPU and GPU", - ) - print("CPU and GPU results match!") + # Define parameters + bsz = 10 + draft_tokens_len = 4 + input_ids_len = 100 + max_draft_token = 10 + truncate_first_token = True + splitwise_prefill = False -def test_draft_model_preprocess(): + draft_tokens = paddle.randint(0, 100, [bsz, draft_tokens_len], dtype="int64") + input_ids = paddle.randint(0, 100, [bsz, input_ids_len], dtype="int64") + stop_flags = paddle.randint(0, 1, [bsz], dtype="int").cast("bool") + seq_lens_this_time = paddle.randint(0, 100, [bsz], dtype="int32") + seq_lens_encoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32") + seq_lens_decoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32") + step_idx = paddle.randint(0, 100, [bsz], dtype="int64") + seq_lens_encoder_record = paddle.randint(0, 100, [bsz], dtype="int32") # noqa: F841 + seq_lens_decoder_record = paddle.randint(0, 100, [bsz], dtype="int32") # noqa: F841 + not_need_stop = paddle.zeros([1], dtype="bool").cpu() + is_block_step = paddle.zeros([bsz], dtype="bool") + batch_drop = paddle.zeros([bsz], dtype="bool") - print("Running XPU test...") - xpu_results = run_test("xpu") + # Output tensors + accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64") + accept_num = paddle.randint(1, max_draft_token + 5, [bsz], dtype="int32") + base_model_seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32") + base_model_seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32") + base_model_step_idx = paddle.randint(0, 100, [bsz], dtype="int64") + base_model_stop_flags = paddle.zeros([bsz], dtype="bool") + base_model_is_block_step = paddle.zeros([bsz], dtype="bool") + base_model_draft_tokens = paddle.zeros([bsz, max_draft_token], dtype="int64") + # Run the op + pre_ids = input_ids.clone() + base_model_seq_lens_this_time = seq_lens_this_time + num_model_step = max_draft_token - print("Running CPU test...") - cpu_results = run_test("cpu") + kvcache_scheduler_v1 = True + inputs = ( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + num_model_step, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1, + ) + # inplace modify, need to clone inputs + inputs_clone = [x.clone() if isinstance(x, paddle.Tensor) else x for x in inputs] + draft_model_preprocess_ref(*inputs) + draft_model_preprocess(*inputs_clone) + return inputs, inputs_clone - print("Comparing results...") - compare_results(cpu_results, xpu_results) - - print("Test passed!") + def test_draft_model_preprocess(self): + results1, results2 = self._run_tests() + np.testing.assert_allclose(results1[0], results2[0]) # draft_tokens + np.testing.assert_allclose(results1[1], results2[1]) # input_ids + np.testing.assert_allclose(results1[2], results2[2]) # stop_flags + np.testing.assert_allclose(results1[3], results2[3]) # seq_lens_this_time + np.testing.assert_allclose(results1[11], results2[11]) # accept_tokens + np.testing.assert_allclose(results1[12], results2[12]) # accept_num + np.testing.assert_allclose(results1[7], results2[7]) # not_need_stop if __name__ == "__main__": - test_draft_model_preprocess() + unittest.main() diff --git a/custom_ops/xpu_ops/test/test_speculate_step.py b/custom_ops/xpu_ops/test/test_speculate_step.py new file mode 100644 index 000000000..65414bcff --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_step.py @@ -0,0 +1,312 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import speculate_step_paddle + +# 固定随机种子,保证测试可复现 +np.random.seed(2023) +paddle.seed(2023) + + +def generate_test_data(): + """ + 生成测试数据的辅助函数。 + 这部分逻辑从 pytest 的 fixture 转换而来,作为一个普通函数供测试方法调用。 + """ + # max_bs = 128 + max_bs = 8 + bs = max_bs + max_seq_len = 8192 + block_size = 64 + block_bs = 8 + block_ratio = 0.75 + max_draft_tokens = 1 + encoder_decoder_block_num = 1 + + # 生成原始测试数据(完全复用原有逻辑) + stop_flags = np.random.randint(0, 2, [max_bs]).astype("bool") + seq_lens_this_time = np.zeros([bs], "int32") + seq_lens_encoder = np.zeros([max_bs], "int32") + seq_lens_decoder = np.zeros([max_bs], "int32") + accept_num = np.random.randint(1, 3, [max_bs]).astype("int32") + for i in range(bs): + seq_lens_decoder[i] = 2 + i * 2 + seq_lens_this_time[i] = 1 + + ori_seq_lens_encoder = np.zeros([max_bs], "int32") + ori_seq_lens_encoder[:] = seq_lens_decoder[:] // 2 + step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64") + + max_block_num = block_bs * max_seq_len // block_size + free_list_len = int(max_block_num * (1 - block_ratio)) + free_list_len = np.full([1], free_list_len, "int32") + free_list = np.arange( + max_block_num - 1, max_block_num - free_list_len.item() - 1, -1, dtype="int32" # 加 .item() 转为标量 + ) + encoder_block_lens = np.zeros([max_bs], "int32") + used_list_len = np.zeros([max_bs], "int32") + block_tables = np.full([max_bs, 128], -1, "int32") + encoder_block_id = 0 + + for i in range(bs): + enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size + encoder_block_lens[i] = enc_block_num + dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num + used_list_len[i] = dec_block_num + block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32") + encoder_block_id += enc_block_num + if dec_block_num > 0: + block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[ + free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1 + ] + free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1 + free_list_len[0] -= dec_block_num + + assert free_list_len[0] >= 0, "free_list_len should not be negative" + + is_block_step = np.zeros([max_bs], "bool") + is_block_step[:bs] = np.random.randint(0, 2, [bs]).astype("bool") + step_block_list = np.full([max_bs], -1, "int32") + step_lens = np.full([1], 0, "int32") + + for i in range(bs): + if is_block_step[i]: + step_block_list[step_lens[0]] = i + step_lens[0] += 1 + + recover_lens = np.full([1], 0, "int32") + recover_block_list = np.full([max_bs], -1, "int32") + need_block_len = np.full([1], 0, "int32") + need_block_list = np.full([max_bs], -1, "int32") + + input_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64") + pre_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64") + next_tokens = np.random.randint(0, 1000, [max_bs], "int64") + first_token_ids = np.random.randint(0, 1000, [max_bs], "int64") + + paddle.set_device("cpu") + # 转换为 paddle tensor(保持原有逻辑) + data_cpu = { + "stop_flags": paddle.to_tensor(stop_flags), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + "seq_lens_decoder": paddle.to_tensor(seq_lens_decoder), + "ori_seq_lens_encoder": paddle.to_tensor(ori_seq_lens_encoder), + "block_tables": paddle.to_tensor(block_tables), + "encoder_block_lens": paddle.to_tensor(encoder_block_lens), + "is_block_step": paddle.to_tensor(is_block_step), + "step_block_list": paddle.to_tensor(step_block_list), + "step_lens": paddle.to_tensor(step_lens), + "recover_block_list": paddle.to_tensor(recover_block_list), + "recover_lens": paddle.to_tensor(recover_lens), + "need_block_list": paddle.to_tensor(need_block_list), + "need_block_len": paddle.to_tensor(need_block_len), + "used_list_len": paddle.to_tensor(used_list_len), + "free_list_len": paddle.to_tensor(free_list_len), + "free_list": paddle.to_tensor(free_list), + "input_ids": paddle.to_tensor(input_ids), + "pre_ids": paddle.to_tensor(pre_ids), + "step_idx": paddle.to_tensor(step_idx), + "next_tokens": paddle.to_tensor(next_tokens), + "first_token_ids": paddle.to_tensor(first_token_ids), + "accept_num": paddle.to_tensor(accept_num), + "block_size": block_size, + "encoder_decoder_block_num": encoder_decoder_block_num, + "max_draft_tokens": max_draft_tokens, + } + + paddle.set_device("xpu:0") + data_xpu = { + "stop_flags": paddle.to_tensor(stop_flags), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + "seq_lens_decoder": paddle.to_tensor(seq_lens_decoder), + "ori_seq_lens_encoder": paddle.to_tensor(ori_seq_lens_encoder), + "block_tables": paddle.to_tensor(block_tables), + "encoder_block_lens": paddle.to_tensor(encoder_block_lens), + "is_block_step": paddle.to_tensor(is_block_step), + "step_block_list": paddle.to_tensor(step_block_list), + "step_lens": paddle.to_tensor(step_lens), + "recover_block_list": paddle.to_tensor(recover_block_list), + "recover_lens": paddle.to_tensor(recover_lens), + "need_block_list": paddle.to_tensor(need_block_list), + "need_block_len": paddle.to_tensor(need_block_len), + "used_list_len": paddle.to_tensor(used_list_len), + "free_list_len": paddle.to_tensor(free_list_len), + "free_list": paddle.to_tensor(free_list), + "input_ids": paddle.to_tensor(input_ids), + "pre_ids": paddle.to_tensor(pre_ids), + "step_idx": paddle.to_tensor(step_idx), + "next_tokens": paddle.to_tensor(next_tokens), + "first_token_ids": paddle.to_tensor(first_token_ids), + "accept_num": paddle.to_tensor(accept_num), + "block_size": block_size, + "encoder_decoder_block_num": encoder_decoder_block_num, + "max_draft_tokens": max_draft_tokens, + } + + # 恢复默认设备,避免影响其他测试 + paddle.set_device("cpu") + + return data_cpu, data_xpu + + +def speculate_step_paddle_execution(test_data): + """测试 speculate_step_paddle 函数的执行性和输出合理性""" + # 提取输入数据 + stop_flags = test_data["stop_flags"] # 克隆避免影响夹具数据 + seq_lens_this_time = test_data["seq_lens_this_time"] + ori_seq_lens_encoder = test_data["ori_seq_lens_encoder"] + seq_lens_encoder = test_data["seq_lens_encoder"] + seq_lens_decoder = test_data["seq_lens_decoder"] + block_tables = test_data["block_tables"] + encoder_block_lens = test_data["encoder_block_lens"] + is_block_step = test_data["is_block_step"] + step_block_list = test_data["step_block_list"] + step_lens = test_data["step_lens"] + recover_block_list = test_data["recover_block_list"] + recover_lens = test_data["recover_lens"] + need_block_list = test_data["need_block_list"] + need_block_len = test_data["need_block_len"] + used_list_len = test_data["used_list_len"] + free_list = test_data["free_list"] + free_list_len = test_data["free_list_len"] + input_ids = test_data["input_ids"] + pre_ids = test_data["pre_ids"] + step_idx = test_data["step_idx"] + next_tokens = test_data["next_tokens"] + first_token_ids = test_data["first_token_ids"] + accept_num = test_data["accept_num"] + block_size = test_data["block_size"] + encoder_decoder_block_num = test_data["encoder_decoder_block_num"] + max_draft_tokens = test_data["max_draft_tokens"] + + # 可选:打印执行前关键信息(如需调试可开启) + if os.environ.get("STEP_TEST_DEBUG", "0") == "1": + print("-" * 50 + "before step op" + "-" * 50) + # ... (省略打印内容以保持简洁) + + # 执行目标函数(核心测试步骤) + speculate_step_paddle( + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_lens, + recover_block_list, + recover_lens, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + next_tokens, + first_token_ids, + accept_num, + block_size, + encoder_decoder_block_num, + max_draft_tokens, + ) + + # 可选:打印执行后关键信息(如需调试可开启) + if os.environ.get("STEP_TEST_DEBUG", "0") == "1": + print("-" * 50 + "after step op" + "-" * 50) + # ... (省略打印内容以保持简洁) + + return test_data + + +class TestSpeculateStepPaddle(unittest.TestCase): + """ + 测试类,继承自 unittest.TestCase。 + 所有以 'test_' 开头的方法都会被视为测试用例。 + """ + + def assert_test_data_equal(self, test_data1, test_data2, rtol=1e-05, atol=1e-08): + """ + 自定义的断言方法,用于比较两个 test_data 结构和数据。 + 在 unittest 中,自定义断言通常以 'assert' 开头。 + """ + # 1. 先校验两个 test_data 的字段名完全一致 + keys1 = set(test_data1.keys()) + keys2 = set(test_data2.keys()) + self.assertEqual( + keys1, + keys2, + msg=f"两个 test_data 字段不一致!\n仅在第一个中存在:{keys1 - keys2}\n仅在第二个中存在:{keys2 - keys1}", + ) + + # 2. 逐字段校验数据 + for key in keys1: + data1 = test_data1[key] + data2 = test_data2[key] + + # 区分:paddle Tensor(需转 numpy)和 普通标量/数组(直接使用) + if isinstance(data1, paddle.Tensor): + np1 = data1.detach().cpu().numpy() + else: + np1 = np.asarray(data1) + + if isinstance(data2, paddle.Tensor): + np2 = data2.detach().cpu().numpy() + else: + np2 = np.asarray(data2) + + # 3. 校验数据 + if np1.dtype in (np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8): + # 布尔/整数型:必须完全相等 + np.testing.assert_array_equal(np1, np2, err_msg=f"字段 {key} 数据不一致!") + else: + # 浮点型:允许 rtol/atol 范围内的误差 + np.testing.assert_allclose(np1, np2, rtol=rtol, atol=atol, err_msg=f"字段 {key} 浮点数据不一致!") + + print("✅ 两个 test_data 结构和数据完全一致!") + + def test_speculate_step_paddle_execution(self): + """ + 核心测试用例方法。 + 该方法会调用 generate_test_data 获取数据, + 分别在 CPU 和 XPU 上执行测试函数, + 并使用自定义的断言方法比较结果。 + """ + print("\nRunning test: test_speculate_step_paddle_execution") + + # 1. 获取测试数据 + data_cpu, data_xpu = generate_test_data() + + # 2. 执行测试函数 + result_xpu = speculate_step_paddle_execution(data_xpu) + result_cpu = speculate_step_paddle_execution(data_cpu) + + # 3. 断言结果一致 + self.assert_test_data_equal(result_xpu, result_cpu) + + +if __name__ == "__main__": + # 使用 unittest 的主程序来运行所有测试用例 + unittest.main() diff --git a/custom_ops/xpu_ops/test/test_speculate_update_v3.py b/custom_ops/xpu_ops/test/test_speculate_update_v3.py index 1ecebc6e7..bdea8727d 100644 --- a/custom_ops/xpu_ops/test/test_speculate_update_v3.py +++ b/custom_ops/xpu_ops/test/test_speculate_update_v3.py @@ -12,101 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +import unittest -# tests/test_speculate_update_v3.py +import numpy as np import paddle +# 假设这是你的自定义算子 from fastdeploy.model_executor.ops.xpu import speculate_update_v3 -# ---------------- NumPy 参考实现 ---------------- -def speculate_update_v3_np( - seq_lens_encoder, - seq_lens_decoder, - not_need_stop, - draft_tokens, - actual_draft_token_nums, - accept_tokens, - accept_num, - stop_flags, - seq_lens_this_time, - is_block_step, - stop_nums, -): - """ - 完全复现 CPU / CUDA 逻辑的 NumPy 参考版本(就地修改)。 - """ - stop_sum = 0 - real_bsz = seq_lens_this_time.shape[0] - max_bsz = stop_flags.shape[0] - max_draft_tokens = draft_tokens.shape[1] - - for bid in range(max_bsz): - stop_flag_now_int = 0 - inactive = bid >= real_bsz - block_step = (not inactive) and is_block_step[bid] - - if (not block_step) and (not inactive): - - if stop_flags[bid]: - stop_flag_now_int = 1 - - # encoder 长度为 0 时直接累加 decoder - if seq_lens_encoder[bid] == 0: - seq_lens_decoder[bid] += accept_num[bid] - - # draft 长度自适应 - if (seq_lens_encoder[bid] == 0) and (seq_lens_this_time[bid] > 1): - cur_len = actual_draft_token_nums[bid] - if accept_num[bid] - 1 == cur_len: # 全部接受 - if cur_len + 2 <= max_draft_tokens - 1: - cur_len += 2 - elif cur_len + 1 <= max_draft_tokens - 1: - cur_len += 1 - else: - cur_len = max_draft_tokens - 1 - else: # 有拒绝 - cur_len = max(1, cur_len - 1) - actual_draft_token_nums[bid] = cur_len - - # 偿还 encoder 欠账 - if seq_lens_encoder[bid] != 0: - seq_lens_decoder[bid] += seq_lens_encoder[bid] - seq_lens_encoder[bid] = 0 - - # 写回下一轮首 token - draft_tokens[bid, 0] = accept_tokens[bid, accept_num[bid] - 1] - - # 停止则清零 decoder - if stop_flag_now_int: - seq_lens_decoder[bid] = 0 - - elif inactive: - stop_flag_now_int = 1 # padding slot 视为 stop - - stop_sum += stop_flag_now_int - - # print("stop_sum: ", stop_sum) - not_need_stop[0] = stop_sum < stop_nums[0] - - # 返回引用,仅供一致性 - return ( - seq_lens_encoder, - seq_lens_decoder, - not_need_stop, - draft_tokens, - actual_draft_token_nums, - ) - - -# ---------------- 生成随机输入 ---------------- def gen_inputs( max_bsz=512, # 与 CUDA BlockSize 对齐 max_draft_tokens=16, real_bsz=123, # 可自调;须 ≤ max_bsz seed=2022, ): + """生成随机测试输入数据""" rng = np.random.default_rng(seed) # 基本张量 @@ -122,89 +43,91 @@ def gen_inputs( stop_nums = np.array([5], dtype=np.int64) # 阈值随意 # seq_lens_this_time 仅取 real_bsz 长度 - seq_lens_this_time = rng.integers(1, max_draft_tokens, size=real_bsz, dtype=np.int32) + seq_lens_this_time = rng.integers(1, max_draft_tokens + 1, size=real_bsz, dtype=np.int32) - return { - "seq_lens_encoder": seq_lens_encoder, - "seq_lens_decoder": seq_lens_decoder, - "not_need_stop": not_need_stop, - "draft_tokens": draft_tokens, - "actual_draft_token_nums": actual_draft_nums, - "accept_tokens": accept_tokens, - "accept_num": accept_num, - "stop_flags": stop_flags, - "seq_lens_this_time": seq_lens_this_time, - "is_block_step": is_block_step, - "stop_nums": stop_nums, - # real_bsz = real_bsz, - # max_bsz = max_bsz, - # max_draft_tokens = max_draft_tokens + paddle.set_device("xpu:0") + data_xpu = { + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + "seq_lens_decoder": paddle.to_tensor(seq_lens_decoder), + "not_need_stop": paddle.to_tensor(not_need_stop).cpu(), + "draft_tokens": paddle.to_tensor(draft_tokens), + "actual_draft_token_nums": paddle.to_tensor(actual_draft_nums), + "accept_tokens": paddle.to_tensor(accept_tokens), + "accept_num": paddle.to_tensor(accept_num), + "stop_flags": paddle.to_tensor(stop_flags), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "is_block_step": paddle.to_tensor(is_block_step), + "stop_nums": paddle.to_tensor(stop_nums), } - -# ------------------- 单测主体 ------------------- -inputs = gen_inputs(max_bsz=512, max_draft_tokens=32, real_bsz=201) - -# ---- Paddle 端 ---- -paddle_inputs = {} -for k, v in inputs.items(): - if k in ("real_bsz", "max_bsz", "max_draft_tokens"): - paddle_inputs[k] = v # 纯 python int - else: - if k == "not_need_stop": - paddle_inputs[k] = paddle.to_tensor(v, place=paddle.CPUPlace()) - else: - # 其余张量保持默认 place(想测 GPU 就手动加 place=paddle.CUDAPlace(0)) - paddle_inputs[k] = paddle.to_tensor(v) - -# ---- NumPy 端 ---- -# 为保证初值一致,这里必须复制 Paddle 入参的 numpy 值再传给参考实现 -np_inputs = { - k: (paddle_inputs[k].numpy().copy() if isinstance(paddle_inputs[k], paddle.Tensor) else paddle_inputs[k]) - for k in paddle_inputs -} - -# 调用自定义算子 -# print("seq_lens_encoder_xpu_before: ", paddle_inputs["seq_lens_encoder"]) -out_pd = speculate_update_v3(**paddle_inputs) -# print("seq_lens_encoder_xpu_after: ", out_pd[0]) -# print("not_need_stop: ", out_pd[2]) - -# speculate_update_v3 返回 5 个张量(与 Outputs 对应) -( - seq_lens_encoder_pd, - seq_lens_decoder_pd, - not_need_stop_pd, - draft_tokens_pd, - actual_draft_nums_pd, -) = out_pd - -# print("seq_lens_encoder_np_before: ", np_inputs["seq_lens_encoder"]) -out_np = speculate_update_v3_np(**np_inputs) -# print("seq_lens_encoder_np_after: ", out_np[0]) -# print("not_need_stop: ", out_np[2]) + paddle.set_device("cpu") + data_cpu = { + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + "seq_lens_decoder": paddle.to_tensor(seq_lens_decoder), + "not_need_stop": paddle.to_tensor(not_need_stop), + "draft_tokens": paddle.to_tensor(draft_tokens), + "actual_draft_token_nums": paddle.to_tensor(actual_draft_nums), + "accept_tokens": paddle.to_tensor(accept_tokens), + "accept_num": paddle.to_tensor(accept_num), + "stop_flags": paddle.to_tensor(stop_flags), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "is_block_step": paddle.to_tensor(is_block_step), + "stop_nums": paddle.to_tensor(stop_nums), + } + return data_xpu, data_cpu -# ---------------- 校对 ---------------- -names = [ - "seq_lens_encoder", - "seq_lens_decoder", - "not_need_stop", - "draft_tokens", - "actual_draft_token_nums", -] -pd_tensors = [ - seq_lens_encoder_pd, - seq_lens_decoder_pd, - not_need_stop_pd, - draft_tokens_pd, - actual_draft_nums_pd, -] +class TestSpeculateUpdateV3(unittest.TestCase): + """测试 speculate_update_v3 算子""" -for name, pd_val, np_val in zip(names, pd_tensors, out_np): - pd_arr = pd_val.numpy() - ok = np.array_equal(pd_arr, np_val) - print(f"{name:25s} equal :", ok) + def test_op_vs_golden(self, max_bsz=512, max_draft_tokens=16, real_bsz=123): + """ + 核心测试:比较自定义算子的输出与纯 NumPy 参考实现的输出。 + """ + # 1. gen inputs for cpu/xpu + data_xpu, data_cpu = gen_inputs(max_bsz=max_bsz, max_draft_tokens=max_draft_tokens, real_bsz=real_bsz) - # 也可以加 assert,配合 pytest - # assert all(np.array_equal(p.numpy(), n) for p,n in zip(pd_tensors, out_np)) + # 3. run xpu kernel + speculate_update_v3(**data_xpu) + + # 4. run cpu kernel + speculate_update_v3(**data_cpu) + + # 5. format outputs + outputs_xpu = [ + data_xpu["seq_lens_encoder"].cpu().numpy(), + data_xpu["seq_lens_decoder"].cpu().numpy(), + data_xpu["not_need_stop"].cpu().numpy(), + data_xpu["draft_tokens"].cpu().numpy(), + data_xpu["actual_draft_token_nums"].cpu().numpy(), + ] + + outputs_cpu = [ + data_cpu["seq_lens_encoder"].numpy(), + data_cpu["seq_lens_decoder"].numpy(), + data_cpu["not_need_stop"].numpy(), + data_cpu["draft_tokens"].numpy(), + data_cpu["actual_draft_token_nums"].numpy(), + ] + output_names = [ + "seq_lens_encoder", + "seq_lens_decoder", + "not_need_stop", + "draft_tokens", + "actual_draft_token_nums", + ] + + # 6. check outputs + for name, pd_out, np_out in zip(output_names, outputs_xpu, outputs_cpu): + with self.subTest(output_name=name): + np.testing.assert_allclose( + pd_out, + np_out, + atol=0, + rtol=1e-6, + err_msg=f"Output mismatch for tensor '{name}'.\nPaddle Output:\n{pd_out}\nGolden Output:\n{np_out}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py new file mode 100644 index 000000000..9a2ea16aa --- /dev/null +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -0,0 +1,315 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Dict, Optional + +import paddle + +from fastdeploy import envs +from fastdeploy.model_executor.forward_meta import XPUForwardMeta +from fastdeploy.platforms import current_platform +from fastdeploy.worker.output import ModelOutputData + +if current_platform.is_xpu(): + from fastdeploy.model_executor.ops.xpu import ( + adjust_batch, + gather_next_token, + get_infer_param, + get_padding_offset, + limit_thinking_content_length_v1, + limit_thinking_content_length_v2, + update_inputs_v1, + ) + + +def xpu_pre_process( + input_ids: paddle.Tensor, + seq_lens_this_time: int, + share_inputs: Dict, + use_speculate_method: bool, + block_size: int, + draft_tokens: Optional[paddle.Tensor] = None, + seq_lens_encoder: Optional[paddle.Tensor] = None, + seq_lens_decoder: Optional[paddle.Tensor] = None, + is_profiling: bool = False, +) -> XPUForwardMeta: + """ """ + max_len = input_ids.shape[1] + cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32") + token_num = paddle.sum(seq_lens_this_time) + + ( + ids_remove_padding, + cum_offsets, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + ) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) + + share_inputs["ids_remove_padding"] = None # set this after adjust batch + share_inputs["cum_offsets"] = cum_offsets + share_inputs["batch_id_per_token"] = batch_id_per_token + share_inputs["cu_seqlens_q"] = cu_seqlens_q + share_inputs["cu_seqlens_k"] = cu_seqlens_k + + xpu_forward_meta = XPUForwardMeta( + ids_remove_padding=share_inputs["ids_remove_padding"], + rotary_embs=share_inputs["rope_emb"], + attn_backend=None, + seq_lens_encoder=share_inputs["seq_lens_encoder"], + seq_lens_decoder=share_inputs["seq_lens_decoder"], + seq_lens_this_time=share_inputs["seq_lens_this_time"], + cum_offsets=share_inputs["cum_offsets"], + batch_id_per_token=share_inputs["batch_id_per_token"], + cu_seqlens_q=share_inputs["cu_seqlens_q"], + cu_seqlens_k=share_inputs["cu_seqlens_k"], + block_tables=share_inputs["block_tables"], + caches=share_inputs["caches"], + ) + + ( + xpu_forward_meta.encoder_batch_map, + xpu_forward_meta.decoder_batch_map, + xpu_forward_meta.encoder_batch_idx, + xpu_forward_meta.decoder_batch_idx, + xpu_forward_meta.encoder_seq_lod, + xpu_forward_meta.decoder_seq_lod, + xpu_forward_meta.encoder_kv_lod, + xpu_forward_meta.prefix_len, + xpu_forward_meta.decoder_context_len, + xpu_forward_meta.decoder_context_len_cache, + xpu_forward_meta.prefix_block_tables, + xpu_forward_meta.encoder_batch_map_cpu, + xpu_forward_meta.decoder_batch_map_cpu, + xpu_forward_meta.encoder_batch_idx_cpu, + xpu_forward_meta.decoder_batch_idx_cpu, + xpu_forward_meta.encoder_seq_lod_cpu, + xpu_forward_meta.decoder_seq_lod_cpu, + xpu_forward_meta.encoder_kv_lod_cpu, + xpu_forward_meta.prefix_len_cpu, + xpu_forward_meta.decoder_context_len_cpu, + xpu_forward_meta.decoder_context_len_cache_cpu, + xpu_forward_meta.len_info_cpu, + ) = get_infer_param( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, xpu_forward_meta.block_tables, block_size + ) + xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0] + xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1] + xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2] + + adjusted_input = adjust_batch( + ids_remove_padding.reshape([-1, 1]), + cum_offsets, + xpu_forward_meta.encoder_seq_lod, + xpu_forward_meta.decoder_seq_lod, + xpu_forward_meta.encoder_batch_idx, + xpu_forward_meta.decoder_batch_idx, + xpu_forward_meta.encoder_seq_lod_cpu, + xpu_forward_meta.decoder_seq_lod_cpu, + xpu_forward_meta.encoder_batch_idx_cpu, + xpu_forward_meta.decoder_batch_idx_cpu, + xpu_forward_meta.len_info_cpu, + None, # output_padding_offset + -1, # max bs + ) + + adjusted_input = adjusted_input.squeeze(1) + + share_inputs["ids_remove_padding"] = adjusted_input + xpu_forward_meta.ids_remove_padding = adjusted_input + # Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends + xpu_forward_meta.is_profiling = is_profiling + return xpu_forward_meta + + +def xpu_process_output( + forward_output, + cum_offsets: paddle.Tensor, + xpu_forward_meta: XPUForwardMeta, + share_inputs, +) -> paddle.Tensor: + """ """ + + output_padding_offset = share_inputs.get("output_padding_offset", None) + + hiddden_states = gather_next_token( + forward_output, + cum_offsets, + xpu_forward_meta.encoder_seq_lod, + xpu_forward_meta.decoder_seq_lod, + xpu_forward_meta.encoder_batch_map, + xpu_forward_meta.decoder_batch_map, + xpu_forward_meta.encoder_seq_lod_cpu, + xpu_forward_meta.decoder_seq_lod_cpu, + xpu_forward_meta.encoder_batch_map_cpu, + xpu_forward_meta.decoder_batch_map_cpu, + xpu_forward_meta.len_info_cpu, + output_padding_offset, # output_padding_offset + -1, # max_input_length + ) + return hiddden_states + + +def xpu_post_process_normal( + sampled_token_ids: paddle.Tensor, + model_output: ModelOutputData, + share_inputs: Dict[str, paddle.Tensor], + block_size: int = 64, + skip_save_output: bool = False, + think_end_id: int = None, + line_break_id: int = None, +) -> None: + """ """ + from fastdeploy.model_executor.ops.xpu import ( + save_output, + set_stop_value_multi_ends, + update_inputs, + ) + + if think_end_id > 0: + limit_strategy = envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR + max_think_lens = share_inputs["max_think_lens"] + step_idx = share_inputs["step_idx"] + limit_think_status = share_inputs["limit_think_status"] + stop_flags = share_inputs["stop_flags"] + eos_token_ids = share_inputs["eos_token_id"] + if limit_strategy == "": + # for ernie-45-vl + limit_thinking_content_length_v1( + sampled_token_ids, + max_think_lens, + step_idx, + limit_think_status, + stop_flags, + eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题 + think_end_id, + ) + elif limit_strategy == "\n\n\n": + # for ernie-x1 + assert line_break_id > 0 + limit_thinking_content_length_v2( + sampled_token_ids, + max_think_lens, + step_idx, + limit_think_status, + stop_flags, + think_end_id, + line_break_id, + ) + else: + raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.") + + # 1. Set stop value + paddle.assign( + paddle.where( + model_output.stop_flags, + model_output.step_idx, + model_output.step_idx + 1, + ), + model_output.step_idx, + ) + length_cond = paddle.greater_equal(model_output.step_idx, model_output.max_dec_len) + paddle.assign( + paddle.logical_or(model_output.stop_flags, length_cond), + model_output.stop_flags, + ) + set_stop_value_multi_ends( + sampled_token_ids, + model_output.stop_flags, + model_output.seq_lens_this_time, + model_output.eos_token_id, + model_output.next_tokens, + False, + ) # multi ends + + # 2. Update the input buffer of the model + with paddle.framework._no_check_dy2st_diff(): + if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output: + update_inputs_v1( + model_output.stop_flags, + model_output.not_need_stop, + model_output.seq_lens_this_time, + model_output.seq_lens_encoder, + model_output.seq_lens_decoder, + share_inputs["step_seq_lens_decoder"], + share_inputs["prompt_lens"], + sampled_token_ids, + model_output.input_ids, + share_inputs["block_tables"], + model_output.stop_nums, + model_output.next_tokens, + model_output.is_block_step, + block_size, + ) + else: + update_inputs( + model_output.stop_flags, + model_output.not_need_stop, + model_output.seq_lens_this_time, + model_output.seq_lens_encoder, + model_output.seq_lens_decoder, + model_output.input_ids, + model_output.stop_nums, + sampled_token_ids, + model_output.is_block_step, + ) + # 3. Transmit the model's output and stop generation signal via message queue. + # In the future, we will abandon this approach. + if not skip_save_output: + save_output( + sampled_token_ids, + model_output.not_need_stop, + model_output.mp_rank, + False, # use_ep + ) + + +def step_xpu( + share_inputs: Dict[str, paddle.Tensor], + block_size: int, + enc_dec_block_num: int, +) -> None: + """ + TODO(gongshaotian): normalization name + """ + from fastdeploy.model_executor.ops.xpu import step_paddle + + step_paddle( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + block_size, + enc_dec_block_num, + ) diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index b60eb8cdf..6338965d2 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -17,7 +17,7 @@ import os import random import time -from typing import Dict, List, Optional +from typing import List, Optional import numpy as np import paddle @@ -28,7 +28,7 @@ from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request, RequestType from fastdeploy.input.ernie4_5_vl_processor import DataProcessor from fastdeploy.inter_communicator import IPCSignal -from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta +from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.utils import ( profile_run_guard, sot_warmup_guard, @@ -43,17 +43,17 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp from fastdeploy.model_executor.ops.xpu import ( - adjust_batch, create_kv_signal_sender, destroy_kv_signal_sender, - get_infer_param, - get_padding_offset, - limit_thinking_content_length_v1, - limit_thinking_content_length_v2, recover_decode_task, set_data_ipc, share_external_data, - update_inputs_v1, +) +from fastdeploy.model_executor.xpu_pre_and_post_process import ( # xpu_post_process_specualate, # TODO(chenhuan09): add xpu_post_process_specualate + step_xpu, + xpu_post_process_normal, + xpu_pre_process, + xpu_process_output, ) from fastdeploy.utils import get_logger from fastdeploy.worker.model_runner_base import ModelRunnerBase @@ -62,282 +62,6 @@ from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput logger = get_logger("xpu_model_runner", "xpu_model_runner.log") -def xpu_pre_process( - input_ids: paddle.Tensor, - seq_lens_this_time: int, - share_inputs: Dict, - use_speculate_method: bool, - block_size: int, - draft_tokens: Optional[paddle.Tensor] = None, - seq_lens_encoder: Optional[paddle.Tensor] = None, - seq_lens_decoder: Optional[paddle.Tensor] = None, - is_profiling: bool = False, -) -> XPUForwardMeta: - """ """ - max_len = input_ids.shape[1] - cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32") - token_num = paddle.sum(seq_lens_this_time) - - ( - ids_remove_padding, - cum_offsets, - batch_id_per_token, - cu_seqlens_q, - cu_seqlens_k, - ) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) - - share_inputs["ids_remove_padding"] = None # set this after adjust batch - share_inputs["cum_offsets"] = cum_offsets - share_inputs["batch_id_per_token"] = batch_id_per_token - share_inputs["cu_seqlens_q"] = cu_seqlens_q - share_inputs["cu_seqlens_k"] = cu_seqlens_k - - xpu_forward_meta = XPUForwardMeta( - ids_remove_padding=share_inputs["ids_remove_padding"], - rotary_embs=share_inputs["rope_emb"], - attn_backend=None, - seq_lens_encoder=share_inputs["seq_lens_encoder"], - seq_lens_decoder=share_inputs["seq_lens_decoder"], - seq_lens_this_time=share_inputs["seq_lens_this_time"], - cum_offsets=share_inputs["cum_offsets"], - batch_id_per_token=share_inputs["batch_id_per_token"], - cu_seqlens_q=share_inputs["cu_seqlens_q"], - cu_seqlens_k=share_inputs["cu_seqlens_k"], - block_tables=share_inputs["block_tables"], - caches=share_inputs["caches"], - ) - - ( - xpu_forward_meta.encoder_batch_map, - xpu_forward_meta.decoder_batch_map, - xpu_forward_meta.encoder_batch_idx, - xpu_forward_meta.decoder_batch_idx, - xpu_forward_meta.encoder_seq_lod, - xpu_forward_meta.decoder_seq_lod, - xpu_forward_meta.encoder_kv_lod, - xpu_forward_meta.prefix_len, - xpu_forward_meta.decoder_context_len, - xpu_forward_meta.decoder_context_len_cache, - xpu_forward_meta.prefix_block_tables, - xpu_forward_meta.encoder_batch_map_cpu, - xpu_forward_meta.decoder_batch_map_cpu, - xpu_forward_meta.encoder_batch_idx_cpu, - xpu_forward_meta.decoder_batch_idx_cpu, - xpu_forward_meta.encoder_seq_lod_cpu, - xpu_forward_meta.decoder_seq_lod_cpu, - xpu_forward_meta.encoder_kv_lod_cpu, - xpu_forward_meta.prefix_len_cpu, - xpu_forward_meta.decoder_context_len_cpu, - xpu_forward_meta.decoder_context_len_cache_cpu, - xpu_forward_meta.len_info_cpu, - ) = get_infer_param( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, xpu_forward_meta.block_tables, block_size - ) - xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0] - xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1] - xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2] - - adjusted_input = adjust_batch( - ids_remove_padding.reshape([-1, 1]), - cum_offsets, - xpu_forward_meta.encoder_seq_lod, - xpu_forward_meta.encoder_batch_idx, - xpu_forward_meta.decoder_batch_idx, - xpu_forward_meta.encoder_seq_lod_cpu, - xpu_forward_meta.encoder_batch_idx_cpu, - xpu_forward_meta.decoder_batch_idx_cpu, - xpu_forward_meta.enc_batch, - xpu_forward_meta.dec_batch, - None, # output_padding_offset - -1, # max_input_length - ) - - adjusted_input = adjusted_input.squeeze(1) - - share_inputs["ids_remove_padding"] = adjusted_input - xpu_forward_meta.ids_remove_padding = adjusted_input - # Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends - xpu_forward_meta.is_profiling = is_profiling - return xpu_forward_meta - - -def xpu_process_output( - forward_output, - cum_offsets: paddle.Tensor, - xpu_forward_meta: XPUForwardMeta, -) -> paddle.Tensor: - """ """ - from fastdeploy.model_executor.ops.xpu import gather_next_token - - hiddden_states = gather_next_token( - forward_output, - cum_offsets, - xpu_forward_meta.encoder_seq_lod, - xpu_forward_meta.encoder_batch_map, - xpu_forward_meta.decoder_batch_map, - xpu_forward_meta.encoder_seq_lod_cpu, - xpu_forward_meta.encoder_batch_map_cpu, - xpu_forward_meta.decoder_batch_map_cpu, - xpu_forward_meta.enc_batch, - xpu_forward_meta.dec_batch, - None, # output_padding_offset - -1, # max_input_length - ) - return hiddden_states - - -def xpu_post_process( - sampled_token_ids: paddle.Tensor, - model_output: ModelOutputData, - share_inputs: Dict[str, paddle.Tensor], - block_size: int = 64, - skip_save_output: bool = False, - think_end_id: int = None, - line_break_id: int = None, -) -> None: - """ """ - from fastdeploy.model_executor.ops.xpu import ( - save_output, - set_stop_value_multi_ends, - update_inputs, - ) - - if think_end_id > 0: - limit_strategy = envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR - max_think_lens = share_inputs["max_think_lens"] - step_idx = share_inputs["step_idx"] - limit_think_status = share_inputs["limit_think_status"] - stop_flags = share_inputs["stop_flags"] - eos_token_ids = share_inputs["eos_token_id"] - if limit_strategy == "": - # for ernie-45-vl - limit_thinking_content_length_v1( - sampled_token_ids, - max_think_lens, - step_idx, - limit_think_status, - stop_flags, - eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题 - think_end_id, - ) - elif limit_strategy == "\n\n\n": - # for ernie-x1 - assert line_break_id > 0 - limit_thinking_content_length_v2( - sampled_token_ids, - max_think_lens, - step_idx, - limit_think_status, - stop_flags, - think_end_id, - line_break_id, - ) - else: - raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.") - - # 1. Set stop value - paddle.assign( - paddle.where( - model_output.stop_flags, - model_output.step_idx, - model_output.step_idx + 1, - ), - model_output.step_idx, - ) - length_cond = paddle.greater_equal(model_output.step_idx, model_output.max_dec_len) - paddle.assign( - paddle.logical_or(model_output.stop_flags, length_cond), - model_output.stop_flags, - ) - set_stop_value_multi_ends( - sampled_token_ids, - model_output.stop_flags, - model_output.seq_lens_this_time, - model_output.eos_token_id, - model_output.next_tokens, - False, - ) # multi ends - - # 2. Update the input buffer of the model - with paddle.framework._no_check_dy2st_diff(): - if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output: - update_inputs_v1( - model_output.stop_flags, - model_output.not_need_stop, - model_output.seq_lens_this_time, - model_output.seq_lens_encoder, - model_output.seq_lens_decoder, - share_inputs["step_seq_lens_decoder"], - share_inputs["prompt_lens"], - sampled_token_ids, - model_output.input_ids, - share_inputs["block_tables"], - model_output.stop_nums, - model_output.next_tokens, - model_output.is_block_step, - block_size, - ) - else: - update_inputs( - model_output.stop_flags, - model_output.not_need_stop, - model_output.seq_lens_this_time, - model_output.seq_lens_encoder, - model_output.seq_lens_decoder, - model_output.input_ids, - model_output.stop_nums, - sampled_token_ids, - model_output.is_block_step, - ) - # 3. Transmit the model's output and stop generation signal via message queue. - # In the future, we will abandon this approach. - if not skip_save_output: - save_output( - sampled_token_ids, - model_output.not_need_stop, - model_output.mp_rank, - False, # use_ep - ) - - -def step_paddle( - share_inputs: Dict[str, paddle.Tensor], - block_size: int, - enc_dec_block_num: int, -) -> None: - """ - TODO(gongshaotian): normalization name - """ - from fastdeploy.model_executor.ops.xpu import step_paddle - - step_paddle( - share_inputs["stop_flags"], - share_inputs["seq_lens_this_time"], - share_inputs["step_seq_lens_encoder"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs["block_tables"], - share_inputs["encoder_block_lens"], - share_inputs["is_block_step"], - share_inputs["step_block_list"], - share_inputs["step_lens"], - share_inputs["recover_block_list"], - share_inputs["recover_lens"], - share_inputs["need_block_list"], - share_inputs["need_block_len"], - share_inputs["used_list_len"], - share_inputs["free_list"], - share_inputs["free_list_len"], - share_inputs["input_ids"], - share_inputs["pre_ids"], - share_inputs["step_idx"], - share_inputs["next_tokens"], - share_inputs["first_token_ids"], - block_size, - enc_dec_block_num, - ) - - class XPUModelRunner(ModelRunnerBase): """ """ @@ -1212,8 +936,9 @@ class XPUModelRunner(ModelRunnerBase): forward_meta=self.forward_meta, ) - hidden_states = xpu_process_output(model_output, self.share_inputs["cum_offsets"], self.forward_meta) - + hidden_states = xpu_process_output( + model_output, self.share_inputs["cum_offsets"], self.forward_meta, self.share_inputs + ) # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states) sampler_output = self.sampler(logits, self.sampling_metadata) @@ -1247,7 +972,7 @@ class XPUModelRunner(ModelRunnerBase): stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], ) - xpu_post_process( + xpu_post_process_normal( sampled_token_ids=sampler_output.sampled_token_ids, model_output=model_output_data, share_inputs=self.share_inputs, @@ -1260,7 +985,7 @@ class XPUModelRunner(ModelRunnerBase): # 7. Updata 'infer_seed' and step_paddle() self.share_inputs["infer_seed"].add_(self.infer_seed_increment) self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED - step_paddle( + step_xpu( self.share_inputs, self.cache_config.block_size, self.cache_config.enc_dec_block_num,