From ec6811f6486a95bb54ca447c4047fece021ed8ee Mon Sep 17 00:00:00 2001 From: lizan1999 <55830407+lizan1999@users.noreply.github.com> Date: Fri, 19 Dec 2025 10:20:38 +0800 Subject: [PATCH] support token num = 0 (#5635) Co-authored-by: lizan1999 Co-authored-by: cmcamdy <1027740945@qq.com> Co-authored-by: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com> --- .../xpu_ops/src/ops/get_padding_offset.cc | 29 ++++++++++--------- .../speculate_get_output_padding_offset.cc | 21 +++++++------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/get_padding_offset.cc b/custom_ops/xpu_ops/src/ops/get_padding_offset.cc index 7c1824372..8178e8cdd 100644 --- a/custom_ops/xpu_ops/src/ops/get_padding_offset.cc +++ b/custom_ops/xpu_ops/src/ops/get_padding_offset.cc @@ -39,19 +39,22 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_k = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); - int r = baidu::xpu::api::plugin::get_padding_offset( - xpu_ctx->x_context(), - batch_id_per_token.data(), - cum_offsets_out.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - x_remove_padding.data(), - input_ids.data(), - cum_offsets.data(), - seq_len.data(), - seq_length, - bsz); - PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed."); + if (token_num_data > 0) { + int r = baidu::xpu::api::plugin::get_padding_offset( + xpu_ctx->x_context(), + batch_id_per_token.data(), + cum_offsets_out.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + x_remove_padding.data(), + input_ids.data(), + cum_offsets.data(), + seq_len.data(), + seq_length, + bsz); + PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed."); + } + return {x_remove_padding, cum_offsets_out, batch_id_per_token, diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_output_padding_offset.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_output_padding_offset.cc index 31d0e1fac..6ad030d50 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_output_padding_offset.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_output_padding_offset.cc @@ -44,16 +44,17 @@ std::vector SpeculateGetOutputPaddingOffset( output_cum_offsets_tmp.place()); auto output_cum_offsets = output_cum_offsets_tmp.copy_to(output_cum_offsets_tmp.place(), false); - - int r = baidu::xpu::api::plugin::speculate_get_output_padding_offset( - ctx, - output_padding_offset.mutable_data(), - output_cum_offsets.mutable_data(), - output_cum_offsets_tmp.data(), - seq_lens_output.data(), - bsz, - max_seq_len); - PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed."); + if (cpu_out_token_num.data()[0] > 0) { + int r = baidu::xpu::api::plugin::speculate_get_output_padding_offset( + ctx, + output_padding_offset.mutable_data(), + output_cum_offsets.mutable_data(), + output_cum_offsets_tmp.data(), + seq_lens_output.data(), + bsz, + max_seq_len); + PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed."); + } return {output_padding_offset, output_cum_offsets}; }