diff --git a/fastdeploy/backends/backend.h b/fastdeploy/backends/backend.h index 652d94cb8..02c94875d 100644 --- a/fastdeploy/backends/backend.h +++ b/fastdeploy/backends/backend.h @@ -62,8 +62,11 @@ class BaseBackend { virtual TensorInfo GetOutputInfo(int index) = 0; virtual std::vector GetInputInfos() = 0; virtual std::vector GetOutputInfos() = 0; + // if copy_to_fd is true, copy memory data to FDTensor + // else share memory to FDTensor(only Paddle、ORT、TRT、OpenVINO support it) virtual bool Infer(std::vector& inputs, - std::vector* outputs) = 0; + std::vector* outputs, + bool copy_to_fd = true) = 0; virtual std::unique_ptr Clone(void *stream = nullptr, int device_id = -1) { FDERROR << "Clone no support" << std::endl; diff --git a/fastdeploy/backends/lite/lite_backend.cc b/fastdeploy/backends/lite/lite_backend.cc index bdfad299c..e3c87aabd 100755 --- a/fastdeploy/backends/lite/lite_backend.cc +++ b/fastdeploy/backends/lite/lite_backend.cc @@ -187,7 +187,8 @@ TensorInfo LiteBackend::GetOutputInfo(int index) { std::vector LiteBackend::GetOutputInfos() { return outputs_desc_; } bool LiteBackend::Infer(std::vector& inputs, - std::vector* outputs) { + std::vector* outputs, + bool copy_to_fd) { if (inputs.size() != inputs_desc_.size()) { FDERROR << "[LiteBackend] Size of inputs(" << inputs.size() << ") should keep same with the inputs of this model(" diff --git a/fastdeploy/backends/lite/lite_backend.h b/fastdeploy/backends/lite/lite_backend.h index 2922d4ea3..279acf5df 100755 --- a/fastdeploy/backends/lite/lite_backend.h +++ b/fastdeploy/backends/lite/lite_backend.h @@ -60,7 +60,9 @@ class LiteBackend : public BaseBackend { const std::string& params_file, const LiteBackendOption& option = LiteBackendOption()); - bool Infer(std::vector& inputs, std::vector* outputs) override; // NOLINT + bool Infer(std::vector& inputs, + std::vector* outputs, + bool copy_to_fd = true) override; // NOLINT int NumInputs() const override { return inputs_desc_.size(); } diff --git a/fastdeploy/backends/openvino/ov_backend.cc b/fastdeploy/backends/openvino/ov_backend.cc index da3ec5404..9e8c2571a 100644 --- a/fastdeploy/backends/openvino/ov_backend.cc +++ b/fastdeploy/backends/openvino/ov_backend.cc @@ -341,7 +341,8 @@ int OpenVINOBackend::NumInputs() const { return input_infos_.size(); } int OpenVINOBackend::NumOutputs() const { return output_infos_.size(); } bool OpenVINOBackend::Infer(std::vector& inputs, - std::vector* outputs) { + std::vector* outputs, + bool copy_to_fd) { if (inputs.size() != input_infos_.size()) { FDERROR << "[OpenVINOBackend] Size of the inputs(" << inputs.size() << ") should keep same with the inputs of this model(" @@ -364,11 +365,20 @@ bool OpenVINOBackend::Infer(std::vector& inputs, auto out_tensor_shape = out_tensor.get_shape(); std::vector shape(out_tensor_shape.begin(), out_tensor_shape.end()); - (*outputs)[i].Allocate(shape, + if(copy_to_fd) { + (*outputs)[i].Resize(shape, OpenVINODataTypeToFD(out_tensor.get_element_type()), - output_infos_[i].name); - memcpy((*outputs)[i].MutableData(), out_tensor.data(), - (*outputs)[i].Nbytes()); + output_infos_[i].name, + Device::CPU); + memcpy((*outputs)[i].MutableData(), out_tensor.data(), + (*outputs)[i].Nbytes()); + } else { + (*outputs)[i].name = output_infos_[i].name; + (*outputs)[i].SetExternalData(shape, + OpenVINODataTypeToFD(out_tensor.get_element_type()), + out_tensor.data(), + Device::CPU); + } } return true; } diff --git a/fastdeploy/backends/openvino/ov_backend.h b/fastdeploy/backends/openvino/ov_backend.h index e224cdca5..2dadab29d 100644 --- a/fastdeploy/backends/openvino/ov_backend.h +++ b/fastdeploy/backends/openvino/ov_backend.h @@ -48,7 +48,8 @@ class OpenVINOBackend : public BaseBackend { const OpenVINOBackendOption& option = OpenVINOBackendOption()); bool Infer(std::vector& inputs, - std::vector* outputs) override; + std::vector* outputs, + bool copy_to_fd = true) override; int NumInputs() const override; diff --git a/fastdeploy/backends/ort/ort_backend.cc b/fastdeploy/backends/ort/ort_backend.cc index 379ab9d1e..1e6d8bfb5 100755 --- a/fastdeploy/backends/ort/ort_backend.cc +++ b/fastdeploy/backends/ort/ort_backend.cc @@ -181,8 +181,8 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file, return true; } -void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor, - const std::string& name) { +void OrtBackend::OrtValueToFDTensor(const Ort::Value& value, FDTensor* tensor, + const std::string& name, bool copy_to_fd) { const auto info = value.GetTensorTypeAndShapeInfo(); const auto data_type = info.GetElementType(); size_t numel = info.GetElementCount(); @@ -210,12 +210,21 @@ void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor, "Unrecognized data type of %d while calling OrtBackend::CopyToCpu().", data_type); } - tensor->Resize(shape, dtype, name); - memcpy(tensor->MutableData(), value.GetTensorData(), numel); + const void* value_ptr = value.GetTensorData(); + if (copy_to_fd) { + tensor->Resize(shape, dtype, name); + memcpy(tensor->MutableData(), value_ptr, numel); + } else { + tensor->name = name; + tensor->SetExternalData( + shape, dtype, + const_cast(value_ptr), Device::CPU); + } } bool OrtBackend::Infer(std::vector& inputs, - std::vector* outputs) { + std::vector* outputs, + bool copy_to_fd) { if (inputs.size() != inputs_desc_.size()) { FDERROR << "[OrtBackend] Size of the inputs(" << inputs.size() << ") should keep same with the inputs of this model(" @@ -243,11 +252,12 @@ bool OrtBackend::Infer(std::vector& inputs, return false; } - // Copy result after inference + // Convert result after inference std::vector ort_outputs = binding_->GetOutputValues(); outputs->resize(ort_outputs.size()); for (size_t i = 0; i < ort_outputs.size(); ++i) { - CopyToCpu(ort_outputs[i], &((*outputs)[i]), outputs_desc_[i].name); + OrtValueToFDTensor(ort_outputs[i], &((*outputs)[i]), + outputs_desc_[i].name, copy_to_fd); } return true; diff --git a/fastdeploy/backends/ort/ort_backend.h b/fastdeploy/backends/ort/ort_backend.h index 31c769824..ab5f38e61 100644 --- a/fastdeploy/backends/ort/ort_backend.h +++ b/fastdeploy/backends/ort/ort_backend.h @@ -68,7 +68,8 @@ class OrtBackend : public BaseBackend { bool from_memory_buffer = false); bool Infer(std::vector& inputs, - std::vector* outputs) override; + std::vector* outputs, + bool copy_to_fd = true) override; int NumInputs() const override { return inputs_desc_.size(); } @@ -92,7 +93,7 @@ class OrtBackend : public BaseBackend { Ort::CustomOpDomain custom_op_domain_ = Ort::CustomOpDomain("Paddle"); #endif OrtBackendOption option_; - void CopyToCpu(const Ort::Value& value, FDTensor* tensor, - const std::string& name); + void OrtValueToFDTensor(const Ort::Value& value, FDTensor* tensor, + const std::string& name, bool copy_to_fd); }; } // namespace fastdeploy diff --git a/fastdeploy/backends/paddle/paddle_backend.cc b/fastdeploy/backends/paddle/paddle_backend.cc index c1ecacee2..866bf578e 100644 --- a/fastdeploy/backends/paddle/paddle_backend.cc +++ b/fastdeploy/backends/paddle/paddle_backend.cc @@ -194,7 +194,8 @@ std::vector PaddleBackend::GetOutputInfos() { } bool PaddleBackend::Infer(std::vector& inputs, - std::vector* outputs) { + std::vector* outputs, + bool copy_to_fd) { if (inputs.size() != inputs_desc_.size()) { FDERROR << "[PaddleBackend] Size of inputs(" << inputs.size() << ") should keep same with the inputs of this model(" @@ -208,11 +209,18 @@ bool PaddleBackend::Infer(std::vector& inputs, } predictor_->Run(); + + // output share backend memory only support CPU or GPU + if(option_.use_ipu) { + copy_to_fd = true; + } outputs->resize(outputs_desc_.size()); for (size_t i = 0; i < outputs_desc_.size(); ++i) { auto handle = predictor_->GetOutputHandle(outputs_desc_[i].name); - (*outputs)[i].is_pinned_memory = option_.enable_pinned_memory; - CopyTensorToCpu(handle, &((*outputs)[i])); + if(copy_to_fd) { + (*outputs)[i].is_pinned_memory = option_.enable_pinned_memory; + } + PaddleTensorToFDTensor(handle, &((*outputs)[i]), copy_to_fd); } return true; } diff --git a/fastdeploy/backends/paddle/paddle_backend.h b/fastdeploy/backends/paddle/paddle_backend.h index 0c674494e..ba083ae43 100755 --- a/fastdeploy/backends/paddle/paddle_backend.h +++ b/fastdeploy/backends/paddle/paddle_backend.h @@ -87,9 +87,12 @@ paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device); // Share memory buffer with paddle_infer::Tensor from fastdeploy::FDTensor void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, FDTensor& fd_tensor); -// Copy memory data from paddle_infer::Tensor to fastdeploy::FDTensor -void CopyTensorToCpu(std::unique_ptr& tensor, - FDTensor* fd_tensor); +// convert paddle_infer::Tensor to fastdeploy::FDTensor +// if copy_to_fd is true, copy memory data to FDTensor +/// else share memory to FDTensor +void PaddleTensorToFDTensor(std::unique_ptr& tensor, + FDTensor* fd_tensor, + bool copy_to_fd); // Convert data type from paddle inference to fastdeploy FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype); @@ -108,7 +111,9 @@ class PaddleBackend : public BaseBackend { const PaddleBackendOption& option = PaddleBackendOption()); bool Infer(std::vector& inputs, - std::vector* outputs) override; + std::vector* outputs, + bool copy_to_fd = true) override; + int NumInputs() const override { return inputs_desc_.size(); } diff --git a/fastdeploy/backends/paddle/util.cc b/fastdeploy/backends/paddle/util.cc index d8cc1dbb9..4e493fe16 100644 --- a/fastdeploy/backends/paddle/util.cc +++ b/fastdeploy/backends/paddle/util.cc @@ -61,25 +61,41 @@ void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, Str(fd_tensor.dtype).c_str()); } -void CopyTensorToCpu(std::unique_ptr& tensor, - FDTensor* fd_tensor) { +void PaddleTensorToFDTensor(std::unique_ptr& tensor, + FDTensor* fd_tensor, + bool copy_to_fd) { auto fd_dtype = PaddleDataTypeToFD(tensor->type()); std::vector shape; auto tmp_shape = tensor->shape(); shape.assign(tmp_shape.begin(), tmp_shape.end()); - fd_tensor->Resize(shape, fd_dtype, tensor->name()); - if (fd_tensor->dtype == FDDataType::FP32) { - tensor->CopyToCpu(static_cast(fd_tensor->MutableData())); - return; - } else if (fd_tensor->dtype == FDDataType::INT32) { - tensor->CopyToCpu(static_cast(fd_tensor->MutableData())); - return; - } else if (fd_tensor->dtype == FDDataType::INT64) { - tensor->CopyToCpu(static_cast(fd_tensor->MutableData())); - return; + if(copy_to_fd) { + fd_tensor->Resize(shape, fd_dtype, tensor->name()); + if (fd_tensor->dtype == FDDataType::FP32) { + tensor->CopyToCpu(static_cast(fd_tensor->MutableData())); + return; + } else if (fd_tensor->dtype == FDDataType::INT32) { + tensor->CopyToCpu(static_cast(fd_tensor->MutableData())); + return; + } else if (fd_tensor->dtype == FDDataType::INT64) { + tensor->CopyToCpu(static_cast(fd_tensor->MutableData())); + return; + } + FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", + Str(fd_tensor->dtype).c_str()); + } else { + paddle_infer::PlaceType place; + int size = 0; + // TODO(liqi): The tensor->data interface of paddle don't return device id + // and don't support return void*. + auto* out_data = tensor->data(&place, &size); + Device device = Device::CPU; + if(place == paddle_infer::PlaceType::kGPU) { + device = Device::GPU; + } + fd_tensor->SetExternalData( + shape, fd_dtype, + reinterpret_cast(out_data), device); } - FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", - Str(fd_tensor->dtype).c_str()); } FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) { diff --git a/fastdeploy/backends/poros/poros_backend.cc b/fastdeploy/backends/poros/poros_backend.cc index 8fdc42fb3..a7c96f7cd 100755 --- a/fastdeploy/backends/poros/poros_backend.cc +++ b/fastdeploy/backends/poros/poros_backend.cc @@ -188,7 +188,8 @@ bool PorosBackend::InitFromPoros(const std::string& model_file, } bool PorosBackend::Infer(std::vector& inputs, - std::vector* outputs) { + std::vector* outputs, + bool copy_to_fd) { // Convert FD Tensor to PyTorch Tensor std::vector poros_inputs; bool is_backend_cuda = diff --git a/fastdeploy/backends/poros/poros_backend.h b/fastdeploy/backends/poros/poros_backend.h index 656249e00..00dfe4444 100755 --- a/fastdeploy/backends/poros/poros_backend.h +++ b/fastdeploy/backends/poros/poros_backend.h @@ -85,7 +85,9 @@ class PorosBackend : public BaseBackend { std::vector>& prewarm_tensors, const PorosBackendOption& option = PorosBackendOption()); - bool Infer(std::vector& inputs, std::vector* outputs); + bool Infer(std::vector& inputs, + std::vector* outputs, + bool copy_to_fd = true) override; int NumInputs() const { return _numinputs; } diff --git a/fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.cc b/fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.cc index 2f0618dbe..16edf7561 100644 --- a/fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.cc +++ b/fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.cc @@ -289,7 +289,8 @@ std::vector RKNPU2Backend::GetOutputInfos() { } bool RKNPU2Backend::Infer(std::vector& inputs, - std::vector* outputs) { + std::vector* outputs, + bool copy_to_fd) { int ret = RKNN_SUCC; // Judge whether the input and output size are the same if (inputs.size() != inputs_desc_.size()) { diff --git a/fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.h b/fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.h index 1aba24ec3..af28fdddf 100644 --- a/fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.h +++ b/fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.h @@ -72,7 +72,8 @@ class RKNPU2Backend : public BaseBackend { std::vector GetInputInfos() override; std::vector GetOutputInfos() override; bool Infer(std::vector& inputs, - std::vector* outputs) override; + std::vector* outputs, + bool copy_to_fd = true) override; private: // The object of rknn context. diff --git a/fastdeploy/backends/tensorrt/trt_backend.cc b/fastdeploy/backends/tensorrt/trt_backend.cc index 5cdc266b6..3a8659ace 100755 --- a/fastdeploy/backends/tensorrt/trt_backend.cc +++ b/fastdeploy/backends/tensorrt/trt_backend.cc @@ -283,7 +283,8 @@ int TrtBackend::ShapeRangeInfoUpdated(const std::vector& inputs) { } bool TrtBackend::Infer(std::vector& inputs, - std::vector* outputs) { + std::vector* outputs, + bool copy_to_fd) { if (inputs.size() != NumInputs()) { FDERROR << "Require " << NumInputs() << "inputs, but get " << inputs.size() << "." << std::endl; @@ -304,7 +305,7 @@ bool TrtBackend::Infer(std::vector& inputs, cudaSetDevice(option_.gpu_id); SetInputs(inputs); - AllocateOutputsBuffer(outputs); + AllocateOutputsBuffer(outputs, copy_to_fd); if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) { FDERROR << "Failed to Infer with TensorRT." << std::endl; @@ -323,6 +324,11 @@ bool TrtBackend::Infer(std::vector& inputs, casted_output_tensors_[(*outputs)[i].name].Resize((*outputs)[i].shape, (*outputs)[i].dtype, (*outputs)[i].name, Device::GPU); function::CudaCast(output_tensor, &casted_output_tensors_[(*outputs)[i].name], stream_); + if(!copy_to_fd) { + (*outputs)[i].SetExternalData((*outputs)[i].shape, model_output_dtype, + casted_output_tensors_[(*outputs)[i].name].MutableData(), + Device::GPU, option_.gpu_id); + } } else { casted_output_tensors_[(*outputs)[i].name].SetExternalData( (*outputs)[i].shape, model_output_dtype, @@ -330,15 +336,17 @@ bool TrtBackend::Infer(std::vector& inputs, Device::GPU); } } - for (size_t i = 0; i < outputs->size(); ++i) { - FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(), - casted_output_tensors_[(*outputs)[i].name].Data(), - (*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost, - stream_) == 0, - "[ERROR] Error occurs while copy memory from GPU to CPU."); + if (copy_to_fd) { + for (size_t i = 0; i < outputs->size(); ++i) { + FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(), + casted_output_tensors_[(*outputs)[i].name].Data(), + (*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost, + stream_) == 0, + "[ERROR] Error occurs while copy memory from GPU to CPU."); + } + FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess, + "[ERROR] Error occurs while sync cuda stream."); } - FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess, - "[ERROR] Error occurs while sync cuda stream."); return true; } @@ -427,7 +435,8 @@ void TrtBackend::SetInputs(const std::vector& inputs) { } } -void TrtBackend::AllocateOutputsBuffer(std::vector* outputs) { +void TrtBackend::AllocateOutputsBuffer(std::vector* outputs, + bool copy_to_fd) { if (outputs->size() != outputs_desc_.size()) { outputs->resize(outputs_desc_.size()); } @@ -446,18 +455,26 @@ void TrtBackend::AllocateOutputsBuffer(std::vector* outputs) { outputs_desc_[i].name.c_str()); auto ori_idx = iter->second; - // set user's outputs info - std::vector shape(output_dims.d, - output_dims.d + output_dims.nbDims); - (*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory; - (*outputs)[ori_idx].Resize(shape, outputs_desc_[i].original_dtype, - outputs_desc_[i].name); - // Allocate output buffer memory outputs_device_buffer_[outputs_desc_[i].name].resize(output_dims); // binding output buffer - bindings_[idx] = outputs_device_buffer_[outputs_desc_[i].name].data(); + bindings_[idx] = outputs_device_buffer_[outputs_desc_[i].name].data(); + + // set user's outputs info + std::vector shape(output_dims.d, + output_dims.d + output_dims.nbDims); + if(copy_to_fd) { + (*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory; + (*outputs)[ori_idx].Resize(shape, outputs_desc_[i].original_dtype, + outputs_desc_[i].name); + } else { + (*outputs)[ori_idx].name = outputs_desc_[i].name; + (*outputs)[ori_idx].SetExternalData( + shape, outputs_desc_[i].original_dtype, + bindings_[idx], Device::GPU, + option_.gpu_id); + } } } diff --git a/fastdeploy/backends/tensorrt/trt_backend.h b/fastdeploy/backends/tensorrt/trt_backend.h index f08401e54..425087fad 100755 --- a/fastdeploy/backends/tensorrt/trt_backend.h +++ b/fastdeploy/backends/tensorrt/trt_backend.h @@ -97,7 +97,9 @@ class TrtBackend : public BaseBackend { bool InitFromOnnx(const std::string& model_file, const TrtBackendOption& option = TrtBackendOption(), bool from_memory_buffer = false); - bool Infer(std::vector& inputs, std::vector* outputs); + bool Infer(std::vector& inputs, + std::vector* outputs, + bool copy_to_fd = true) override; int NumInputs() const { return inputs_desc_.size(); } int NumOutputs() const { return outputs_desc_.size(); } @@ -162,7 +164,8 @@ class TrtBackend : public BaseBackend { bool LoadTrtCache(const std::string& trt_engine_file); int ShapeRangeInfoUpdated(const std::vector& inputs); void SetInputs(const std::vector& inputs); - void AllocateOutputsBuffer(std::vector* outputs); + void AllocateOutputsBuffer(std::vector* outputs, + bool copy_to_fd = true); }; } // namespace fastdeploy diff --git a/fastdeploy/core/fd_tensor.cc b/fastdeploy/core/fd_tensor.cc index 1c5a7b422..896f2ff3b 100644 --- a/fastdeploy/core/fd_tensor.cc +++ b/fastdeploy/core/fd_tensor.cc @@ -81,11 +81,13 @@ const void* FDTensor::CpuData() const { void FDTensor::SetExternalData(const std::vector& new_shape, const FDDataType& data_type, void* data_buffer, - const Device& new_device) { + const Device& new_device, + int new_device_id) { dtype = data_type; shape.assign(new_shape.begin(), new_shape.end()); external_data_ptr = data_buffer; device = new_device; + device_id = new_device_id; } void FDTensor::ExpandDim(int64_t axis) { @@ -316,6 +318,8 @@ void FDTensor::FreeFn() { } } +// TODO(liqi): no src_device and dst_device +// should support copy from cpu or gpu to cpu or gpu void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes, const Device& device, bool is_pinned_memory) { if (device == Device::GPU) { @@ -383,7 +387,8 @@ FDTensor::FDTensor(const Scalar& scalar) { FDTensor::FDTensor(const FDTensor& other) : shape(other.shape), name(other.name), dtype(other.dtype), - device(other.device), external_data_ptr(other.external_data_ptr) { + device(other.device), external_data_ptr(other.external_data_ptr), + device_id(other.device_id) { // Copy buffer if (other.buffer_ == nullptr) { buffer_ = nullptr; @@ -398,7 +403,8 @@ FDTensor::FDTensor(const FDTensor& other) FDTensor::FDTensor(FDTensor&& other) : buffer_(other.buffer_), shape(std::move(other.shape)), name(std::move(other.name)), dtype(other.dtype), - external_data_ptr(other.external_data_ptr), device(other.device) { + external_data_ptr(other.external_data_ptr), device(other.device), + device_id(other.device_id) { other.name = ""; // Note(zhoushunjie): Avoid double free. other.buffer_ = nullptr; @@ -408,6 +414,7 @@ FDTensor::FDTensor(FDTensor&& other) FDTensor& FDTensor::operator=(const FDTensor& other) { if (&other != this) { // Copy buffer + device_id = other.device_id; if (other.buffer_ == nullptr) { FreeFn(); buffer_ = nullptr; @@ -435,6 +442,7 @@ FDTensor& FDTensor::operator=(FDTensor&& other) { name = std::move(other.name); dtype = other.dtype; device = other.device; + device_id = other.device_id; other.name = ""; // Note(zhoushunjie): Avoid double free. diff --git a/fastdeploy/core/fd_tensor.h b/fastdeploy/core/fd_tensor.h index ff443ae3b..ef9ff3796 100644 --- a/fastdeploy/core/fd_tensor.h +++ b/fastdeploy/core/fd_tensor.h @@ -78,7 +78,8 @@ struct FASTDEPLOY_DECL FDTensor { // So take care with the user buffer void SetExternalData(const std::vector& new_shape, const FDDataType& data_type, void* data_buffer, - const Device& new_device = Device::CPU); + const Device& new_device = Device::CPU, + int new_device_id = -1); // Expand the shape of a Tensor. Insert a new axis that will appear // at the `axis` position in the expanded Tensor shape. diff --git a/fastdeploy/runtime.cc b/fastdeploy/runtime.cc index b9924cb0b..2350817e5 100755 --- a/fastdeploy/runtime.cc +++ b/fastdeploy/runtime.cc @@ -580,6 +580,39 @@ bool Runtime::Infer(std::vector& input_tensors, return backend_->Infer(input_tensors, output_tensors); } +bool Runtime::Infer() { + return backend_->Infer(input_tensors_, &output_tensors_, false); +} + +void Runtime::BindInputTensor(const std::string& name, FDTensor& input) { + bool is_exist = false; + for (auto& t : input_tensors_) { + if (t.name == name) { + is_exist = true; + t.SetExternalData(input.shape, input.dtype, + input.MutableData(), input.device, + input.device_id); + break; + } + } + if(!is_exist) { + FDTensor new_tensor(name); + new_tensor.SetExternalData(input.shape, input.dtype, + input.MutableData(), input.device, + input.device_id); + input_tensors_.emplace_back(std::move(new_tensor)); + } +} + +FDTensor* Runtime::GetOutputTensor(const std::string& name) { + for (auto& t : output_tensors_) { + if (t.name == name) { + return &t; + } + } + return nullptr; +} + void Runtime::CreatePaddleBackend() { #ifdef ENABLE_PADDLE_BACKEND auto pd_option = PaddleBackendOption(); diff --git a/fastdeploy/runtime.h b/fastdeploy/runtime.h old mode 100755 new mode 100644 index 6ea584026..e96643345 --- a/fastdeploy/runtime.h +++ b/fastdeploy/runtime.h @@ -405,6 +405,12 @@ struct FASTDEPLOY_DECL Runtime { bool Infer(std::vector& input_tensors, std::vector* output_tensors); + /** \brief No params inference the model. + * + * the input and output data need to pass through the BindInputTensor and GetOutputTensor interfaces. + */ + bool Infer(); + /** \brief Compile TorchScript Module, only for Poros backend * * \param[in] prewarm_tensors Prewarm datas for compile @@ -432,6 +438,12 @@ struct FASTDEPLOY_DECL Runtime { /** \brief Get all the output information */ std::vector GetOutputInfos(); + /** \brief Bind FDTensor by name, no copy and share input memory + */ + void BindInputTensor(const std::string& name, FDTensor& input); + /** \brief Get output FDTensor by name, no copy and share backend output memory + */ + FDTensor* GetOutputTensor(const std::string& name); /** \brief Clone new Runtime when multiple instances of the same model are created * @@ -451,5 +463,7 @@ struct FASTDEPLOY_DECL Runtime { void CreateLiteBackend(); void CreateRKNPU2Backend(); std::unique_ptr backend_; + std::vector input_tensors_; + std::vector output_tensors_; }; } // namespace fastdeploy diff --git a/serving/src/fastdeploy_runtime.cc b/serving/src/fastdeploy_runtime.cc index 2e839b5ac..79479609c 100644 --- a/serving/src/fastdeploy_runtime.cc +++ b/serving/src/fastdeploy_runtime.cc @@ -607,9 +607,6 @@ class ModelInstanceState : public BackendModelInstance { std::vector output_names_; std::vector input_tensor_infos_; std::vector output_tensor_infos_; - - std::vector input_tensors_; - std::vector output_tensors_; }; TRITONSERVER_Error* ModelInstanceState::Create( @@ -647,8 +644,6 @@ ModelInstanceState::~ModelInstanceState() { ReleaseRunResources(); } void ModelInstanceState::ReleaseRunResources() { input_names_.clear(); output_names_.clear(); - input_tensors_.clear(); - output_tensors_.clear(); input_tensor_infos_.clear(); output_tensor_infos_.clear(); } @@ -671,9 +666,7 @@ TRITONSERVER_Error* ModelInstanceState::ValidateInputs() { input_tensor_infos_ = runtime_->GetInputInfos(); std::vector names; GetInfoNames(input_tensor_infos_, names); - input_tensors_.clear(); input_names_.clear(); - input_tensors_.reserve(input_tensor_infos_.size()); triton::common::TritonJson::Value ios; RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("input", &ios)); @@ -700,7 +693,6 @@ TRITONSERVER_Error* ModelInstanceState::ValidateInputs() { std::set inames(names.begin(), names.end()); RETURN_IF_ERROR(CheckAllowedModelInput(io, inames)); } - input_tensors_.emplace_back(io_name); auto fd_data_type = ModelConfigDataTypeToFDType(io_dtype); if (fd_data_type == fastdeploy::FDDataType::UNKNOWN1) { @@ -759,11 +751,8 @@ TRITONSERVER_Error* ModelInstanceState::ValidateInputs() { TRITONSERVER_Error* ModelInstanceState::ValidateOutputs() { output_tensor_infos_ = runtime_->GetOutputInfos(); - output_tensors_.clear(); - output_tensors_.reserve(output_tensor_infos_.size()); std::set out_names; for (const auto& info : output_tensor_infos_) { - output_tensors_.emplace_back(info.name); out_names.insert(info.name); } output_names_.clear(); @@ -793,7 +782,6 @@ TRITONSERVER_Error* ModelInstanceState::ValidateOutputs() { if (index < 0) { RETURN_IF_ERROR(CheckAllowedModelInput(io, out_names)); } - // output_tensors_.emplace_back(io_name); auto fd_data_type = ModelConfigDataTypeToFDType(io_dtype); if (fd_data_type == fastdeploy::FDDataType::UNKNOWN1) { @@ -1009,7 +997,7 @@ void ModelInstanceState::ProcessRequests(TRITONBACKEND_Request** requests, TRITONSERVER_Error* ModelInstanceState::Run( std::vector* responses, const uint32_t response_count) { - runtime_->Infer(input_tensors_, &output_tensors_); + runtime_->Infer(); #ifdef TRITON_ENABLE_GPU if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { cudaStreamSynchronize(CudaStream()); @@ -1042,18 +1030,7 @@ TRITONSERVER_Error* ModelInstanceState::SetInputTensors( input, &input_name, &input_datatype, &input_shape, &input_dims_count, nullptr, nullptr)); - int index = GetInfoIndex(std::string(input_name), input_tensor_infos_); - if (index < 0) { - auto err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - (std::string("Input name [") + input_name + - std::string("] is not one of the FD predictor input: ") + - input_tensors_[index].name) - .c_str()); - // SendErrorForResponses(responses, request_count, err); - return err; - } - + std::string in_name = std::string(input_name); std::vector batchn_shape; // For a ragged input tensor, the tensor shape should be // the flatten shape of the whole batch @@ -1082,23 +1059,40 @@ TRITONSERVER_Error* ModelInstanceState::SetInputTensors( } } + const char* input_buffer; + size_t batchn_byte_size; TRITONSERVER_MemoryType memory_type; - int64_t device_id = 0; - fastdeploy::Device device; + int64_t memory_type_id; + std::vector> + allowed_input_types; if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { - memory_type = TRITONSERVER_MEMORY_GPU; + allowed_input_types = {{TRITONSERVER_MEMORY_GPU, DeviceId()}, + {TRITONSERVER_MEMORY_CPU_PINNED, 0}, + {TRITONSERVER_MEMORY_CPU, 0}}; + } else { + allowed_input_types = {{TRITONSERVER_MEMORY_CPU_PINNED, 0}, + {TRITONSERVER_MEMORY_CPU, 0}}; + } + + RETURN_IF_ERROR( + collector->ProcessTensor( + input_name, nullptr, 0, allowed_input_types, &input_buffer, + &batchn_byte_size, &memory_type, &memory_type_id)); + + int32_t device_id = -1; + fastdeploy::Device device; + if (memory_type == TRITONSERVER_MEMORY_GPU) { device_id = DeviceId(); device = fastdeploy::Device::GPU; } else { - memory_type = TRITONSERVER_MEMORY_CPU; device = fastdeploy::Device::CPU; } - input_tensors_[index].Resize( - batchn_shape, ConvertDataTypeToFD(input_datatype), input_name, device); - collector->ProcessTensor( - input_name, - reinterpret_cast(input_tensors_[index].MutableData()), - input_tensors_[index].Nbytes(), memory_type, device_id); + + fastdeploy::FDTensor fdtensor(in_name); + fdtensor.SetExternalData( + batchn_shape, ConvertDataTypeToFD(input_datatype), + const_cast(input_buffer), device, device_id); + runtime_->BindInputTensor(in_name, fdtensor); } // Finalize... @@ -1134,12 +1128,25 @@ TRITONSERVER_Error* ModelInstanceState::ReadOutputTensors( // } for (auto& output_name : output_names_) { - int idx = GetInfoIndex(output_name, output_tensor_infos_); + auto* output_tensor = runtime_->GetOutputTensor(output_name); + if (output_tensor == nullptr) { + RETURN_IF_ERROR( + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("output tensor '") + output_name + "' is not found") + .c_str())); + } + TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; + int64_t memory_type_id = 0; + if(output_tensor->device == fastdeploy::Device::GPU) { + memory_type = TRITONSERVER_MEMORY_GPU; + memory_type_id = DeviceId(); + } responder.ProcessTensor( - output_tensors_[idx].name, ConvertFDType(output_tensors_[idx].dtype), - output_tensors_[idx].shape, - reinterpret_cast(output_tensors_[idx].MutableData()), - TRITONSERVER_MEMORY_CPU, 0); + output_tensor->name, ConvertFDType(output_tensor->dtype), + output_tensor->shape, + reinterpret_cast(output_tensor->MutableData()), + memory_type, memory_type_id); } // Finalize and wait for any pending buffer copies.