diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 833cd04ec..79f89df37 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -482,13 +482,14 @@ std::vector SpeculateGetSeqLensOutput( void SetDataIpc(const paddle::Tensor& tmp_input, const std::string& shm_name); -void TextImageGatherScatter(paddle::Tensor& input, - paddle::Tensor& text_input, - paddle::Tensor& image_input, - paddle::Tensor& token_type_ids, - paddle::Tensor& text_index, - paddle::Tensor& image_index, - const bool is_scatter); +std::vector TextImageGatherScatter( + paddle::Tensor& input, + paddle::Tensor& text_input, + paddle::Tensor& image_input, + paddle::Tensor& token_type_ids, + paddle::Tensor& text_index, + paddle::Tensor& image_index, + const bool is_scatter); void TextImageIndexOut(const paddle::Tensor& token_type_ids, const paddle::Tensor& text_index, diff --git a/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc b/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc index 1df2ba82b..4d041528c 100644 --- a/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc +++ b/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc @@ -17,13 +17,18 @@ #include "paddle/extension.h" #include "xpu/plugin.h" -void TextImageGatherScatter(paddle::Tensor& input, - paddle::Tensor& text_input, - paddle::Tensor& image_input, - paddle::Tensor& token_type_ids, - paddle::Tensor& text_index, - paddle::Tensor& image_index, - const bool is_scatter) { +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +std::vector TextImageGatherScatter( + paddle::Tensor& input, + paddle::Tensor& text_input, + paddle::Tensor& image_input, + paddle::Tensor& token_type_ids, + paddle::Tensor& text_index, + paddle::Tensor& image_index, + const bool is_scatter) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); auto xpu_ctx = static_cast(dev_ctx); @@ -58,22 +63,19 @@ void TextImageGatherScatter(paddle::Tensor& input, break; } } + return {input, text_input, image_input}; } -PD_BUILD_OP(text_image_gather_scatter) +PD_BUILD_STATIC_OP(text_image_gather_scatter) .Inputs({"input", "text_input", "image_input", "token_type_ids", "text_index", "image_index"}) - .Outputs({"text_input_out", - "image_input_out", - "text_index_out", - "image_index_out"}) + .Outputs({"output", "text_input_out", "image_input_out"}) .Attrs({"is_scatter:bool"}) - .SetInplaceMap({{"text_input", "text_input_out"}, - {"image_input", "image_input_out"}, - {"text_index", "text_index_out"}, - {"image_index", "image_index_out"}}) + .SetInplaceMap({{"input", "output"}, + {"text_input", "text_input_out"}, + {"image_input", "image_input_out"}}) .SetKernelFn(PD_KERNEL(TextImageGatherScatter));