fix_gather_next_token (#5311)

This commit is contained in:
cmcamdy
2025-12-01 18:00:30 +08:00
committed by GitHub
parent 0925d44f18
commit 3149aed750

View File

@@ -89,32 +89,27 @@ std::vector<paddle::Tensor> GatherNextToken(
return {out}; return {out};
} }
if (enc_batch <= 0) { if (output_padding_offset) {
out = x.copy_to(x.place(), false); int r = baidu::xpu::api::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x.data<data_t>()),
reinterpret_cast<XPUType*>(out.data<data_t>()),
encoder_seqs_lods_vp,
decoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
} else { } else {
if (output_padding_offset) { int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
int r = ctx,
baidu::xpu::api::plugin::eb_mtp_gather_next_token<XPUType, XPUType>( reinterpret_cast<const XPUType*>(x.data<data_t>()),
ctx, reinterpret_cast<XPUType*>(out.data<data_t>()),
reinterpret_cast<const XPUType*>(x.data<data_t>()), encoder_seqs_lods_vp,
reinterpret_cast<XPUType*>(out.data<data_t>()), encoder_batch_map_vp,
encoder_seqs_lods_vp, decoder_batch_map_vp,
decoder_seqs_lods_vp, dim);
encoder_batch_map_vp, PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
decoder_batch_map_vp,
dim);
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
} else {
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x.data<data_t>()),
reinterpret_cast<XPUType*>(out.data<data_t>()),
encoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
}
} }
return {out}; return {out};
} }