mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] Update the return value of TextImageGatherScatter (#4636)
Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
@@ -482,13 +482,14 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
|
||||
|
||||
void SetDataIpc(const paddle::Tensor& tmp_input, const std::string& shm_name);
|
||||
|
||||
void TextImageGatherScatter(paddle::Tensor& input,
|
||||
paddle::Tensor& text_input,
|
||||
paddle::Tensor& image_input,
|
||||
paddle::Tensor& token_type_ids,
|
||||
paddle::Tensor& text_index,
|
||||
paddle::Tensor& image_index,
|
||||
const bool is_scatter);
|
||||
std::vector<paddle::Tensor> TextImageGatherScatter(
|
||||
paddle::Tensor& input,
|
||||
paddle::Tensor& text_input,
|
||||
paddle::Tensor& image_input,
|
||||
paddle::Tensor& token_type_ids,
|
||||
paddle::Tensor& text_index,
|
||||
paddle::Tensor& image_index,
|
||||
const bool is_scatter);
|
||||
|
||||
void TextImageIndexOut(const paddle::Tensor& token_type_ids,
|
||||
const paddle::Tensor& text_index,
|
||||
|
||||
@@ -17,13 +17,18 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
void TextImageGatherScatter(paddle::Tensor& input,
|
||||
paddle::Tensor& text_input,
|
||||
paddle::Tensor& image_input,
|
||||
paddle::Tensor& token_type_ids,
|
||||
paddle::Tensor& text_index,
|
||||
paddle::Tensor& image_index,
|
||||
const bool is_scatter) {
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
std::vector<paddle::Tensor> TextImageGatherScatter(
|
||||
paddle::Tensor& input,
|
||||
paddle::Tensor& text_input,
|
||||
paddle::Tensor& image_input,
|
||||
paddle::Tensor& token_type_ids,
|
||||
paddle::Tensor& text_index,
|
||||
paddle::Tensor& image_index,
|
||||
const bool is_scatter) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
@@ -58,22 +63,19 @@ void TextImageGatherScatter(paddle::Tensor& input,
|
||||
break;
|
||||
}
|
||||
}
|
||||
return {input, text_input, image_input};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(text_image_gather_scatter)
|
||||
PD_BUILD_STATIC_OP(text_image_gather_scatter)
|
||||
.Inputs({"input",
|
||||
"text_input",
|
||||
"image_input",
|
||||
"token_type_ids",
|
||||
"text_index",
|
||||
"image_index"})
|
||||
.Outputs({"text_input_out",
|
||||
"image_input_out",
|
||||
"text_index_out",
|
||||
"image_index_out"})
|
||||
.Outputs({"output", "text_input_out", "image_input_out"})
|
||||
.Attrs({"is_scatter:bool"})
|
||||
.SetInplaceMap({{"text_input", "text_input_out"},
|
||||
{"image_input", "image_input_out"},
|
||||
{"text_index", "text_index_out"},
|
||||
{"image_index", "image_index_out"}})
|
||||
.SetInplaceMap({{"input", "output"},
|
||||
{"text_input", "text_input_out"},
|
||||
{"image_input", "image_input_out"}})
|
||||
.SetKernelFn(PD_KERNEL(TextImageGatherScatter));
|
||||
|
||||
Reference in New Issue
Block a user