support token num = 0 (#5635)

Co-authored-by: lizan1999 <lizan03@baidu.com>
Co-authored-by: cmcamdy <1027740945@qq.com>
Co-authored-by: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com>
This commit is contained in:
lizan1999
2025-12-19 10:20:38 +08:00
committed by GitHub
parent d657455616
commit ec6811f648
2 changed files with 27 additions and 23 deletions

View File

@@ -39,19 +39,22 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
int r = baidu::xpu::api::plugin::get_padding_offset(
xpu_ctx->x_context(),
batch_id_per_token.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length,
bsz);
PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed.");
if (token_num_data > 0) {
int r = baidu::xpu::api::plugin::get_padding_offset(
xpu_ctx->x_context(),
batch_id_per_token.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length,
bsz);
PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed.");
}
return {x_remove_padding,
cum_offsets_out,
batch_id_per_token,

View File

@@ -44,16 +44,17 @@ std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
output_cum_offsets_tmp.place());
auto output_cum_offsets =
output_cum_offsets_tmp.copy_to(output_cum_offsets_tmp.place(), false);
int r = baidu::xpu::api::plugin::speculate_get_output_padding_offset(
ctx,
output_padding_offset.mutable_data<int>(),
output_cum_offsets.mutable_data<int>(),
output_cum_offsets_tmp.data<int>(),
seq_lens_output.data<int>(),
bsz,
max_seq_len);
PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed.");
if (cpu_out_token_num.data<int64_t>()[0] > 0) {
int r = baidu::xpu::api::plugin::speculate_get_output_padding_offset(
ctx,
output_padding_offset.mutable_data<int>(),
output_cum_offsets.mutable_data<int>(),
output_cum_offsets_tmp.data<int>(),
seq_lens_output.data<int>(),
bsz,
max_seq_len);
PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed.");
}
return {output_padding_offset, output_cum_offsets};
}