diff --git a/csrc/fastdeploy/backends/ort/ort_backend.cc b/csrc/fastdeploy/backends/ort/ort_backend.cc index 6f36364dc..4615e9a38 100644 --- a/csrc/fastdeploy/backends/ort/ort_backend.cc +++ b/csrc/fastdeploy/backends/ort/ort_backend.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "fastdeploy/backends/ort/ort_backend.h" + #include + #include "fastdeploy/backends/ort/ops/multiclass_nms.h" #include "fastdeploy/backends/ort/utils.h" #include "fastdeploy/utils/utils.h" @@ -164,33 +166,34 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file, return true; } -void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor, const std::string& name) { +void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor, + const std::string& name) { const auto info = value.GetTensorTypeAndShapeInfo(); const auto data_type = info.GetElementType(); size_t numel = info.GetElementCount(); + auto shape = info.GetShape(); + FDDataType dtype; if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { - tensor->Allocate(info.GetShape(), FDDataType::FP32, name); - memcpy(static_cast(tensor->MutableData()), value.GetTensorData(), - numel * sizeof(float)); + dtype = FDDataType::FP32; + numel *= sizeof(float); } else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { - tensor->Allocate(info.GetShape(), FDDataType::INT32, name); - memcpy(static_cast(tensor->MutableData()), value.GetTensorData(), - numel * sizeof(int32_t)); + dtype = FDDataType::INT32; + numel *= sizeof(int32_t); } else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { - tensor->Allocate(info.GetShape(), FDDataType::INT64, name); - memcpy(static_cast(tensor->MutableData()), value.GetTensorData(), - numel * sizeof(int64_t)); + dtype = FDDataType::INT64; + numel *= sizeof(int64_t); } else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - tensor->Allocate(info.GetShape(), FDDataType::FP64, name); - memcpy(static_cast(tensor->MutableData()), value.GetTensorData(), - numel * sizeof(double)); + dtype = FDDataType::FP64; + numel *= sizeof(double); } else { FDASSERT( false, "Unrecognized data type of %d while calling OrtBackend::CopyToCpu().", data_type); } + tensor->Resize(shape, dtype, name); + memcpy(tensor->MutableData(), value.GetTensorData(), numel); } bool OrtBackend::Infer(std::vector& inputs, diff --git a/csrc/fastdeploy/backends/ort/ort_backend.h b/csrc/fastdeploy/backends/ort/ort_backend.h index 93fd1e45d..526e28fc7 100644 --- a/csrc/fastdeploy/backends/ort/ort_backend.h +++ b/csrc/fastdeploy/backends/ort/ort_backend.h @@ -88,6 +88,7 @@ class OrtBackend : public BaseBackend { Ort::CustomOpDomain custom_op_domain_ = Ort::CustomOpDomain("Paddle"); #endif OrtBackendOption option_; - void CopyToCpu(const Ort::Value& value, FDTensor* tensor, const std::string& name); + void CopyToCpu(const Ort::Value& value, FDTensor* tensor, + const std::string& name); }; } // namespace fastdeploy diff --git a/csrc/fastdeploy/backends/paddle/paddle_backend.cc b/csrc/fastdeploy/backends/paddle/paddle_backend.cc index 7210225b9..9f1b5bb8c 100644 --- a/csrc/fastdeploy/backends/paddle/paddle_backend.cc +++ b/csrc/fastdeploy/backends/paddle/paddle_backend.cc @@ -79,16 +79,23 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file, } TensorInfo PaddleBackend::GetInputInfo(int index) { - FDASSERT(index < NumInputs(), "The index: %d should less than the number of inputs: %d.", index, NumInputs()); + FDASSERT(index < NumInputs(), + "The index: %d should less than the number of inputs: %d.", index, + NumInputs()); return inputs_desc_[index]; } +std::vector PaddleBackend::GetInputInfo() { return inputs_desc_; } + TensorInfo PaddleBackend::GetOutputInfo(int index) { FDASSERT(index < NumOutputs(), - "The index: %d should less than the number of outputs %d.", index, NumOutputs()); + "The index: %d should less than the number of outputs %d.", index, + NumOutputs()); return outputs_desc_[index]; } +std::vector PaddleBackend::GetOutputInfo() { return outputs_desc_; } + bool PaddleBackend::Infer(std::vector& inputs, std::vector* outputs) { if (inputs.size() != inputs_desc_.size()) { @@ -100,7 +107,7 @@ bool PaddleBackend::Infer(std::vector& inputs, for (size_t i = 0; i < inputs.size(); ++i) { auto handle = predictor_->GetInputHandle(inputs[i].name); - ShareTensorFromCpu(handle.get(), inputs[i]); + ShareTensorFromFDTensor(handle.get(), inputs[i]); } predictor_->Run(); diff --git a/csrc/fastdeploy/backends/paddle/paddle_backend.h b/csrc/fastdeploy/backends/paddle/paddle_backend.h index e634fd7b6..7ce82c9b3 100644 --- a/csrc/fastdeploy/backends/paddle/paddle_backend.h +++ b/csrc/fastdeploy/backends/paddle/paddle_backend.h @@ -44,8 +44,11 @@ struct PaddleBackendOption { std::vector delete_pass_names = {}; }; +// convert FD device to paddle place type +paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device); + // Share memory buffer with paddle_infer::Tensor from fastdeploy::FDTensor -void ShareTensorFromCpu(paddle_infer::Tensor* tensor, FDTensor& fd_tensor); +void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, FDTensor& fd_tensor); // Copy memory data from paddle_infer::Tensor to fastdeploy::FDTensor void CopyTensorToCpu(std::unique_ptr& tensor, @@ -72,6 +75,8 @@ class PaddleBackend : public BaseBackend { TensorInfo GetInputInfo(int index); TensorInfo GetOutputInfo(int index); + std::vector GetInputInfo(); + std::vector GetOutputInfo(); private: paddle_infer::Config config_; diff --git a/csrc/fastdeploy/backends/paddle/util.cc b/csrc/fastdeploy/backends/paddle/util.cc index c601c2a0e..4faba0039 100644 --- a/csrc/fastdeploy/backends/paddle/util.cc +++ b/csrc/fastdeploy/backends/paddle/util.cc @@ -15,23 +15,33 @@ #include "fastdeploy/backends/paddle/paddle_backend.h" namespace fastdeploy { -void ShareTensorFromCpu(paddle_infer::Tensor* tensor, FDTensor& fd_tensor) { +paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device) { + if (device == Device::GPU) { + return paddle_infer::PlaceType::kGPU; + } + return paddle_infer::PlaceType::kCPU; +} + +void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, + FDTensor& fd_tensor) { std::vector shape(fd_tensor.shape.begin(), fd_tensor.shape.end()); tensor->Reshape(shape); + auto place = ConvertFDDeviceToPlace(fd_tensor.device); if (fd_tensor.dtype == FDDataType::FP32) { tensor->ShareExternalData(static_cast(fd_tensor.Data()), - shape, paddle_infer::PlaceType::kCPU); + shape, place); return; } else if (fd_tensor.dtype == FDDataType::INT32) { tensor->ShareExternalData(static_cast(fd_tensor.Data()), - shape, paddle_infer::PlaceType::kCPU); + shape, place); return; } else if (fd_tensor.dtype == FDDataType::INT64) { tensor->ShareExternalData(static_cast(fd_tensor.Data()), - shape, paddle_infer::PlaceType::kCPU); + shape, place); return; } - FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", Str(fd_tensor.dtype).c_str()); + FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", + Str(fd_tensor.dtype).c_str()); } void CopyTensorToCpu(std::unique_ptr& tensor, @@ -51,7 +61,8 @@ void CopyTensorToCpu(std::unique_ptr& tensor, tensor->CopyToCpu(static_cast(fd_tensor->MutableData())); return; } - FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", Str(fd_tensor->dtype).c_str()); + FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", + Str(fd_tensor->dtype).c_str()); } FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) { @@ -65,7 +76,10 @@ FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) { } else if (dtype == paddle_infer::UINT8) { fd_dtype = FDDataType::UINT8; } else { - FDASSERT(false, "Unexpected data type: %d while call CopyTensorToCpu in PaddleBackend.", int(dtype)); + FDASSERT( + false, + "Unexpected data type: %d while call CopyTensorToCpu in PaddleBackend.", + int(dtype)); } return fd_dtype; } diff --git a/csrc/fastdeploy/backends/tensorrt/trt_backend.cc b/csrc/fastdeploy/backends/tensorrt/trt_backend.cc index d65af6fcb..841888f56 100644 --- a/csrc/fastdeploy/backends/tensorrt/trt_backend.cc +++ b/csrc/fastdeploy/backends/tensorrt/trt_backend.cc @@ -13,9 +13,11 @@ // limitations under the License. #include "fastdeploy/backends/tensorrt/trt_backend.h" + +#include + #include "NvInferSafeRuntime.h" #include "fastdeploy/utils/utils.h" -#include #ifdef ENABLE_PADDLE_FRONTEND #include "paddle2onnx/converter.h" #endif @@ -210,9 +212,9 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file, outputs_desc_.resize(onnx_reader.num_outputs); for (int i = 0; i < onnx_reader.num_inputs; ++i) { std::string name(onnx_reader.inputs[i].name); - std::vector shape(onnx_reader.inputs[i].shape, - onnx_reader.inputs[i].shape + - onnx_reader.inputs[i].rank); + std::vector shape( + onnx_reader.inputs[i].shape, + onnx_reader.inputs[i].shape + onnx_reader.inputs[i].rank); inputs_desc_[i].name = name; inputs_desc_[i].shape.assign(shape.begin(), shape.end()); inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype); @@ -231,9 +233,9 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file, for (int i = 0; i < onnx_reader.num_outputs; ++i) { std::string name(onnx_reader.outputs[i].name); - std::vector shape(onnx_reader.outputs[i].shape, - onnx_reader.outputs[i].shape + - onnx_reader.outputs[i].rank); + std::vector shape( + onnx_reader.outputs[i].shape, + onnx_reader.outputs[i].shape + onnx_reader.outputs[i].rank); outputs_desc_[i].name = name; outputs_desc_[i].shape.assign(shape.begin(), shape.end()); outputs_desc_[i].dtype = @@ -286,24 +288,8 @@ bool TrtBackend::Infer(std::vector& inputs, BuildTrtEngine(); } - AllocateBufferInDynamicShape(inputs, outputs); - std::vector input_binds(inputs.size()); - for (size_t i = 0; i < inputs.size(); ++i) { - if (inputs[i].dtype == FDDataType::INT64) { - int64_t* data = static_cast(inputs[i].Data()); - std::vector casted_data(data, data + inputs[i].Numel()); - FDASSERT(cudaMemcpyAsync(inputs_buffer_[inputs[i].name].data(), - static_cast(casted_data.data()), - inputs[i].Nbytes() / 2, cudaMemcpyHostToDevice, - stream_) == 0, - "[ERROR] Error occurs while copy memory from CPU to GPU."); - } else { - FDASSERT(cudaMemcpyAsync(inputs_buffer_[inputs[i].name].data(), - inputs[i].Data(), inputs[i].Nbytes(), - cudaMemcpyHostToDevice, stream_) == 0, - "[ERROR] Error occurs while copy memory from CPU to GPU."); - } - } + SetInputs(inputs); + AllocateOutputsBuffer(outputs); if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) { FDERROR << "Failed to Infer with TensorRT." << std::endl; return false; @@ -339,18 +325,50 @@ void TrtBackend::GetInputOutputInfo() { bindings_.resize(num_binds); } -void TrtBackend::AllocateBufferInDynamicShape( - const std::vector& inputs, std::vector* outputs) { +void TrtBackend::SetInputs(const std::vector& inputs) { for (const auto& item : inputs) { auto idx = engine_->getBindingIndex(item.name.c_str()); std::vector shape(item.shape.begin(), item.shape.end()); auto dims = ToDims(shape); context_->setBindingDimensions(idx, dims); - if (item.Nbytes() > inputs_buffer_[item.name].nbBytes()) { + + if (item.device == Device::GPU) { + if (item.dtype == FDDataType::INT64) { + // TODO(liqi): cast int64 to int32 + // TRT don't support INT64 + FDASSERT(false, + "TRT don't support INT64 input on GPU, " + "please use INT32 input"); + } else { + // no copy + inputs_buffer_[item.name].SetExternalData(dims, item.Data()); + } + } else { + // Allocate input buffer memory inputs_buffer_[item.name].resize(dims); - bindings_[idx] = inputs_buffer_[item.name].data(); + + // copy from cpu to gpu + if (item.dtype == FDDataType::INT64) { + int64_t* data = static_cast(const_cast(item.Data())); + std::vector casted_data(data, data + item.Numel()); + FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(), + static_cast(casted_data.data()), + item.Nbytes() / 2, cudaMemcpyHostToDevice, + stream_) == 0, + "Error occurs while copy memory from CPU to GPU."); + } else { + FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(), item.Data(), + item.Nbytes(), cudaMemcpyHostToDevice, + stream_) == 0, + "Error occurs while copy memory from CPU to GPU."); + } } + // binding input buffer + bindings_[idx] = inputs_buffer_[item.name].data(); } +} + +void TrtBackend::AllocateOutputsBuffer(std::vector* outputs) { if (outputs->size() != outputs_desc_.size()) { outputs->resize(outputs_desc_.size()); } @@ -365,13 +383,15 @@ void TrtBackend::AllocateBufferInDynamicShape( "Cannot find output: %s of tensorrt network from the original model.", outputs_desc_[i].name.c_str()); auto ori_idx = iter->second; - std::vector shape(output_dims.d, output_dims.d + output_dims.nbDims); - (*outputs)[ori_idx].Allocate(shape, GetFDDataType(outputs_desc_[i].dtype), outputs_desc_[i].name); - if ((*outputs)[ori_idx].Nbytes() > - outputs_buffer_[outputs_desc_[i].name].nbBytes()) { - outputs_buffer_[outputs_desc_[i].name].resize(output_dims); - bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data(); - } + // set user's outputs info + std::vector shape(output_dims.d, + output_dims.d + output_dims.nbDims); + (*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype), + outputs_desc_[i].name); + // Allocate output buffer memory + outputs_buffer_[outputs_desc_[i].name].resize(output_dims); + // binding output buffer + bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data(); } } @@ -580,4 +600,4 @@ TensorInfo TrtBackend::GetOutputInfo(int index) { info.dtype = GetFDDataType(outputs_desc_[index].dtype); return info; } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrc/fastdeploy/backends/tensorrt/trt_backend.h b/csrc/fastdeploy/backends/tensorrt/trt_backend.h index b04a75e55..ed249a2a7 100644 --- a/csrc/fastdeploy/backends/tensorrt/trt_backend.h +++ b/csrc/fastdeploy/backends/tensorrt/trt_backend.h @@ -14,6 +14,8 @@ #pragma once +#include + #include #include #include @@ -23,7 +25,6 @@ #include "NvOnnxParser.h" #include "fastdeploy/backends/backend.h" #include "fastdeploy/backends/tensorrt/utils.h" -#include namespace fastdeploy { @@ -109,12 +110,12 @@ class TrtBackend : public BaseBackend { std::map shape_range_info_; void GetInputOutputInfo(); - void AllocateBufferInDynamicShape(const std::vector& inputs, - std::vector* outputs); bool CreateTrtEngineFromOnnx(const std::string& onnx_model_buffer); bool BuildTrtEngine(); bool LoadTrtCache(const std::string& trt_engine_file); int ShapeRangeInfoUpdated(const std::vector& inputs); + void SetInputs(const std::vector& inputs); + void AllocateOutputsBuffer(std::vector* outputs); }; -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrc/fastdeploy/backends/tensorrt/utils.h b/csrc/fastdeploy/backends/tensorrt/utils.h index 4739cedbe..bf84b2ef2 100644 --- a/csrc/fastdeploy/backends/tensorrt/utils.h +++ b/csrc/fastdeploy/backends/tensorrt/utils.h @@ -14,11 +14,9 @@ #pragma once -#include "NvInfer.h" -#include "fastdeploy/core/fd_tensor.h" -#include "fastdeploy/utils/utils.h" -#include #include + +#include #include #include #include @@ -26,17 +24,24 @@ #include #include +#include "NvInfer.h" +#include "fastdeploy/core/allocate.h" +#include "fastdeploy/core/fd_tensor.h" +#include "fastdeploy/utils/utils.h" + namespace fastdeploy { struct FDInferDeleter { - template void operator()(T* obj) const { + template + void operator()(T* obj) const { if (obj) { obj->destroy(); } } }; -template using FDUniquePtr = std::unique_ptr; +template +using FDUniquePtr = std::unique_ptr; int64_t Volume(const nvinfer1::Dims& d); @@ -64,13 +69,18 @@ std::ostream& operator<<(std::ostream& out, const std::vector& vec) { return out; } -template class FDGenericBuffer { +template +class FDGenericBuffer { public: //! //! \brief Construct an empty buffer. //! explicit FDGenericBuffer(nvinfer1::DataType type = nvinfer1::DataType::kFLOAT) - : mSize(0), mCapacity(0), mType(type), mBuffer(nullptr) {} + : mSize(0), + mCapacity(0), + mType(type), + mBuffer(nullptr), + mExternal_buffer(nullptr) {} //! //! \brief Construct a buffer with the specified allocation size in bytes. @@ -82,8 +92,18 @@ template class FDGenericBuffer { } } + //! + //! \brief This use to skip memory copy step. + //! + FDGenericBuffer(size_t size, nvinfer1::DataType type, void* buffer) + : mSize(size), mCapacity(size), mType(type) { + mExternal_buffer = buffer; + } + FDGenericBuffer(FDGenericBuffer&& buf) - : mSize(buf.mSize), mCapacity(buf.mCapacity), mType(buf.mType), + : mSize(buf.mSize), + mCapacity(buf.mCapacity), + mType(buf.mType), mBuffer(buf.mBuffer) { buf.mSize = 0; buf.mCapacity = 0; @@ -109,12 +129,18 @@ template class FDGenericBuffer { //! //! \brief Returns pointer to underlying array. //! - void* data() { return mBuffer; } + void* data() { + if (mExternal_buffer != nullptr) return mExternal_buffer; + return mBuffer; + } //! //! \brief Returns pointer to underlying array. //! - const void* data() const { return mBuffer; } + const void* data() const { + if (mExternal_buffer != nullptr) return mExternal_buffer; + return mBuffer; + } //! //! \brief Returns the size (in number of elements) of the buffer. @@ -126,11 +152,29 @@ template class FDGenericBuffer { //! size_t nbBytes() const { return this->size() * TrtDataTypeSize(mType); } + //! + //! \brief Set user memory buffer for TRT Buffer + //! + void SetExternalData(size_t size, nvinfer1::DataType type, void* buffer) { + mSize = mCapacity = size; + mType = type; + mExternal_buffer = const_cast(buffer); + } + + //! + //! \brief Set user memory buffer for TRT Buffer + //! + void SetExternalData(const nvinfer1::Dims& dims, const void* buffer) { + mSize = mCapacity = Volume(dims); + mExternal_buffer = const_cast(buffer); + } + //! //! \brief Resizes the buffer. This is a no-op if the new size is smaller than //! or equal to the current capacity. //! void resize(size_t newSize) { + mExternal_buffer = nullptr; mSize = newSize; if (mCapacity < newSize) { freeFn(mBuffer); @@ -146,28 +190,20 @@ template class FDGenericBuffer { //! void resize(const nvinfer1::Dims& dims) { return this->resize(Volume(dims)); } - ~FDGenericBuffer() { freeFn(mBuffer); } + ~FDGenericBuffer() { + mExternal_buffer = nullptr; + freeFn(mBuffer); + } private: size_t mSize{0}, mCapacity{0}; nvinfer1::DataType mType; void* mBuffer; + void* mExternal_buffer; AllocFunc allocFn; FreeFunc freeFn; }; -class FDDeviceAllocator { - public: - bool operator()(void** ptr, size_t size) const { - return cudaMalloc(ptr, size) == cudaSuccess; - } -}; - -class FDDeviceFree { - public: - void operator()(void* ptr) const { cudaFree(ptr); } -}; - using FDDeviceBuffer = FDGenericBuffer; class FDTrtLogger : public nvinfer1::ILogger { @@ -197,7 +233,7 @@ class FDTrtLogger : public nvinfer1::ILogger { }; struct ShapeRangeInfo { - ShapeRangeInfo(const std::vector& new_shape) { + explicit ShapeRangeInfo(const std::vector& new_shape) { shape.assign(new_shape.begin(), new_shape.end()); min.resize(new_shape.size()); max.resize(new_shape.size()); @@ -239,4 +275,4 @@ struct ShapeRangeInfo { } }; -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrc/fastdeploy/core/allocate.cc b/csrc/fastdeploy/core/allocate.cc new file mode 100644 index 000000000..0615ee46d --- /dev/null +++ b/csrc/fastdeploy/core/allocate.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#ifdef WITH_GPU +#include +#endif + +#include "fastdeploy/core/allocate.h" + +namespace fastdeploy { + +bool FDHostAllocator::operator()(void** ptr, size_t size) const { + *ptr = malloc(size); + return *ptr != nullptr; +} + +void FDHostFree::operator()(void* ptr) const { free(ptr); } + +#ifdef WITH_GPU + +bool FDDeviceAllocator::operator()(void** ptr, size_t size) const { + return cudaMalloc(ptr, size) == cudaSuccess; +} + +void FDDeviceFree::operator()(void* ptr) const { cudaFree(ptr); } + +#endif + +} // namespace fastdeploy diff --git a/csrc/fastdeploy/core/allocate.h b/csrc/fastdeploy/core/allocate.h new file mode 100644 index 000000000..c48bb7cee --- /dev/null +++ b/csrc/fastdeploy/core/allocate.h @@ -0,0 +1,50 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include + +#include "fastdeploy/utils/utils.h" + +namespace fastdeploy { + +class FASTDEPLOY_DECL FDHostAllocator { + public: + bool operator()(void** ptr, size_t size) const; +}; + +class FASTDEPLOY_DECL FDHostFree { + public: + void operator()(void* ptr) const; +}; + +#ifdef WITH_GPU + +class FASTDEPLOY_DECL FDDeviceAllocator { + public: + bool operator()(void** ptr, size_t size) const; +}; + +class FASTDEPLOY_DECL FDDeviceFree { + public: + void operator()(void* ptr) const; +}; + +#endif + +} // namespace fastdeploy diff --git a/csrc/fastdeploy/core/fd_tensor.cc b/csrc/fastdeploy/core/fd_tensor.cc index d4761bba6..5fe93299d 100644 --- a/csrc/fastdeploy/core/fd_tensor.cc +++ b/csrc/fastdeploy/core/fd_tensor.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "fastdeploy/core/fd_tensor.h" + #include "fastdeploy/utils/utils.h" #ifdef WITH_GPU @@ -25,55 +26,69 @@ void* FDTensor::MutableData() { if (external_data_ptr != nullptr) { return external_data_ptr; } - return data.data(); + return buffer_; } void* FDTensor::Data() { if (external_data_ptr != nullptr) { - if (device == Device::GPU) { -#ifdef WITH_GPU - // need to copy cuda mem to cpu first - temporary_cpu_buffer.resize(Nbytes()); - FDASSERT(cudaMemcpy(temporary_cpu_buffer.data(), external_data_ptr, - Nbytes(), cudaMemcpyDeviceToHost) == 0, - "[ERROR] Error occurs while copy memory from GPU to CPU"); - return temporary_cpu_buffer.data(); -#else - FDASSERT(false, - "The FastDeploy didn't compile under -DWITH_GPU=ON, so this is " - "an unexpected problem happend."); -#endif - } else { - return external_data_ptr; - } + return external_data_ptr; } - return data.data(); + return buffer_; } const void* FDTensor::Data() const { if (external_data_ptr != nullptr) { return external_data_ptr; } - return data.data(); + return buffer_; +} + +const void* FDTensor::CpuData() const { + if (device == Device::GPU) { +#ifdef WITH_GPU + auto* cpu_ptr = const_cast*>(&temporary_cpu_buffer); + cpu_ptr->resize(Nbytes()); + // need to copy cuda mem to cpu first + if (external_data_ptr != nullptr) { + FDASSERT(cudaMemcpy(cpu_ptr->data(), external_data_ptr, Nbytes(), + cudaMemcpyDeviceToHost) == 0, + "[ERROR] Error occurs while copy memory from GPU to CPU"); + + } else { + FDASSERT(cudaMemcpy(cpu_ptr->data(), buffer_, Nbytes(), + cudaMemcpyDeviceToHost) == 0, + "[ERROR] Error occurs while buffer copy memory from GPU to CPU"); + } + return cpu_ptr->data(); +#else + FDASSERT(false, + "The FastDeploy didn't compile under -DWITH_GPU=ON, so this is " + "an unexpected problem happend."); +#endif + } + return Data(); } void FDTensor::SetExternalData(const std::vector& new_shape, - const FDDataType& data_type, void* data_buffer) { + const FDDataType& data_type, void* data_buffer, + const Device& new_device) { dtype = data_type; shape.assign(new_shape.begin(), new_shape.end()); external_data_ptr = data_buffer; + device = new_device; } void FDTensor::Allocate(const std::vector& new_shape, const FDDataType& data_type, - const std::string& tensor_name) { + const std::string& tensor_name, + const Device& new_device) { dtype = data_type; name = tensor_name; shape.assign(new_shape.begin(), new_shape.end()); - int unit = FDDataTypeSize(data_type); - int total_size = - std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - data.resize(total_size * unit); + device = new_device; + size_t nbytes = Nbytes(); + FDASSERT(AllocFn(nbytes), + "The FastDeploy FDTensor allocate cpu memory error"); } int FDTensor::Nbytes() const { return Numel() * FDDataTypeSize(dtype); } @@ -82,6 +97,44 @@ int FDTensor::Numel() const { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); } +void FDTensor::Resize(size_t new_nbytes) { + size_t nbytes = Nbytes(); + if (new_nbytes > nbytes) { + FreeFn(); + AllocFn(new_nbytes); + } +} + +void FDTensor::Resize(const std::vector& new_shape) { + int numel = Numel(); + int new_numel = std::accumulate(new_shape.begin(), new_shape.end(), 1, + std::multiplies()); + shape.assign(new_shape.begin(), new_shape.end()); + if (new_numel > numel) { + FreeFn(); + size_t nbytes = new_numel * FDDataTypeSize(dtype); + AllocFn(nbytes); + } +} + +void FDTensor::Resize(const std::vector& new_shape, + const FDDataType& data_type, + const std::string& tensor_name, + const Device& new_device) { + name = tensor_name; + device = new_device; + size_t nbytes = Nbytes(); + shape.assign(new_shape.begin(), new_shape.end()); + dtype = data_type; + int new_nbytes = std::accumulate(new_shape.begin(), new_shape.end(), 1, + std::multiplies()) * + FDDataTypeSize(data_type); + if (new_nbytes > nbytes) { + FreeFn(); + AllocFn(new_nbytes); + } +} + template void CalculateStatisInfo(void* src_ptr, int size, double* mean, double* max, double* min) { diff --git a/csrc/fastdeploy/core/fd_tensor.h b/csrc/fastdeploy/core/fd_tensor.h index 84e8c7ff0..8447bfaad 100644 --- a/csrc/fastdeploy/core/fd_tensor.h +++ b/csrc/fastdeploy/core/fd_tensor.h @@ -18,15 +18,17 @@ #include #include +#include "fastdeploy/core/allocate.h" #include "fastdeploy/core/fd_type.h" namespace fastdeploy { struct FASTDEPLOY_DECL FDTensor { - std::vector data; - std::vector shape; + // std::vector data; + void* buffer_ = nullptr; + std::vector shape = {0}; std::string name = ""; - FDDataType dtype; + FDDataType dtype = FDDataType::INT8; // This use to skip memory copy step // the external_data_ptr will point to the user allocated memory @@ -46,28 +48,32 @@ struct FASTDEPLOY_DECL FDTensor { // Get data buffer pointer void* MutableData(); - // Use this data to get the tensor data to process - // Since the most senario is process data in CPU - // this function weill return a pointer to cpu memory - // buffer. - // If the original data is on other device, the data - // will copy to cpu store in `temporary_cpu_buffer` void* Data(); const void* Data() const; + // Use this data to get the tensor data to process + // Since the most senario is process data in CPU + // this function will return a pointer to cpu memory + // buffer. + // If the original data is on other device, the data + // will copy to cpu store in `temporary_cpu_buffer` + const void* CpuData() const; + // Set user memory buffer for Tensor, the memory is managed by // the user it self, but the Tensor will share the memory with user // So take care with the user buffer void SetExternalData(const std::vector& new_shape, - const FDDataType& data_type, void* data_buffer); + const FDDataType& data_type, void* data_buffer, + const Device& new_device = Device::CPU); // Initialize Tensor // Include setting attribute for tensor // and allocate cpu memory buffer void Allocate(const std::vector& new_shape, const FDDataType& data_type, - const std::string& tensor_name = ""); + const std::string& tensor_name = "", + const Device& new_device = Device::CPU); // Total size of tensor memory buffer in bytes int Nbytes() const; @@ -75,13 +81,51 @@ struct FASTDEPLOY_DECL FDTensor { // Total number of elements in this tensor int Numel() const; + void Resize(size_t nbytes); + + void Resize(const std::vector& new_shape); + + void Resize(const std::vector& new_shape, + const FDDataType& data_type, const std::string& tensor_name = "", + const Device& new_device = Device::CPU); + // Debug function // Use this function to print shape, dtype, mean, max, min // prefix will also be printed as tag void PrintInfo(const std::string& prefix = "TensorInfo: "); + bool AllocFn(size_t nbytes) { + if (device == Device::GPU) { +#ifdef WITH_GPU + return FDDeviceAllocator()(&buffer_, nbytes); +#else + FDASSERT(false, + "The FastDeploy FDTensor allocator didn't compile under " + "-DWITH_GPU=ON," + "so this is an unexpected problem happend."); +#endif + } + return FDHostAllocator()(&buffer_, nbytes); + } + + void FreeFn() { + if (external_data_ptr != nullptr) external_data_ptr = nullptr; + if (buffer_ != nullptr) { + if (device == Device::GPU) { +#ifdef WITH_GPU + FDDeviceFree()(buffer_); +#endif + } else { + FDHostFree()(buffer_); + } + buffer_ = nullptr; + } + } + FDTensor() {} explicit FDTensor(const std::string& tensor_name); + + ~FDTensor() { FreeFn(); } }; } // namespace fastdeploy diff --git a/csrc/fastdeploy/core/fd_type.cc b/csrc/fastdeploy/core/fd_type.cc index ccba7e232..15844802a 100644 --- a/csrc/fastdeploy/core/fd_type.cc +++ b/csrc/fastdeploy/core/fd_type.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "fastdeploy/core/fd_type.h" + #include "fastdeploy/utils/utils.h" namespace fastdeploy { @@ -33,6 +34,8 @@ int FDDataTypeSize(const FDDataType& data_type) { return sizeof(double); } else if (data_type == FDDataType::UINT8) { return sizeof(uint8_t); + } else if (data_type == FDDataType::INT8) { + return sizeof(int8_t); } else { FDASSERT(false, "Unexpected data type: %s", Str(data_type).c_str()); } @@ -42,9 +45,6 @@ int FDDataTypeSize(const FDDataType& data_type) { std::string Str(const Device& d) { std::string out; switch (d) { - case Device::DEFAULT: - out = "Device::DEFAULT"; - break; case Device::CPU: out = "Device::CPU"; break; diff --git a/csrc/fastdeploy/core/fd_type.h b/csrc/fastdeploy/core/fd_type.h index 50b00dca8..22f33e0e5 100644 --- a/csrc/fastdeploy/core/fd_type.h +++ b/csrc/fastdeploy/core/fd_type.h @@ -22,7 +22,7 @@ namespace fastdeploy { -enum FASTDEPLOY_DECL Device { DEFAULT, CPU, GPU }; +enum FASTDEPLOY_DECL Device { CPU, GPU }; FASTDEPLOY_DECL std::string Str(const Device& d); diff --git a/csrc/fastdeploy/pybind/fastdeploy_runtime.cc b/csrc/fastdeploy/pybind/fastdeploy_runtime.cc index c4a07b50f..a5924b0af 100644 --- a/csrc/fastdeploy/pybind/fastdeploy_runtime.cc +++ b/csrc/fastdeploy/pybind/fastdeploy_runtime.cc @@ -71,14 +71,13 @@ void BindRuntime(pybind11::module& m) { std::vector inputs(data.size()); int index = 0; for (auto iter = data.begin(); iter != data.end(); ++iter) { - inputs[index].dtype = - NumpyDataTypeToFDDataType(iter->second.dtype()); - inputs[index].shape.insert( - inputs[index].shape.begin(), iter->second.shape(), - iter->second.shape() + iter->second.ndim()); + std::vector data_shape; + data_shape.insert(data_shape.begin(), iter->second.shape(), + iter->second.shape() + iter->second.ndim()); + auto dtype = NumpyDataTypeToFDDataType(iter->second.dtype()); // TODO(jiangjiajun) Maybe skip memory copy is a better choice // use SetExternalData - inputs[index].data.resize(iter->second.nbytes()); + inputs[index].Resize(data_shape, dtype); memcpy(inputs[index].MutableData(), iter->second.mutable_data(), iter->second.nbytes()); inputs[index].name = iter->first; @@ -134,4 +133,4 @@ void BindRuntime(pybind11::module& m) { m.def("get_available_backends", []() { return GetAvailableBackends(); }); } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrc/fastdeploy/pybind/main.cc.in b/csrc/fastdeploy/pybind/main.cc.in index ba3d799c0..8280fdebf 100644 --- a/csrc/fastdeploy/pybind/main.cc.in +++ b/csrc/fastdeploy/pybind/main.cc.in @@ -59,13 +59,15 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype) { void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor, bool share_buffer) { - tensor->dtype = NumpyDataTypeToFDDataType(pyarray.dtype()); - tensor->shape.insert(tensor->shape.begin(), pyarray.shape(), - pyarray.shape() + pyarray.ndim()); + auto dtype = NumpyDataTypeToFDDataType(pyarray.dtype()); + std::vector data_shape; + data_shape.insert(data_shape.begin(), pyarray.shape(), + pyarray.shape() + pyarray.ndim()); if (share_buffer) { - tensor->external_data_ptr = pyarray.mutable_data(); + tensor-> SetExternalData(data_shape, dtype, + pyarray.mutable_data()); } else { - tensor->data.resize(pyarray.nbytes()); + tensor->Resize(data_shape, dtype); memcpy(tensor->MutableData(), pyarray.mutable_data(), pyarray.nbytes()); } } diff --git a/csrc/fastdeploy/pybind/main.h b/csrc/fastdeploy/pybind/main.h index 4230eb059..09c42f876 100644 --- a/csrc/fastdeploy/pybind/main.h +++ b/csrc/fastdeploy/pybind/main.h @@ -17,6 +17,7 @@ #include #include #include + #include #include "fastdeploy/fastdeploy_runtime.h" @@ -42,7 +43,8 @@ pybind11::array TensorToPyArray(const FDTensor& tensor); cv::Mat PyArrayToCvMat(pybind11::array& pyarray); #endif -template FDDataType CTypeToFDDataType() { +template +FDDataType CTypeToFDDataType() { if (std::is_same::value) { return FDDataType::INT32; } else if (std::is_same::value) { @@ -58,16 +60,17 @@ template FDDataType CTypeToFDDataType() { } template -std::vector -PyBackendInfer(T& self, const std::vector& names, - std::vector& data) { +std::vector PyBackendInfer( + T& self, const std::vector& names, + std::vector& data) { std::vector inputs(data.size()); for (size_t i = 0; i < data.size(); ++i) { // TODO(jiangjiajun) here is considered to use user memory directly - inputs[i].dtype = NumpyDataTypeToFDDataType(data[i].dtype()); - inputs[i].shape.insert(inputs[i].shape.begin(), data[i].shape(), - data[i].shape() + data[i].ndim()); - inputs[i].data.resize(data[i].nbytes()); + auto dtype = NumpyDataTypeToFDDataType(data[i].dtype()); + std::vector data_shape; + data_shape.insert(data_shape.begin(), data[i].shape(), + data[i].shape() + data[i].ndim()); + inputs[i].Resize(data_shape, dtype); memcpy(inputs[i].MutableData(), data[i].mutable_data(), data[i].nbytes()); inputs[i].name = names[i]; } @@ -86,4 +89,4 @@ PyBackendInfer(T& self, const std::vector& names, return results; } -} // namespace fastdeploy +} // namespace fastdeploy