mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
fix_gather_next_token (#5311)
This commit is contained in:
@@ -89,32 +89,27 @@ std::vector<paddle::Tensor> GatherNextToken(
|
|||||||
return {out};
|
return {out};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (enc_batch <= 0) {
|
if (output_padding_offset) {
|
||||||
out = x.copy_to(x.place(), false);
|
int r = baidu::xpu::api::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
|
||||||
|
ctx,
|
||||||
|
reinterpret_cast<const XPUType*>(x.data<data_t>()),
|
||||||
|
reinterpret_cast<XPUType*>(out.data<data_t>()),
|
||||||
|
encoder_seqs_lods_vp,
|
||||||
|
decoder_seqs_lods_vp,
|
||||||
|
encoder_batch_map_vp,
|
||||||
|
decoder_batch_map_vp,
|
||||||
|
dim);
|
||||||
|
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
|
||||||
} else {
|
} else {
|
||||||
if (output_padding_offset) {
|
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
|
||||||
int r =
|
ctx,
|
||||||
baidu::xpu::api::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
|
reinterpret_cast<const XPUType*>(x.data<data_t>()),
|
||||||
ctx,
|
reinterpret_cast<XPUType*>(out.data<data_t>()),
|
||||||
reinterpret_cast<const XPUType*>(x.data<data_t>()),
|
encoder_seqs_lods_vp,
|
||||||
reinterpret_cast<XPUType*>(out.data<data_t>()),
|
encoder_batch_map_vp,
|
||||||
encoder_seqs_lods_vp,
|
decoder_batch_map_vp,
|
||||||
decoder_seqs_lods_vp,
|
dim);
|
||||||
encoder_batch_map_vp,
|
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
|
||||||
decoder_batch_map_vp,
|
|
||||||
dim);
|
|
||||||
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
|
|
||||||
} else {
|
|
||||||
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
|
|
||||||
ctx,
|
|
||||||
reinterpret_cast<const XPUType*>(x.data<data_t>()),
|
|
||||||
reinterpret_cast<XPUType*>(out.data<data_t>()),
|
|
||||||
encoder_seqs_lods_vp,
|
|
||||||
encoder_batch_map_vp,
|
|
||||||
decoder_batch_map_vp,
|
|
||||||
dim);
|
|
||||||
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return {out};
|
return {out};
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user