From c92eeed45d0e85883f5c972dd0512c2bb1fd7e41 Mon Sep 17 00:00:00 2001 From: ddchenhao66 <165133255+ddchenhao66@users.noreply.github.com> Date: Wed, 29 Oct 2025 16:17:01 +0800 Subject: [PATCH] [XPU] Update the return value of TextImageGatherScatter (#4636) Co-authored-by: ddchenhao66 --- custom_ops/xpu_ops/src/ops/pybind/pybind.cc | 15 ++++---- .../src/ops/text_image_gather_scatter.cc | 34 ++++++++++--------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 833cd04ec..79f89df37 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -482,13 +482,14 @@ std::vector SpeculateGetSeqLensOutput( void SetDataIpc(const paddle::Tensor& tmp_input, const std::string& shm_name); -void TextImageGatherScatter(paddle::Tensor& input, - paddle::Tensor& text_input, - paddle::Tensor& image_input, - paddle::Tensor& token_type_ids, - paddle::Tensor& text_index, - paddle::Tensor& image_index, - const bool is_scatter); +std::vector TextImageGatherScatter( + paddle::Tensor& input, + paddle::Tensor& text_input, + paddle::Tensor& image_input, + paddle::Tensor& token_type_ids, + paddle::Tensor& text_index, + paddle::Tensor& image_index, + const bool is_scatter); void TextImageIndexOut(const paddle::Tensor& token_type_ids, const paddle::Tensor& text_index, diff --git a/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc b/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc index 1df2ba82b..4d041528c 100644 --- a/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc +++ b/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc @@ -17,13 +17,18 @@ #include "paddle/extension.h" #include "xpu/plugin.h" -void TextImageGatherScatter(paddle::Tensor& input, - paddle::Tensor& text_input, - paddle::Tensor& image_input, - paddle::Tensor& token_type_ids, - paddle::Tensor& text_index, - paddle::Tensor& image_index, - const bool is_scatter) { +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +std::vector TextImageGatherScatter( + paddle::Tensor& input, + paddle::Tensor& text_input, + paddle::Tensor& image_input, + paddle::Tensor& token_type_ids, + paddle::Tensor& text_index, + paddle::Tensor& image_index, + const bool is_scatter) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); auto xpu_ctx = static_cast(dev_ctx); @@ -58,22 +63,19 @@ void TextImageGatherScatter(paddle::Tensor& input, break; } } + return {input, text_input, image_input}; } -PD_BUILD_OP(text_image_gather_scatter) +PD_BUILD_STATIC_OP(text_image_gather_scatter) .Inputs({"input", "text_input", "image_input", "token_type_ids", "text_index", "image_index"}) - .Outputs({"text_input_out", - "image_input_out", - "text_index_out", - "image_index_out"}) + .Outputs({"output", "text_input_out", "image_input_out"}) .Attrs({"is_scatter:bool"}) - .SetInplaceMap({{"text_input", "text_input_out"}, - {"image_input", "image_input_out"}, - {"text_index", "text_index_out"}, - {"image_index", "image_index_out"}}) + .SetInplaceMap({{"input", "output"}, + {"text_input", "text_input_out"}, + {"image_input", "image_input_out"}}) .SetKernelFn(PD_KERNEL(TextImageGatherScatter));