diff --git a/custom_ops/xpu_ops/src/ops/get_padding_offset.cc b/custom_ops/xpu_ops/src/ops/get_padding_offset.cc index 7c1824372..8178e8cdd 100644 --- a/custom_ops/xpu_ops/src/ops/get_padding_offset.cc +++ b/custom_ops/xpu_ops/src/ops/get_padding_offset.cc @@ -39,19 +39,22 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_k = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); - int r = baidu::xpu::api::plugin::get_padding_offset( - xpu_ctx->x_context(), - batch_id_per_token.data(), - cum_offsets_out.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - x_remove_padding.data(), - input_ids.data(), - cum_offsets.data(), - seq_len.data(), - seq_length, - bsz); - PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed."); + if (token_num_data > 0) { + int r = baidu::xpu::api::plugin::get_padding_offset( + xpu_ctx->x_context(), + batch_id_per_token.data(), + cum_offsets_out.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + x_remove_padding.data(), + input_ids.data(), + cum_offsets.data(), + seq_len.data(), + seq_length, + bsz); + PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed."); + } + return {x_remove_padding, cum_offsets_out, batch_id_per_token, diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_output_padding_offset.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_output_padding_offset.cc index 31d0e1fac..6ad030d50 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_output_padding_offset.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_output_padding_offset.cc @@ -44,16 +44,17 @@ std::vector SpeculateGetOutputPaddingOffset( output_cum_offsets_tmp.place()); auto output_cum_offsets = output_cum_offsets_tmp.copy_to(output_cum_offsets_tmp.place(), false); - - int r = baidu::xpu::api::plugin::speculate_get_output_padding_offset( - ctx, - output_padding_offset.mutable_data(), - output_cum_offsets.mutable_data(), - output_cum_offsets_tmp.data(), - seq_lens_output.data(), - bsz, - max_seq_len); - PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed."); + if (cpu_out_token_num.data()[0] > 0) { + int r = baidu::xpu::api::plugin::speculate_get_output_padding_offset( + ctx, + output_padding_offset.mutable_data(), + output_cum_offsets.mutable_data(), + output_cum_offsets_tmp.data(), + seq_lens_output.data(), + bsz, + max_seq_len); + PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed."); + } return {output_padding_offset, output_cum_offsets}; }