fix_gather_next_token (#5311)

This commit is contained in:
cmcamdy
2025-12-01 18:00:30 +08:00
committed by GitHub
parent 0925d44f18
commit 3149aed750

View File

@@ -89,32 +89,27 @@ std::vector<paddle::Tensor> GatherNextToken(
return {out};
}
if (enc_batch <= 0) {
out = x.copy_to(x.place(), false);
if (output_padding_offset) {
int r = baidu::xpu::api::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x.data<data_t>()),
reinterpret_cast<XPUType*>(out.data<data_t>()),
encoder_seqs_lods_vp,
decoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
} else {
if (output_padding_offset) {
int r =
baidu::xpu::api::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x.data<data_t>()),
reinterpret_cast<XPUType*>(out.data<data_t>()),
encoder_seqs_lods_vp,
decoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
} else {
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x.data<data_t>()),
reinterpret_cast<XPUType*>(out.data<data_t>()),
encoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
}
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x.data<data_t>()),
reinterpret_cast<XPUType*>(out.data<data_t>()),
encoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
}
return {out};
}