mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
prebind output by shareExternalData
This commit is contained in:
@@ -15,9 +15,9 @@
|
|||||||
#include <dlpack/dlpack.h>
|
#include <dlpack/dlpack.h>
|
||||||
|
|
||||||
#include "fastdeploy/core/fd_type.h"
|
#include "fastdeploy/core/fd_type.h"
|
||||||
#include "fastdeploy/utils/utils.h"
|
|
||||||
#include "fastdeploy/fastdeploy_model.h"
|
#include "fastdeploy/fastdeploy_model.h"
|
||||||
#include "fastdeploy/pybind/main.h"
|
#include "fastdeploy/pybind/main.h"
|
||||||
|
#include "fastdeploy/utils/utils.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
@@ -68,8 +68,8 @@ DLDataType FDToDlpackType(FDDataType fd_dtype) {
|
|||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
FDASSERT(false,
|
FDASSERT(false, "Convert to DlPack, FDType \"%s\" is not supported.",
|
||||||
"Convert to DlPack, FDType \"%s\" is not supported.", Str(fd_dtype).c_str());
|
Str(fd_dtype).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
dl_dtype.code = dl_code;
|
dl_dtype.code = dl_code;
|
||||||
@@ -77,10 +77,8 @@ DLDataType FDToDlpackType(FDDataType fd_dtype) {
|
|||||||
return dl_dtype;
|
return dl_dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
FDDataType
|
FDDataType DlpackToFDType(const DLDataType& data_type) {
|
||||||
DlpackToFDType(const DLDataType& data_type) {
|
FDASSERT(data_type.lanes == 1, "FDTensor does not support dlpack lanes != 1")
|
||||||
FDASSERT(data_type.lanes == 1,
|
|
||||||
"FDTensor does not support dlpack lanes != 1")
|
|
||||||
|
|
||||||
if (data_type.code == DLDataTypeCode::kDLFloat) {
|
if (data_type.code == DLDataTypeCode::kDLFloat) {
|
||||||
if (data_type.bits == 16) {
|
if (data_type.bits == 16) {
|
||||||
@@ -152,7 +150,7 @@ pybind11::capsule FDTensorToDLPack(FDTensor& fd_tensor) {
|
|||||||
dlpack_tensor->dl_tensor.dtype = FDToDlpackType(fd_tensor.dtype);
|
dlpack_tensor->dl_tensor.dtype = FDToDlpackType(fd_tensor.dtype);
|
||||||
|
|
||||||
dlpack_tensor->dl_tensor.device.device_id = fd_tensor.device_id;
|
dlpack_tensor->dl_tensor.device.device_id = fd_tensor.device_id;
|
||||||
if(fd_tensor.device == Device::GPU) {
|
if (fd_tensor.device == Device::GPU) {
|
||||||
if (fd_tensor.is_pinned_memory) {
|
if (fd_tensor.is_pinned_memory) {
|
||||||
dlpack_tensor->dl_tensor.device.device_type = DLDeviceType::kDLCUDAHost;
|
dlpack_tensor->dl_tensor.device.device_type = DLDeviceType::kDLCUDAHost;
|
||||||
} else {
|
} else {
|
||||||
@@ -162,8 +160,8 @@ pybind11::capsule FDTensorToDLPack(FDTensor& fd_tensor) {
|
|||||||
dlpack_tensor->dl_tensor.device.device_type = DLDeviceType::kDLCPU;
|
dlpack_tensor->dl_tensor.device.device_type = DLDeviceType::kDLCPU;
|
||||||
}
|
}
|
||||||
|
|
||||||
return pybind11::capsule(
|
return pybind11::capsule(static_cast<void*>(dlpack_tensor), "dltensor",
|
||||||
static_cast<void*>(dlpack_tensor), "dltensor", &DeleteUnusedDltensor);
|
&DeleteUnusedDltensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
FDTensor FDTensorFromDLPack(const std::string& name,
|
FDTensor FDTensorFromDLPack(const std::string& name,
|
||||||
@@ -178,9 +176,8 @@ FDTensor FDTensorFromDLPack(const std::string& name,
|
|||||||
int64_t* strides = dl_managed_tensor->dl_tensor.strides;
|
int64_t* strides = dl_managed_tensor->dl_tensor.strides;
|
||||||
|
|
||||||
int ndim = dl_managed_tensor->dl_tensor.ndim;
|
int ndim = dl_managed_tensor->dl_tensor.ndim;
|
||||||
std::vector<int64_t> dims(
|
std::vector<int64_t> dims(dl_managed_tensor->dl_tensor.shape,
|
||||||
dl_managed_tensor->dl_tensor.shape,
|
dl_managed_tensor->dl_tensor.shape + ndim);
|
||||||
dl_managed_tensor->dl_tensor.shape + ndim);
|
|
||||||
|
|
||||||
// Check if the input is contiguous and in C order
|
// Check if the input is contiguous and in C order
|
||||||
if (strides != nullptr) {
|
if (strides != nullptr) {
|
||||||
@@ -196,8 +193,8 @@ FDTensor FDTensorFromDLPack(const std::string& name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
FDASSERT(is_contiguous_c_order,
|
FDASSERT(is_contiguous_c_order,
|
||||||
"DLPack tensor is not contiguous. Only contiguous DLPack "
|
"DLPack tensor is not contiguous. Only contiguous DLPack "
|
||||||
"tensors that are stored in C-Order are supported.");
|
"tensors that are stored in C-Order are supported.");
|
||||||
}
|
}
|
||||||
|
|
||||||
Device device;
|
Device device;
|
||||||
@@ -216,21 +213,20 @@ FDTensor FDTensorFromDLPack(const std::string& name,
|
|||||||
is_pinned_memory = true;
|
is_pinned_memory = true;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
FDASSERT(false,
|
FDASSERT(
|
||||||
|
false,
|
||||||
("DLDevice type " +
|
("DLDevice type " +
|
||||||
std::to_string(dl_managed_tensor->dl_tensor.device.device_type) +
|
std::to_string(dl_managed_tensor->dl_tensor.device.device_type) +
|
||||||
" is not support by Python backend.").c_str());
|
" is not support by Python backend.")
|
||||||
|
.c_str());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
FDDataType dtype =
|
FDDataType dtype = DlpackToFDType(dl_managed_tensor->dl_tensor.dtype);
|
||||||
DlpackToFDType(dl_managed_tensor->dl_tensor.dtype);
|
|
||||||
|
|
||||||
PyCapsule_SetName(dlpack_tensor.ptr(), "used_dlpack");
|
PyCapsule_SetName(dlpack_tensor.ptr(), "used_dlpack");
|
||||||
FDTensor fd_tensor(name);
|
FDTensor fd_tensor(name);
|
||||||
fd_tensor.SetExternalData(
|
fd_tensor.SetExternalData(dims, dtype, memory_ptr, device, device_id);
|
||||||
dims, dtype, memory_ptr, device, device_id
|
|
||||||
);
|
|
||||||
fd_tensor.is_pinned_memory = is_pinned_memory;
|
fd_tensor.is_pinned_memory = is_pinned_memory;
|
||||||
return fd_tensor;
|
return fd_tensor;
|
||||||
}
|
}
|
||||||
@@ -242,15 +238,52 @@ void BindFDTensor(pybind11::module& m) {
|
|||||||
.def_readonly("shape", &FDTensor::shape)
|
.def_readonly("shape", &FDTensor::shape)
|
||||||
.def_readonly("dtype", &FDTensor::dtype)
|
.def_readonly("dtype", &FDTensor::dtype)
|
||||||
.def_readonly("device", &FDTensor::device)
|
.def_readonly("device", &FDTensor::device)
|
||||||
.def("numpy", [](FDTensor& self) {
|
.def("numpy", [](FDTensor& self) { return TensorToPyArray(self); })
|
||||||
return TensorToPyArray(self);
|
|
||||||
})
|
|
||||||
.def("data", &FDTensor::MutableData)
|
.def("data", &FDTensor::MutableData)
|
||||||
.def("from_numpy", [](FDTensor& self, pybind11::array& pyarray, bool share_buffer = false) {
|
.def("from_numpy",
|
||||||
PyArrayToTensor(pyarray, &self, share_buffer);
|
[](FDTensor& self, pybind11::array& pyarray,
|
||||||
})
|
bool share_buffer = false) {
|
||||||
|
PyArrayToTensor(pyarray, &self, share_buffer);
|
||||||
|
})
|
||||||
|
.def("from_external_data",
|
||||||
|
[](const std::string& name, size_t data_addr,
|
||||||
|
const std::vector<int64_t>& shape, const std::string& data_type,
|
||||||
|
const std::string& data_place, int device_id) {
|
||||||
|
auto fd_data_type = FDDataType::UNKNOWN1;
|
||||||
|
if (data_type == "FP32") {
|
||||||
|
fd_data_type = FDDataType::FP32;
|
||||||
|
} else if (data_type == "FP16") {
|
||||||
|
fd_data_type = FDDataType::FP16;
|
||||||
|
} else if (data_type == "INT32") {
|
||||||
|
fd_data_type = FDDataType::INT32;
|
||||||
|
} else if (data_type == "INT64") {
|
||||||
|
fd_data_type = FDDataType::INT64;
|
||||||
|
} else {
|
||||||
|
FDASSERT(false,
|
||||||
|
"FDTensor.from_external_data, datatype \"%s\" is not "
|
||||||
|
"supported.",
|
||||||
|
data_type.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Device fd_data_place;
|
||||||
|
if (data_place.find("gpu") != data_place.npos) {
|
||||||
|
fd_data_place = Device::GPU;
|
||||||
|
} else {
|
||||||
|
FDASSERT(false,
|
||||||
|
("Device type " + data_place +
|
||||||
|
" is not support by FDTensor.from_external_data.")
|
||||||
|
.c_str());
|
||||||
|
}
|
||||||
|
void* data_ptr = nullptr;
|
||||||
|
data_ptr = reinterpret_cast<void*>(data_addr);
|
||||||
|
FDTensor fd_tensor(name);
|
||||||
|
fd_tensor.SetExternalData(shape, fd_data_type,
|
||||||
|
static_cast<void*>(data_ptr),
|
||||||
|
fd_data_place, device_id);
|
||||||
|
return fd_tensor;
|
||||||
|
})
|
||||||
.def("to_dlpack", &FDTensorToDLPack)
|
.def("to_dlpack", &FDTensorToDLPack)
|
||||||
.def("from_dlpack",&FDTensorFromDLPack)
|
.def("from_dlpack", &FDTensorFromDLPack)
|
||||||
.def("print_info", &FDTensor::PrintInfo);
|
.def("print_info", &FDTensor::PrintInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -110,6 +110,7 @@ void BindRuntime(pybind11::module& m) {
|
|||||||
return outputs;
|
return outputs;
|
||||||
})
|
})
|
||||||
.def("bind_input_tensor", &Runtime::BindInputTensor)
|
.def("bind_input_tensor", &Runtime::BindInputTensor)
|
||||||
|
.def("bind_output_tensor", &Runtime::BindOutputTensor)
|
||||||
.def("infer", [](Runtime& self) { self.Infer(); })
|
.def("infer", [](Runtime& self) { self.Infer(); })
|
||||||
.def("get_output_tensor",
|
.def("get_output_tensor",
|
||||||
[](Runtime& self, const std::string& name) {
|
[](Runtime& self, const std::string& name) {
|
||||||
|
@@ -25,6 +25,7 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
|
|||||||
if (option.device == Device::GPU) {
|
if (option.device == Device::GPU) {
|
||||||
config_.EnableUseGpu(option.gpu_mem_init_size, option.device_id);
|
config_.EnableUseGpu(option.gpu_mem_init_size, option.device_id);
|
||||||
if (option_.external_stream_) {
|
if (option_.external_stream_) {
|
||||||
|
FDINFO << "Will use external stream for Paddle Backend." << std::endl;
|
||||||
config_.SetExecStream(option_.external_stream_);
|
config_.SetExecStream(option_.external_stream_);
|
||||||
}
|
}
|
||||||
if (option.enable_trt) {
|
if (option.enable_trt) {
|
||||||
@@ -47,7 +48,7 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
|
|||||||
config_.SetOptimCacheDir(option.trt_option.serialize_file);
|
config_.SetOptimCacheDir(option.trt_option.serialize_file);
|
||||||
}
|
}
|
||||||
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size,
|
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size,
|
||||||
option.trt_option.max_batch_size, 3,
|
option.trt_option.max_batch_size, 20,
|
||||||
precision, use_static);
|
precision, use_static);
|
||||||
SetTRTDynamicShapeToConfig(option);
|
SetTRTDynamicShapeToConfig(option);
|
||||||
}
|
}
|
||||||
@@ -124,9 +125,10 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_buffer,
|
|||||||
"file will save to the directory where paddle model saved."
|
"file will save to the directory where paddle model saved."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
use_static = true;
|
use_static = true;
|
||||||
|
config_.SetOptimCacheDir(option.trt_option.serialize_file);
|
||||||
}
|
}
|
||||||
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size,
|
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size,
|
||||||
option.trt_option.max_batch_size, 3,
|
option.trt_option.max_batch_size, 20,
|
||||||
paddle_infer::PrecisionType::kInt8,
|
paddle_infer::PrecisionType::kInt8,
|
||||||
use_static, false);
|
use_static, false);
|
||||||
SetTRTDynamicShapeToConfig(option);
|
SetTRTDynamicShapeToConfig(option);
|
||||||
@@ -223,23 +225,47 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
|
|||||||
<< inputs_desc_.size() << ")." << std::endl;
|
<< inputs_desc_.size() << ")." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
// output share backend memory only support CPU or GPU
|
||||||
|
if (option_.device == Device::IPU) {
|
||||||
|
copy_to_fd = true;
|
||||||
|
}
|
||||||
|
|
||||||
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
|
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
auto handle = predictor_->GetInputHandle(inputs[i].name);
|
auto handle = predictor_->GetInputHandle(inputs[i].name);
|
||||||
ShareTensorFromFDTensor(handle.get(), inputs[i]);
|
ShareTensorFromFDTensor(handle.get(), inputs[i]);
|
||||||
}
|
}
|
||||||
|
std::unordered_set<std::string> prebinded_output_name;
|
||||||
|
// prebinded output only support for GPU
|
||||||
|
if (!copy_to_fd) {
|
||||||
|
for (size_t i = 0; i < (*outputs).size(); ++i) {
|
||||||
|
auto output_name = (*outputs)[i].name;
|
||||||
|
// if a output is not prebinded,
|
||||||
|
// the name of output is expected to be empty.
|
||||||
|
// We skip here
|
||||||
|
if (output_name.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Record the prebinded output_name.
|
||||||
|
// Those outputs do not need PaddleTensorToFDTensor
|
||||||
|
// after predictor_.Run()
|
||||||
|
prebinded_output_name.insert(output_name);
|
||||||
|
auto handle = predictor_->GetOutputHandle(output_name);
|
||||||
|
ShareOutTensorFromFDTensor(handle.get(), (*outputs)[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RUNTIME_PROFILE_LOOP_BEGIN(1)
|
RUNTIME_PROFILE_LOOP_BEGIN(1)
|
||||||
predictor_->Run();
|
predictor_->Run();
|
||||||
RUNTIME_PROFILE_LOOP_END
|
RUNTIME_PROFILE_LOOP_END
|
||||||
|
|
||||||
// output share backend memory only support CPU or GPU
|
|
||||||
if (option_.device == Device::IPU) {
|
|
||||||
copy_to_fd = true;
|
|
||||||
}
|
|
||||||
outputs->resize(outputs_desc_.size());
|
outputs->resize(outputs_desc_.size());
|
||||||
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
|
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
|
||||||
|
// skip prebinded output
|
||||||
|
if (copy_to_fd == false &&
|
||||||
|
prebinded_output_name.count(outputs_desc_[i].name)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
auto handle = predictor_->GetOutputHandle(outputs_desc_[i].name);
|
auto handle = predictor_->GetOutputHandle(outputs_desc_[i].name);
|
||||||
if (copy_to_fd) {
|
if (copy_to_fd) {
|
||||||
(*outputs)[i].is_pinned_memory = option_.enable_pinned_memory;
|
(*outputs)[i].is_pinned_memory = option_.enable_pinned_memory;
|
||||||
|
@@ -35,6 +35,9 @@ paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device);
|
|||||||
// Share memory buffer with paddle_infer::Tensor from fastdeploy::FDTensor
|
// Share memory buffer with paddle_infer::Tensor from fastdeploy::FDTensor
|
||||||
void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, FDTensor& fd_tensor);
|
void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, FDTensor& fd_tensor);
|
||||||
|
|
||||||
|
void ShareOutTensorFromFDTensor(paddle_infer::Tensor* tensor,
|
||||||
|
FDTensor& fd_tensor);
|
||||||
|
|
||||||
// convert paddle_infer::Tensor to fastdeploy::FDTensor
|
// convert paddle_infer::Tensor to fastdeploy::FDTensor
|
||||||
// if copy_to_fd is true, copy memory data to FDTensor
|
// if copy_to_fd is true, copy memory data to FDTensor
|
||||||
/// else share memory to FDTensor
|
/// else share memory to FDTensor
|
||||||
|
@@ -61,6 +61,43 @@ void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor,
|
|||||||
Str(fd_tensor.dtype).c_str());
|
Str(fd_tensor.dtype).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ShareOutTensorFromFDTensor(paddle_infer::Tensor* tensor,
|
||||||
|
FDTensor& fd_tensor) {
|
||||||
|
std::vector<int> shape(fd_tensor.shape.begin(), fd_tensor.shape.end());
|
||||||
|
auto place = ConvertFDDeviceToPlace(fd_tensor.device);
|
||||||
|
if (fd_tensor.dtype == FDDataType::FP32) {
|
||||||
|
if (place == paddle_infer::PlaceType::kGPU) {
|
||||||
|
tensor->ShareExternalData(static_cast<float*>(fd_tensor.MutableData()),
|
||||||
|
shape, place);
|
||||||
|
} else {
|
||||||
|
tensor->CopyToCpu(static_cast<float*>(fd_tensor.MutableData()));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
} else if (fd_tensor.dtype == FDDataType::INT32) {
|
||||||
|
if (place == paddle_infer::PlaceType::kGPU) {
|
||||||
|
tensor->ShareExternalData(static_cast<int32_t*>(fd_tensor.MutableData()),
|
||||||
|
shape, place);
|
||||||
|
} else {
|
||||||
|
tensor->CopyToCpu(static_cast<int32_t*>(fd_tensor.MutableData()));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
} else if (fd_tensor.dtype == FDDataType::INT64) {
|
||||||
|
if (place == paddle_infer::PlaceType::kGPU) {
|
||||||
|
tensor->ShareExternalData(static_cast<int64_t*>(fd_tensor.MutableData()),
|
||||||
|
shape, place);
|
||||||
|
} else {
|
||||||
|
tensor->CopyToCpu(static_cast<int64_t*>(fd_tensor.MutableData()));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
} else if (fd_tensor.dtype == FDDataType::UINT8) {
|
||||||
|
tensor->ShareExternalData(static_cast<uint8_t*>(fd_tensor.MutableData()),
|
||||||
|
shape, paddle_infer::PlaceType::kCPU);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
|
||||||
|
Str(fd_tensor.dtype).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
void PaddleTensorToFDTensor(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
void PaddleTensorToFDTensor(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
||||||
FDTensor* fd_tensor, bool copy_to_fd) {
|
FDTensor* fd_tensor, bool copy_to_fd) {
|
||||||
auto fd_dtype = PaddleDataTypeToFD(tensor->type());
|
auto fd_dtype = PaddleDataTypeToFD(tensor->type());
|
||||||
|
@@ -198,6 +198,26 @@ void Runtime::BindInputTensor(const std::string& name, FDTensor& input) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Runtime::BindOutputTensor(const std::string& name, FDTensor& output) {
|
||||||
|
bool is_exist = false;
|
||||||
|
for (auto& t : output_tensors_) {
|
||||||
|
if (t.name == name) {
|
||||||
|
// FDWARNING << "The output name [" << name << "] is exist." << std::endl;
|
||||||
|
is_exist = true;
|
||||||
|
t.SetExternalData(output.shape, output.dtype, output.MutableData(),
|
||||||
|
output.device, output.device_id);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!is_exist) {
|
||||||
|
// FDWARNING << "The output name [" << name << "] don't exist." <<
|
||||||
|
// std::endl;
|
||||||
|
FDTensor new_tensor(name);
|
||||||
|
new_tensor.SetExternalData(output.shape, output.dtype, output.MutableData(),
|
||||||
|
output.device, output.device_id);
|
||||||
|
output_tensors_.emplace_back(std::move(new_tensor));
|
||||||
|
}
|
||||||
|
}
|
||||||
FDTensor* Runtime::GetOutputTensor(const std::string& name) {
|
FDTensor* Runtime::GetOutputTensor(const std::string& name) {
|
||||||
for (auto& t : output_tensors_) {
|
for (auto& t : output_tensors_) {
|
||||||
if (t.name == name) {
|
if (t.name == name) {
|
||||||
|
@@ -72,6 +72,12 @@ struct FASTDEPLOY_DECL Runtime {
|
|||||||
/** \brief Bind FDTensor by name, no copy and share input memory
|
/** \brief Bind FDTensor by name, no copy and share input memory
|
||||||
*/
|
*/
|
||||||
void BindInputTensor(const std::string& name, FDTensor& input);
|
void BindInputTensor(const std::string& name, FDTensor& input);
|
||||||
|
|
||||||
|
/** \brief Bind FDTensor by name, no copy and share output memory.
|
||||||
|
* Please make share the correctness of tensor shape of output.
|
||||||
|
*/
|
||||||
|
void BindOutputTensor(const std::string& name, FDTensor& output);
|
||||||
|
|
||||||
/** \brief Get output FDTensor by name, no copy and share backend output memory
|
/** \brief Get output FDTensor by name, no copy and share backend output memory
|
||||||
*/
|
*/
|
||||||
FDTensor* GetOutputTensor(const std::string& name);
|
FDTensor* GetOutputTensor(const std::string& name);
|
||||||
|
@@ -72,6 +72,14 @@ class Runtime:
|
|||||||
"""
|
"""
|
||||||
self._runtime.bind_input_tensor(name, fdtensor)
|
self._runtime.bind_input_tensor(name, fdtensor)
|
||||||
|
|
||||||
|
def bind_output_tensor(self, name, fdtensor):
|
||||||
|
"""Bind FDTensor by name, no copy and share output memory
|
||||||
|
|
||||||
|
:param name: (str)The name of output data.
|
||||||
|
:param fdtensor: (fastdeploy.FDTensor)The output FDTensor.
|
||||||
|
"""
|
||||||
|
self._runtime.bind_output_tensor(name, fdtensor)
|
||||||
|
|
||||||
def zero_copy_infer(self):
|
def zero_copy_infer(self):
|
||||||
"""No params inference the model.
|
"""No params inference the model.
|
||||||
|
|
||||||
@@ -656,7 +664,8 @@ class RuntimeOption:
|
|||||||
continue
|
continue
|
||||||
if hasattr(getattr(self._option, attr), "__call__"):
|
if hasattr(getattr(self._option, attr), "__call__"):
|
||||||
continue
|
continue
|
||||||
message += " {} : {}\t\n".format(attr, getattr(self._option, attr))
|
message += " {} : {}\t\n".format(attr,
|
||||||
|
getattr(self._option, attr))
|
||||||
message.strip("\n")
|
message.strip("\n")
|
||||||
message += ")"
|
message += ")"
|
||||||
return message
|
return message
|
||||||
|
Reference in New Issue
Block a user