mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
fix_gather_next_token (#5311)
This commit is contained in:
@@ -89,32 +89,27 @@ std::vector<paddle::Tensor> GatherNextToken(
|
||||
return {out};
|
||||
}
|
||||
|
||||
if (enc_batch <= 0) {
|
||||
out = x.copy_to(x.place(), false);
|
||||
if (output_padding_offset) {
|
||||
int r = baidu::xpu::api::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
|
||||
ctx,
|
||||
reinterpret_cast<const XPUType*>(x.data<data_t>()),
|
||||
reinterpret_cast<XPUType*>(out.data<data_t>()),
|
||||
encoder_seqs_lods_vp,
|
||||
decoder_seqs_lods_vp,
|
||||
encoder_batch_map_vp,
|
||||
decoder_batch_map_vp,
|
||||
dim);
|
||||
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
|
||||
} else {
|
||||
if (output_padding_offset) {
|
||||
int r =
|
||||
baidu::xpu::api::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
|
||||
ctx,
|
||||
reinterpret_cast<const XPUType*>(x.data<data_t>()),
|
||||
reinterpret_cast<XPUType*>(out.data<data_t>()),
|
||||
encoder_seqs_lods_vp,
|
||||
decoder_seqs_lods_vp,
|
||||
encoder_batch_map_vp,
|
||||
decoder_batch_map_vp,
|
||||
dim);
|
||||
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
|
||||
} else {
|
||||
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
|
||||
ctx,
|
||||
reinterpret_cast<const XPUType*>(x.data<data_t>()),
|
||||
reinterpret_cast<XPUType*>(out.data<data_t>()),
|
||||
encoder_seqs_lods_vp,
|
||||
encoder_batch_map_vp,
|
||||
decoder_batch_map_vp,
|
||||
dim);
|
||||
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
|
||||
}
|
||||
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
|
||||
ctx,
|
||||
reinterpret_cast<const XPUType*>(x.data<data_t>()),
|
||||
reinterpret_cast<XPUType*>(out.data<data_t>()),
|
||||
encoder_seqs_lods_vp,
|
||||
encoder_batch_map_vp,
|
||||
decoder_batch_map_vp,
|
||||
dim);
|
||||
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user