mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
support token num = 0 (#5635)
Co-authored-by: lizan1999 <lizan03@baidu.com> Co-authored-by: cmcamdy <1027740945@qq.com> Co-authored-by: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com>
This commit is contained in:
@@ -39,19 +39,22 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
int r = baidu::xpu::api::plugin::get_padding_offset(
|
||||
xpu_ctx->x_context(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
x_remove_padding.data<int64_t>(),
|
||||
input_ids.data<int64_t>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_len.data<int>(),
|
||||
seq_length,
|
||||
bsz);
|
||||
PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed.");
|
||||
if (token_num_data > 0) {
|
||||
int r = baidu::xpu::api::plugin::get_padding_offset(
|
||||
xpu_ctx->x_context(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
x_remove_padding.data<int64_t>(),
|
||||
input_ids.data<int64_t>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_len.data<int>(),
|
||||
seq_length,
|
||||
bsz);
|
||||
PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed.");
|
||||
}
|
||||
|
||||
return {x_remove_padding,
|
||||
cum_offsets_out,
|
||||
batch_id_per_token,
|
||||
|
||||
@@ -44,16 +44,17 @@ std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
|
||||
output_cum_offsets_tmp.place());
|
||||
auto output_cum_offsets =
|
||||
output_cum_offsets_tmp.copy_to(output_cum_offsets_tmp.place(), false);
|
||||
|
||||
int r = baidu::xpu::api::plugin::speculate_get_output_padding_offset(
|
||||
ctx,
|
||||
output_padding_offset.mutable_data<int>(),
|
||||
output_cum_offsets.mutable_data<int>(),
|
||||
output_cum_offsets_tmp.data<int>(),
|
||||
seq_lens_output.data<int>(),
|
||||
bsz,
|
||||
max_seq_len);
|
||||
PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed.");
|
||||
if (cpu_out_token_num.data<int64_t>()[0] > 0) {
|
||||
int r = baidu::xpu::api::plugin::speculate_get_output_padding_offset(
|
||||
ctx,
|
||||
output_padding_offset.mutable_data<int>(),
|
||||
output_cum_offsets.mutable_data<int>(),
|
||||
output_cum_offsets_tmp.data<int>(),
|
||||
seq_lens_output.data<int>(),
|
||||
bsz,
|
||||
max_seq_len);
|
||||
PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed.");
|
||||
}
|
||||
|
||||
return {output_padding_offset, output_cum_offsets};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user