[XPU] Update the return value of TextImageGatherScatter (#4636)

Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
ddchenhao66
2025-10-29 16:17:01 +08:00
committed by GitHub
parent 14f8cddaf1
commit c92eeed45d
2 changed files with 26 additions and 23 deletions

View File

@@ -482,13 +482,14 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
void SetDataIpc(const paddle::Tensor& tmp_input, const std::string& shm_name);
void TextImageGatherScatter(paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter);
std::vector<paddle::Tensor> TextImageGatherScatter(
paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter);
void TextImageIndexOut(const paddle::Tensor& token_type_ids,
const paddle::Tensor& text_index,

View File

@@ -17,13 +17,18 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"
void TextImageGatherScatter(paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter) {
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
std::vector<paddle::Tensor> TextImageGatherScatter(
paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
@@ -58,22 +63,19 @@ void TextImageGatherScatter(paddle::Tensor& input,
break;
}
}
return {input, text_input, image_input};
}
PD_BUILD_OP(text_image_gather_scatter)
PD_BUILD_STATIC_OP(text_image_gather_scatter)
.Inputs({"input",
"text_input",
"image_input",
"token_type_ids",
"text_index",
"image_index"})
.Outputs({"text_input_out",
"image_input_out",
"text_index_out",
"image_index_out"})
.Outputs({"output", "text_input_out", "image_input_out"})
.Attrs({"is_scatter:bool"})
.SetInplaceMap({{"text_input", "text_input_out"},
{"image_input", "image_input_out"},
{"text_index", "text_index_out"},
{"image_index", "image_index_out"}})
.SetInplaceMap({{"input", "output"},
{"text_input", "text_input_out"},
{"image_input", "image_input_out"}})
.SetKernelFn(PD_KERNEL(TextImageGatherScatter));