From 62e051e21de5045b98b02b6b5599fd48013b80c9 Mon Sep 17 00:00:00 2001 From: Wang Xinyu Date: Mon, 30 Jan 2023 09:33:49 +0800 Subject: [PATCH] [CVCUDA] CMake integration, vison processor CV-CUDA integration, PaddleClas support CV-CUDA (#1074) * cvcuda resize * cvcuda center crop * cvcuda resize * add a fdtensor in fdmat * get cv mat and get tensor support gpu * paddleclas cvcuda preprocessor * fix compile err * fix windows compile error * rename reused to cached * address comment * remove debug code * add comment * add manager run * use cuda and cuda used * use cv cuda doc * address comment --------- Co-authored-by: Jason --- CMakeLists.txt | 7 + FastDeploy.cmake.in | 9 ++ cmake/cvcuda.cmake | 43 +++++ docs/cn/faq/use_cv_cuda.md | 39 +++++ .../serving/models/preprocess/1/model.py | 2 +- .../classification/ppcls/ppcls_pybind.cc | 150 ++++++++++-------- .../classification/ppcls/preprocessor.cc | 60 +++---- .../classification/ppcls/preprocessor.h | 18 +-- fastdeploy/vision/common/processors/base.cc | 53 +++++-- fastdeploy/vision/common/processors/base.h | 25 ++- .../vision/common/processors/center_crop.cc | 35 ++++ .../vision/common/processors/center_crop.h | 3 + .../vision/common/processors/cvcuda_utils.cc | 76 +++++++++ .../vision/common/processors/cvcuda_utils.h | 31 ++++ .../vision/common/processors/manager.cc | 80 ++++++++++ fastdeploy/vision/common/processors/manager.h | 74 +++++++++ fastdeploy/vision/common/processors/mat.cc | 122 ++++++++++---- fastdeploy/vision/common/processors/mat.h | 39 ++--- .../processors/normalize_and_permute.cu | 53 +++---- .../common/processors/normalize_and_permute.h | 3 + .../vision/common/processors/proc_lib.h | 2 +- fastdeploy/vision/common/processors/resize.cc | 54 ++++++- fastdeploy/vision/common/processors/resize.h | 3 + .../common/processors/resize_by_short.cc | 39 ++++- .../common/processors/resize_by_short.h | 3 + .../vision/classification/ppcls/__init__.py | 7 +- 26 files changed, 814 insertions(+), 216 deletions(-) create mode 100644 cmake/cvcuda.cmake create mode 100644 docs/cn/faq/use_cv_cuda.md create mode 100644 fastdeploy/vision/common/processors/cvcuda_utils.cc create mode 100644 fastdeploy/vision/common/processors/cvcuda_utils.h create mode 100644 fastdeploy/vision/common/processors/manager.cc create mode 100644 fastdeploy/vision/common/processors/manager.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 35fae01c7..269fd40a0 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,6 +66,7 @@ option(ENABLE_LITE_BACKEND "Whether to enable paddle lite backend." OFF) option(ENABLE_VISION "Whether to enable vision models usage." OFF) option(ENABLE_TEXT "Whether to enable text models usage." OFF) option(ENABLE_FLYCV "Whether to enable flycv to boost image preprocess." OFF) +option(ENABLE_CVCUDA "Whether to enable NVIDIA CV-CUDA to boost image preprocess." OFF) option(ENABLE_ENCRYPTION "Whether to enable ENCRYPTION." OFF) option(WITH_ASCEND "Whether to compile for Huawei Ascend deploy." OFF) option(WITH_TIMVX "Whether to compile for TIMVX deploy." OFF) @@ -373,6 +374,12 @@ if(ENABLE_VISION) include(${PROJECT_SOURCE_DIR}/cmake/flycv.cmake) list(APPEND DEPEND_LIBS external_flycv) endif() + + if(ENABLE_CVCUDA) + include(${PROJECT_SOURCE_DIR}/cmake/cvcuda.cmake) + add_definitions(-DENABLE_CVCUDA) + list(APPEND DEPEND_LIBS nvcv_types cvcuda) + endif() endif() if(ENABLE_TEXT) diff --git a/FastDeploy.cmake.in b/FastDeploy.cmake.in index 76b8f747c..fd0370653 100755 --- a/FastDeploy.cmake.in +++ b/FastDeploy.cmake.in @@ -13,6 +13,7 @@ set(ENABLE_TRT_BACKEND @ENABLE_TRT_BACKEND@) set(ENABLE_PADDLE2ONNX @ENABLE_PADDLE2ONNX@) set(ENABLE_VISION @ENABLE_VISION@) set(ENABLE_FLYCV @ENABLE_FLYCV@) +set(ENABLE_CVCUDA @ENABLE_CVCUDA@) set(ENABLE_TEXT @ENABLE_TEXT@) set(ENABLE_ENCRYPTION @ENABLE_ENCRYPTION@) set(BUILD_ON_JETSON @BUILD_ON_JETSON@) @@ -140,6 +141,7 @@ if(WITH_GPU) message(FATAL_ERROR "[FastDeploy] Cannot find library cudart in ${CUDA_DIRECTORY}, Please define CUDA_DIRECTORY, e.g -DCUDA_DIRECTORY=/path/to/cuda") endif() list(APPEND FASTDEPLOY_LIBS ${CUDA_LIB}) + list(APPEND FASTDEPLOY_INCS ${CUDA_DIRECTORY}/include) if (ENABLE_TRT_BACKEND) if(BUILD_ON_JETSON) @@ -218,6 +220,12 @@ if(ENABLE_VISION) endif() endif() + if(ENABLE_CVCUDA) + find_library(CVCUDA_LIB cvcuda ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/cvcuda/lib NO_DEFAULT_PATH) + find_library(NVCV_TYPES_LIB nvcv_types ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/cvcuda/lib NO_DEFAULT_PATH) + list(APPEND FASTDEPLOY_LIBS ${CVCUDA_LIB} ${NVCV_TYPES_LIB}) + endif() + endif() if (ENABLE_TEXT) @@ -288,6 +296,7 @@ if(ENABLE_OPENVINO_BACKEND) endif() message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}") message(STATUS " ENABLE_VISION : ${ENABLE_VISION}") +message(STATUS " ENABLE_CVCUDA : ${ENABLE_CVCUDA}") message(STATUS " ENABLE_TEXT : ${ENABLE_TEXT}") message(STATUS " ENABLE_ENCRYPTION : ${ENABLE_ENCRYPTION}") if(WITH_GPU) diff --git a/cmake/cvcuda.cmake b/cmake/cvcuda.cmake new file mode 100644 index 000000000..002af9021 --- /dev/null +++ b/cmake/cvcuda.cmake @@ -0,0 +1,43 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT WITH_GPU) + message(FATAL_ERROR "ENABLE_CVCUDA is available on Linux and WITH_GPU=ON, but now WITH_GPU=OFF.") +endif() + +if(APPLE OR ANDROID OR IOS OR WIN32) + message(FATAL_ERROR "Cannot enable CV-CUDA in mac/ios/android/windows os, please set -DENABLE_CVCUDA=OFF.") +endif() + +if(NOT (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64")) + message(FATAL_ERROR "CV-CUDA only support x86_64.") +endif() + +set(CVCUDA_LIB_URL https://github.com/CVCUDA/CV-CUDA/releases/download/v0.2.0-alpha/nvcv-lib-0.2.0_alpha-cuda11-x86_64-linux.tar.xz) +set(CVCUDA_LIB_FILENAME nvcv-lib-0.2.0_alpha-cuda11-x86_64-linux.tar.xz) +set(CVCUDA_DEV_URL https://github.com/CVCUDA/CV-CUDA/releases/download/v0.2.0-alpha/nvcv-dev-0.2.0_alpha-cuda11-x86_64-linux.tar.xz) +set(CVCUDA_DEV_FILENAME nvcv-dev-0.2.0_alpha-cuda11-x86_64-linux.tar.xz) + +download_and_decompress(${CVCUDA_LIB_URL} ${CMAKE_CURRENT_BINARY_DIR}/${CVCUDA_LIB_FILENAME} ${THIRD_PARTY_PATH}/cvcuda) +download_and_decompress(${CVCUDA_DEV_URL} ${CMAKE_CURRENT_BINARY_DIR}/${CVCUDA_DEV_FILENAME} ${THIRD_PARTY_PATH}/cvcuda) + +execute_process(COMMAND rm -rf ${THIRD_PARTY_PATH}/install/cvcuda) +execute_process(COMMAND mkdir -p ${THIRD_PARTY_PATH}/install/cvcuda) +execute_process(COMMAND cp -r ${THIRD_PARTY_PATH}/cvcuda/opt/nvidia/cvcuda0/lib/x86_64-linux-gnu/ ${THIRD_PARTY_PATH}/install/cvcuda/lib) +execute_process(COMMAND cp -r ${THIRD_PARTY_PATH}/cvcuda/opt/nvidia/cvcuda0/include/ ${THIRD_PARTY_PATH}/install/cvcuda/include) + +link_directories(${THIRD_PARTY_PATH}/install/cvcuda/lib) +include_directories(${THIRD_PARTY_PATH}/install/cvcuda/include) + +set(CMAKE_CXX_STANDARD 17) diff --git a/docs/cn/faq/use_cv_cuda.md b/docs/cn/faq/use_cv_cuda.md new file mode 100644 index 000000000..8fad01738 --- /dev/null +++ b/docs/cn/faq/use_cv_cuda.md @@ -0,0 +1,39 @@ +# 使用CV-CUDA/CUDA加速GPU端到端推理性能 + +FastDeploy集成了CV-CUDA来加速预/后处理,个别CV-CUDA不支持的算子使用了CUDA kernel的方式实现。 + +FastDeploy的Vision Processor模块对CV-CUDA的算子做了进一步的封装,用户不需要自己去调用CV-CUDA, +使用FastDeploy的模型推理接口即可利用CV-CUDA的加速能力。 + +FastDeploy的Vision Processor模块在集成CV-CUDA时,做了以下工作来方便用户的使用: +- GPU内存管理,缓存算子的输入、输出tensor,避免重复分配GPU内存 +- CV-CUDA不支持的个别算子利用CUDA kernel实现 +- CV-CUDA/CUDA不支持的算子可以fallback到OpenCV/FlyCV + +## 使用方式 +编译FastDeploy时,开启CV-CUDA编译选项 +```bash +# 编译C++预测库时, 开启CV-CUDA编译选项. +-DENABLE_CVCUDA=ON \ + +# 在编译Python预测库时, 开启CV-CUDA编译选项 +export ENABLE_CVCUDA=ON +``` + +只有继承了ProcessorManager类的模型预处理,才可以使用CV-CUDA,这里以PaddleClasPreprocessor为例 +```bash +# C++ +# 创建model之后,调用model preprocessor的UseCuda接口即可打开CV-CUDA/CUDA预处理 +# 第一个参数enable_cv_cuda,true代表使用CV-CUDA,false代表只使用CUDA(支持的算子较少) +# 第二个参数是GPU id,-1代表不指定,使用当前GPU +model.GetPreprocessor().UseCuda(true, 0); + +# Python +model.preprocessor.use_cuda(True, 0) +``` + +## 最佳实践 + +- 如果预处理第一个算子是resize,则要根据实际情况决定resize是否跑在GPU。因为当resize跑在GPU, + 且图片解码在CPU时,需要把原图copy到GPU内存,开销较大,而resize之后再copy到GPU内存,则往往只需要 + copy较少的数据。 diff --git a/examples/vision/classification/paddleclas/serving/models/preprocess/1/model.py b/examples/vision/classification/paddleclas/serving/models/preprocess/1/model.py index c05ef7b65..724f1f974 100644 --- a/examples/vision/classification/paddleclas/serving/models/preprocess/1/model.py +++ b/examples/vision/classification/paddleclas/serving/models/preprocess/1/model.py @@ -70,7 +70,7 @@ class TritonPythonModel: yaml_path) if args['model_instance_kind'] == 'GPU': device_id = int(args['model_instance_device_id']) - self.preprocess_.use_gpu(device_id) + self.preprocess_.use_cuda(False, device_id) def execute(self, requests): """`execute` must be implemented in every Python model. `execute` diff --git a/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc b/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc index b776d5c45..514f5ad9d 100644 --- a/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc +++ b/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc @@ -18,76 +18,102 @@ void BindPaddleClas(pybind11::module& m) { pybind11::class_( m, "PaddleClasPreprocessor") .def(pybind11::init()) - .def("run", [](vision::classification::PaddleClasPreprocessor& self, std::vector& im_list) { - std::vector images; - for (size_t i = 0; i < im_list.size(); ++i) { - images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); - } - std::vector outputs; - if (!self.Run(&images, &outputs)) { - throw std::runtime_error("Failed to preprocess the input data in PaddleClasPreprocessor."); - } - if (!self.WithGpu()) { - for (size_t i = 0; i < outputs.size(); ++i) { - outputs[i].StopSharing(); - } - } - return outputs; - }) - .def("use_gpu", [](vision::classification::PaddleClasPreprocessor& self, int gpu_id = -1) { - self.UseGpu(gpu_id); - }) - .def("disable_normalize", [](vision::classification::PaddleClasPreprocessor& self) { - self.DisableNormalize(); - }) - .def("disable_permute", [](vision::classification::PaddleClasPreprocessor& self) { - self.DisablePermute(); - }); + .def("run", + [](vision::classification::PaddleClasPreprocessor& self, + std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + } + std::vector outputs; + if (!self.Run(&images, &outputs)) { + throw std::runtime_error( + "Failed to preprocess the input data in " + "PaddleClasPreprocessor."); + } + if (!self.CudaUsed()) { + for (size_t i = 0; i < outputs.size(); ++i) { + outputs[i].StopSharing(); + } + } + return outputs; + }) + .def("use_cuda", + [](vision::classification::PaddleClasPreprocessor& self, + bool enable_cv_cuda = false, + int gpu_id = -1) { self.UseCuda(enable_cv_cuda, gpu_id); }) + .def("disable_normalize", + [](vision::classification::PaddleClasPreprocessor& self) { + self.DisableNormalize(); + }) + .def("disable_permute", + [](vision::classification::PaddleClasPreprocessor& self) { + self.DisablePermute(); + }); pybind11::class_( m, "PaddleClasPostprocessor") .def(pybind11::init()) - .def("run", [](vision::classification::PaddleClasPostprocessor& self, std::vector& inputs) { - std::vector results; - if (!self.Run(inputs, &results)) { - throw std::runtime_error("Failed to postprocess the runtime result in PaddleClasPostprocessor."); - } - return results; - }) - .def("run", [](vision::classification::PaddleClasPostprocessor& self, std::vector& input_array) { - std::vector results; - std::vector inputs; - PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); - if (!self.Run(inputs, &results)) { - throw std::runtime_error("Failed to postprocess the runtime result in PaddleClasPostprocessor."); - } - return results; - }) - .def_property("topk", &vision::classification::PaddleClasPostprocessor::GetTopk, &vision::classification::PaddleClasPostprocessor::SetTopk); + .def("run", + [](vision::classification::PaddleClasPostprocessor& self, + std::vector& inputs) { + std::vector results; + if (!self.Run(inputs, &results)) { + throw std::runtime_error( + "Failed to postprocess the runtime result in " + "PaddleClasPostprocessor."); + } + return results; + }) + .def("run", + [](vision::classification::PaddleClasPostprocessor& self, + std::vector& input_array) { + std::vector results; + std::vector inputs; + PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); + if (!self.Run(inputs, &results)) { + throw std::runtime_error( + "Failed to postprocess the runtime result in " + "PaddleClasPostprocessor."); + } + return results; + }) + .def_property("topk", + &vision::classification::PaddleClasPostprocessor::GetTopk, + &vision::classification::PaddleClasPostprocessor::SetTopk); pybind11::class_( m, "PaddleClasModel") .def(pybind11::init()) - .def("clone", [](vision::classification::PaddleClasModel& self) { - return self.Clone(); - }) - .def("predict", [](vision::classification::PaddleClasModel& self, pybind11::array& data) { - cv::Mat im = PyArrayToCvMat(data); - vision::ClassifyResult result; - self.Predict(im, &result); - return result; - }) - .def("batch_predict", [](vision::classification::PaddleClasModel& self, std::vector& data) { - std::vector images; - for (size_t i = 0; i < data.size(); ++i) { - images.push_back(PyArrayToCvMat(data[i])); - } - std::vector results; - self.BatchPredict(images, &results); - return results; - }) - .def_property_readonly("preprocessor", &vision::classification::PaddleClasModel::GetPreprocessor) - .def_property_readonly("postprocessor", &vision::classification::PaddleClasModel::GetPostprocessor); + .def("clone", + [](vision::classification::PaddleClasModel& self) { + return self.Clone(); + }) + .def("predict", + [](vision::classification::PaddleClasModel& self, + pybind11::array& data) { + cv::Mat im = PyArrayToCvMat(data); + vision::ClassifyResult result; + self.Predict(im, &result); + return result; + }) + .def("batch_predict", + [](vision::classification::PaddleClasModel& self, + std::vector& data) { + std::vector images; + for (size_t i = 0; i < data.size(); ++i) { + images.push_back(PyArrayToCvMat(data[i])); + } + std::vector results; + self.BatchPredict(images, &results); + return results; + }) + .def_property_readonly( + "preprocessor", + &vision::classification::PaddleClasModel::GetPreprocessor) + .def_property_readonly( + "postprocessor", + &vision::classification::PaddleClasModel::GetPostprocessor); } } // namespace fastdeploy diff --git a/fastdeploy/vision/classification/ppcls/preprocessor.cc b/fastdeploy/vision/classification/ppcls/preprocessor.cc index aa4314cf8..90d40e094 100644 --- a/fastdeploy/vision/classification/ppcls/preprocessor.cc +++ b/fastdeploy/vision/classification/ppcls/preprocessor.cc @@ -13,11 +13,9 @@ // limitations under the License. #include "fastdeploy/vision/classification/ppcls/preprocessor.h" + #include "fastdeploy/function/concat.h" #include "yaml-cpp/yaml.h" -#ifdef WITH_GPU -#include -#endif namespace fastdeploy { namespace vision { @@ -61,9 +59,10 @@ bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig() { auto mean = op.begin()->second["mean"].as>(); auto std = op.begin()->second["std"].as>(); auto scale = op.begin()->second["scale"].as(); - FDASSERT((scale - 0.00392157) < 1e-06 && (scale - 0.00392157) > -1e-06, - "Only support scale in Normalize be 0.00392157, means the pixel " - "is in range of [0, 255]."); + FDASSERT( + (scale - 0.00392157) < 1e-06 && (scale - 0.00392157) > -1e-06, + "Only support scale in Normalize be 0.00392157, means the pixel " + "is in range of [0, 255]."); processors_.push_back(std::make_shared(mean, std)); } } else if (op_name == "ToCHWImage") { @@ -84,53 +83,32 @@ bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig() { void PaddleClasPreprocessor::DisableNormalize() { this->disable_normalize_ = true; - // the DisableNormalize function will be invalid if the configuration file is loaded during preprocessing + // the DisableNormalize function will be invalid if the configuration file is + // loaded during preprocessing if (!BuildPreprocessPipelineFromConfig()) { - FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl; + FDERROR << "Failed to build preprocess pipeline from configuration file." + << std::endl; } } void PaddleClasPreprocessor::DisablePermute() { this->disable_permute_ = true; - // the DisablePermute function will be invalid if the configuration file is loaded during preprocessing + // the DisablePermute function will be invalid if the configuration file is + // loaded during preprocessing if (!BuildPreprocessPipelineFromConfig()) { - FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl; + FDERROR << "Failed to build preprocess pipeline from configuration file." + << std::endl; } } -void PaddleClasPreprocessor::UseGpu(int gpu_id) { -#ifdef WITH_GPU - use_cuda_ = true; - if (gpu_id < 0) return; - device_id_ = gpu_id; - cudaSetDevice(device_id_); -#else - FDWARNING << "FastDeploy didn't compile with WITH_GPU. " - << "Will force to use CPU to run preprocessing." << std::endl; - use_cuda_ = false; -#endif -} - -bool PaddleClasPreprocessor::Run(std::vector* images, std::vector* outputs) { - if (!initialized_) { - FDERROR << "The preprocessor is not initialized." << std::endl; - return false; - } - if (images->size() == 0) { - FDERROR << "The size of input images should be greater than 0." << std::endl; - return false; - } - +bool PaddleClasPreprocessor::Apply(std::vector* images, + std::vector* outputs) { for (size_t i = 0; i < images->size(); ++i) { for (size_t j = 0; j < processors_.size(); ++j) { bool ret = false; - if (processors_[j]->Name() == "NormalizeAndPermute" && use_cuda_) { - ret = (*(processors_[j].get()))(&((*images)[i]), ProcLib::CUDA); - } else { - ret = (*(processors_[j].get()))(&((*images)[i])); - } + ret = (*(processors_[j].get()))(&((*images)[i])); if (!ret) { FDERROR << "Failed to processs image:" << i << " in " - << processors_[i]->Name() << "." << std::endl; + << processors_[j]->Name() << "." << std::endl; return false; } } @@ -138,7 +116,7 @@ bool PaddleClasPreprocessor::Run(std::vector* images, std::vectorresize(1); // Concat all the preprocessed data to a batch tensor - std::vector tensors(images->size()); + std::vector tensors(images->size()); for (size_t i = 0; i < images->size(); ++i) { (*images)[i].ShareWithTensor(&(tensors[i])); tensors[i].ExpandDim(0); @@ -148,7 +126,7 @@ bool PaddleClasPreprocessor::Run(std::vector* images, std::vector* images, std::vector* outputs); - - /** \brief Use GPU to run preprocessing - * - * \param[in] gpu_id GPU device id - */ - void UseGpu(int gpu_id = -1); - - bool WithGpu() { return use_cuda_; } + virtual bool Apply(std::vector* images, + std::vector* outputs); /// This function will disable normalize in preprocessing step. void DisableNormalize(); @@ -54,10 +48,6 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor { private: bool BuildPreprocessPipelineFromConfig(); std::vector> processors_; - bool initialized_ = false; - bool use_cuda_ = false; - // GPU device id - int device_id_ = -1; // for recording the switch of hwc2chw bool disable_permute_ = false; // for recording the switch of normalize diff --git a/fastdeploy/vision/common/processors/base.cc b/fastdeploy/vision/common/processors/base.cc index f3d5b0a97..a47cfe378 100644 --- a/fastdeploy/vision/common/processors/base.cc +++ b/fastdeploy/vision/common/processors/base.cc @@ -13,9 +13,9 @@ // limitations under the License. #include "fastdeploy/vision/common/processors/base.h" -#include "fastdeploy/vision/common/processors/proc_lib.h" #include "fastdeploy/utils/utils.h" +#include "fastdeploy/vision/common/processors/proc_lib.h" namespace fastdeploy { namespace vision { @@ -33,27 +33,58 @@ bool Processor::operator()(Mat* mat, ProcLib lib) { #endif } else if (target == ProcLib::CUDA) { #ifdef WITH_GPU + FDASSERT(mat->Stream() != nullptr, + "CUDA processor requires cuda stream, please set stream for Mat"); return ImplByCuda(mat); #else FDASSERT(false, "FastDeploy didn't compile with WITH_GPU."); +#endif + } else if (target == ProcLib::CVCUDA) { +#ifdef ENABLE_CVCUDA + FDASSERT(mat->Stream() != nullptr, + "CV-CUDA requires cuda stream, please set stream for Mat"); + return ImplByCvCuda(mat); +#else + FDASSERT(false, "FastDeploy didn't compile with CV-CUDA."); #endif } // DEFAULT & OPENCV return ImplByOpenCV(mat); } -FDTensor* Processor::UpdateAndGetReusedBuffer( - const std::vector& new_shape, const int& opencv_dtype, - const std::string& buffer_name, const Device& new_device, +FDTensor* Processor::UpdateAndGetCachedTensor( + const std::vector& new_shape, const FDDataType& data_type, + const std::string& tensor_name, const Device& new_device, const bool& use_pinned_memory) { - if (reused_buffers_.count(buffer_name) == 0) { - reused_buffers_[buffer_name] = FDTensor(); + if (cached_tensors_.count(tensor_name) == 0) { + cached_tensors_[tensor_name] = FDTensor(); } - reused_buffers_[buffer_name].is_pinned_memory = use_pinned_memory; - reused_buffers_[buffer_name].Resize(new_shape, - OpenCVDataTypeToFD(opencv_dtype), - buffer_name, new_device); - return &reused_buffers_[buffer_name]; + cached_tensors_[tensor_name].is_pinned_memory = use_pinned_memory; + cached_tensors_[tensor_name].Resize(new_shape, data_type, tensor_name, + new_device); + return &cached_tensors_[tensor_name]; +} + +FDTensor* Processor::CreateCachedGpuInputTensor( + Mat* mat, const std::string& tensor_name) { +#ifdef WITH_GPU + FDTensor* src = mat->Tensor(); + if (src->device == Device::GPU) { + return src; + } else if (src->device == Device::CPU) { + FDTensor* tensor = UpdateAndGetCachedTensor(src->Shape(), src->Dtype(), + tensor_name, Device::GPU); + FDASSERT(cudaMemcpyAsync(tensor->Data(), src->Data(), tensor->Nbytes(), + cudaMemcpyHostToDevice, mat->Stream()) == 0, + "[ERROR] Error occurs while copy memory from CPU to GPU."); + return tensor; + } else { + FDASSERT(false, "FDMat is on unsupported device: %d", src->device); + } +#else + FDASSERT(false, "FastDeploy didn't compile with WITH_GPU."); +#endif + return nullptr; } void EnableFlyCV() { diff --git a/fastdeploy/vision/common/processors/base.h b/fastdeploy/vision/common/processors/base.h index 00bd9c82f..6fb3a33eb 100644 --- a/fastdeploy/vision/common/processors/base.h +++ b/fastdeploy/vision/common/processors/base.h @@ -59,16 +59,33 @@ class FASTDEPLOY_DECL Processor { return ImplByOpenCV(mat); } + virtual bool ImplByCvCuda(Mat* mat) { + return ImplByOpenCV(mat); + } + virtual bool operator()(Mat* mat, ProcLib lib = ProcLib::DEFAULT); protected: - FDTensor* UpdateAndGetReusedBuffer( - const std::vector& new_shape, const int& opencv_dtype, - const std::string& buffer_name, const Device& new_device = Device::CPU, + // Update and get the cached tensor from the cached_tensors_ map. + // The tensor is indexed by a string. + // If the tensor doesn't exists in the map, then create a new tensor. + // If the tensor exists and shape is getting larger, then realloc the buffer. + // If the tensor exists and shape is not getting larger, then return the + // cached tensor directly. + FDTensor* UpdateAndGetCachedTensor( + const std::vector& new_shape, const FDDataType& data_type, + const std::string& tensor_name, const Device& new_device = Device::CPU, const bool& use_pinned_memory = false); + // Create an input tensor on GPU and save into cached_tensors_. + // If the Mat is on GPU, return the mat->Tensor() directly. + // If the Mat is on CPU, then create a cached GPU tensor and copy the mat's + // CPU tensor to this new GPU tensor. + FDTensor* CreateCachedGpuInputTensor(Mat* mat, + const std::string& tensor_name); + private: - std::unordered_map reused_buffers_; + std::unordered_map cached_tensors_; }; } // namespace vision diff --git a/fastdeploy/vision/common/processors/center_crop.cc b/fastdeploy/vision/common/processors/center_crop.cc index af7c74448..bb0c96947 100644 --- a/fastdeploy/vision/common/processors/center_crop.cc +++ b/fastdeploy/vision/common/processors/center_crop.cc @@ -14,6 +14,12 @@ #include "fastdeploy/vision/common/processors/center_crop.h" +#ifdef ENABLE_CVCUDA +#include + +#include "fastdeploy/vision/common/processors/cvcuda_utils.h" +#endif + namespace fastdeploy { namespace vision { @@ -56,6 +62,35 @@ bool CenterCrop::ImplByFlyCV(Mat* mat) { } #endif +#ifdef ENABLE_CVCUDA +bool CenterCrop::ImplByCvCuda(Mat* mat) { + // Prepare input tensor + std::string tensor_name = Name() + "_cvcuda_src"; + FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name); + auto src_tensor = CreateCvCudaTensorWrapData(*src); + + // Prepare output tensor + tensor_name = Name() + "_cvcuda_dst"; + FDTensor* dst = + UpdateAndGetCachedTensor({height_, width_, mat->Channels()}, src->Dtype(), + tensor_name, Device::GPU); + auto dst_tensor = CreateCvCudaTensorWrapData(*dst); + + int offset_x = static_cast((mat->Width() - width_) / 2); + int offset_y = static_cast((mat->Height() - height_) / 2); + cvcuda::CustomCrop crop_op; + NVCVRectI crop_roi = {offset_x, offset_y, width_, height_}; + crop_op(mat->Stream(), src_tensor, dst_tensor, crop_roi); + + mat->SetTensor(dst); + mat->SetWidth(width_); + mat->SetHeight(height_); + mat->device = Device::GPU; + mat->mat_type = ProcLib::CVCUDA; + return true; +} +#endif + bool CenterCrop::Run(Mat* mat, const int& width, const int& height, ProcLib lib) { auto c = CenterCrop(width, height); diff --git a/fastdeploy/vision/common/processors/center_crop.h b/fastdeploy/vision/common/processors/center_crop.h index 05f594249..7455773f6 100644 --- a/fastdeploy/vision/common/processors/center_crop.h +++ b/fastdeploy/vision/common/processors/center_crop.h @@ -25,6 +25,9 @@ class FASTDEPLOY_DECL CenterCrop : public Processor { bool ImplByOpenCV(Mat* mat); #ifdef ENABLE_FLYCV bool ImplByFlyCV(Mat* mat); +#endif +#ifdef ENABLE_CVCUDA + bool ImplByCvCuda(Mat* mat); #endif std::string Name() { return "CenterCrop"; } diff --git a/fastdeploy/vision/common/processors/cvcuda_utils.cc b/fastdeploy/vision/common/processors/cvcuda_utils.cc new file mode 100644 index 000000000..482d0dc69 --- /dev/null +++ b/fastdeploy/vision/common/processors/cvcuda_utils.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/common/processors/cvcuda_utils.h" + +namespace fastdeploy { +namespace vision { + +#ifdef ENABLE_CVCUDA +nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel) { + FDASSERT(channel == 1 || channel == 3 || channel == 4, + "Only support channel be 1/3/4 in CV-CUDA."); + if (type == FDDataType::UINT8) { + if (channel == 1) { + return nvcv::FMT_U8; + } else if (channel == 3) { + return nvcv::FMT_BGR8; + } else { + return nvcv::FMT_BGRA8; + } + } else if (type == FDDataType::FP32) { + if (channel == 1) { + return nvcv::FMT_F32; + } else if (channel == 3) { + return nvcv::FMT_BGRf32; + } else { + return nvcv::FMT_BGRAf32; + } + } + FDASSERT(false, "Data type of %s is not supported.", Str(type).c_str()); + return nvcv::FMT_BGRf32; +} + +nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor) { + FDASSERT(tensor.shape.size() == 3, + "When create CVCUDA tensor from FD tensor," + "tensor shape should be 3-Dim, HWC layout"); + int batchsize = 1; + + nvcv::TensorDataStridedCuda::Buffer buf; + buf.strides[3] = FDDataTypeSize(tensor.Dtype()); + buf.strides[2] = tensor.shape[2] * buf.strides[3]; + buf.strides[1] = tensor.shape[1] * buf.strides[2]; + buf.strides[0] = tensor.shape[0] * buf.strides[1]; + buf.basePtr = reinterpret_cast(const_cast(tensor.Data())); + + nvcv::Tensor::Requirements req = nvcv::Tensor::CalcRequirements( + batchsize, {tensor.shape[1], tensor.shape[0]}, + CreateCvCudaImageFormat(tensor.Dtype(), tensor.shape[2])); + + nvcv::TensorDataStridedCuda tensor_data( + nvcv::TensorShape{req.shape, req.rank, req.layout}, + nvcv::DataType{req.dtype}, buf); + return nvcv::TensorWrapData(tensor_data); +} + +void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor) { + auto data = + dynamic_cast(tensor.exportData()); + return reinterpret_cast(data->basePtr()); +} +#endif + +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/cvcuda_utils.h b/fastdeploy/vision/common/processors/cvcuda_utils.h new file mode 100644 index 000000000..cd4eae8f6 --- /dev/null +++ b/fastdeploy/vision/common/processors/cvcuda_utils.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/core/fd_tensor.h" + +#ifdef ENABLE_CVCUDA +#include "nvcv/Tensor.hpp" + +namespace fastdeploy { +namespace vision { + +nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel); +nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor); +void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor); + +} +} +#endif diff --git a/fastdeploy/vision/common/processors/manager.cc b/fastdeploy/vision/common/processors/manager.cc new file mode 100644 index 000000000..147e12ae8 --- /dev/null +++ b/fastdeploy/vision/common/processors/manager.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "fastdeploy/vision/common/processors/manager.h" + +namespace fastdeploy { +namespace vision { + +ProcessorManager::~ProcessorManager() { +#ifdef WITH_GPU + if (stream_) cudaStreamDestroy(stream_); +#endif +} + +void ProcessorManager::UseCuda(bool enable_cv_cuda, int gpu_id) { +#ifdef WITH_GPU + if (gpu_id >= 0) { + device_id_ = gpu_id; + FDASSERT(cudaSetDevice(device_id_) == cudaSuccess, + "[ERROR] Error occurs while setting cuda device."); + } + FDASSERT(cudaStreamCreate(&stream_) == cudaSuccess, + "[ERROR] Error occurs while creating cuda stream."); + DefaultProcLib::default_lib = ProcLib::CUDA; +#else + FDASSERT(false, "FastDeploy didn't compile with WITH_GPU."); +#endif + + if (enable_cv_cuda) { +#ifdef ENABLE_CVCUDA + DefaultProcLib::default_lib = ProcLib::CVCUDA; +#else + FDASSERT(false, "FastDeploy didn't compile with CV-CUDA."); +#endif + } +} + +bool ProcessorManager::CudaUsed() { + return (DefaultProcLib::default_lib == ProcLib::CUDA || + DefaultProcLib::default_lib == ProcLib::CVCUDA); +} + +bool ProcessorManager::Run(std::vector* images, + std::vector* outputs) { + if (!initialized_) { + FDERROR << "The preprocessor is not initialized." << std::endl; + return false; + } + if (images->size() == 0) { + FDERROR << "The size of input images should be greater than 0." + << std::endl; + return false; + } + + for (size_t i = 0; i < images->size(); ++i) { + if (CudaUsed()) { + SetStream(&((*images)[i])); + } + } + + bool ret = Apply(images, outputs); + + if (CudaUsed()) { + SyncStream(); + } + return ret; +} + +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/manager.h b/fastdeploy/vision/common/processors/manager.h new file mode 100644 index 000000000..8721c7e10 --- /dev/null +++ b/fastdeploy/vision/common/processors/manager.h @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/utils/utils.h" +#include "fastdeploy/vision/common/processors/mat.h" + +namespace fastdeploy { +namespace vision { + +class FASTDEPLOY_DECL ProcessorManager { + public: + ~ProcessorManager(); + + void UseCuda(bool enable_cv_cuda = false, int gpu_id = -1); + + bool CudaUsed(); + + void SetStream(Mat* mat) { +#ifdef WITH_GPU + mat->SetStream(stream_); +#endif + } + + void SyncStream() { +#ifdef WITH_GPU + FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess, + "[ERROR] Error occurs while sync cuda stream."); +#endif + } + + int DeviceId() { return device_id_; } + + /** \brief Process the input image and prepare input tensors for runtime + * + * \param[in] images The input image data list, all the elements are returned by cv::imread() + * \param[in] outputs The output tensors which will feed in runtime + * \return true if the preprocess successed, otherwise false + */ + bool Run(std::vector* images, std::vector* outputs); + + /** \brief The body of Run() function which needs to be implemented by a derived class + * + * \param[in] images The input image data list, all the elements are returned by cv::imread() + * \param[in] outputs The output tensors which will feed in runtime + * \return true if the preprocess successed, otherwise false + */ + virtual bool Apply(std::vector* images, + std::vector* outputs) = 0; + + protected: + bool initialized_ = false; + + private: +#ifdef WITH_GPU + cudaStream_t stream_ = nullptr; +#endif + int device_id_ = -1; +}; + +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/mat.cc b/fastdeploy/vision/common/processors/mat.cc index 7e9a0efc1..93d11f871 100644 --- a/fastdeploy/vision/common/processors/mat.cc +++ b/fastdeploy/vision/common/processors/mat.cc @@ -19,6 +19,36 @@ namespace fastdeploy { namespace vision { +cv::Mat* Mat::GetOpenCVMat() { + if (mat_type == ProcLib::OPENCV) { + return &cpu_mat; + } else if (mat_type == ProcLib::FLYCV) { +#ifdef ENABLE_FLYCV + // Just a reference to fcv_mat, zero copy. After you + // call this method, cpu_mat and fcv_mat will point + // to the same memory buffer. + cpu_mat = ConvertFlyCVMatToOpenCV(fcv_mat); + mat_type = ProcLib::OPENCV; + return &cpu_mat; +#else + FDASSERT(false, "FastDeploy didn't compiled with FlyCV!"); +#endif + } else if (mat_type == ProcLib::CUDA || mat_type == ProcLib::CVCUDA) { +#ifdef WITH_GPU + FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess, + "[ERROR] Error occurs while sync cuda stream."); + cpu_mat = CreateZeroCopyOpenCVMatFromTensor(fd_tensor); + mat_type = ProcLib::OPENCV; + device = Device::CPU; + return &cpu_mat; +#else + FDASSERT(false, "FastDeploy didn't compiled with -DWITH_GPU=ON"); +#endif + } else { + FDASSERT(false, "The mat_type of custom Mat can not be ProcLib::DEFAULT"); + } +} + void* Mat::Data() { if (mat_type == ProcLib::FLYCV) { #ifdef ENABLE_FLYCV @@ -28,10 +58,32 @@ void* Mat::Data() { "FastDeploy didn't compile with FlyCV, but met data type with " "fcv::Mat."); #endif + } else if (device == Device::GPU) { + return fd_tensor.Data(); } return cpu_mat.ptr(); } +FDTensor* Mat::Tensor() { + if (mat_type == ProcLib::OPENCV) { + ShareWithTensor(&fd_tensor); + } else if (mat_type == ProcLib::FLYCV) { +#ifdef ENABLE_FLYCV + cpu_mat = ConvertFlyCVMatToOpenCV(fcv_mat); + mat_type = ProcLib::OPENCV; + ShareWithTensor(&fd_tensor); +#else + FDASSERT(false, "FastDeploy didn't compiled with FlyCV!"); +#endif + } + return &fd_tensor; +} + +void Mat::SetTensor(FDTensor* tensor) { + fd_tensor.SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(), + tensor->device, tensor->device_id); +} + void Mat::ShareWithTensor(FDTensor* tensor) { tensor->SetExternalData({Channels(), Height(), Width()}, Type(), Data()); tensor->device = device; @@ -54,15 +106,15 @@ bool Mat::CopyToTensor(FDTensor* tensor) { } void Mat::PrintInfo(const std::string& flag) { + std::cout << flag << ": " + << "DataType=" << Type() << ", " + << "Channel=" << Channels() << ", " + << "Height=" << Height() << ", " + << "Width=" << Width() << ", " + << "Mean="; if (mat_type == ProcLib::FLYCV) { #ifdef ENABLE_FLYCV fcv::Scalar mean = fcv::mean(fcv_mat); - std::cout << flag << ": " - << "DataType=" << Type() << ", " - << "Channel=" << Channels() << ", " - << "Height=" << Height() << ", " - << "Width=" << Width() << ", " - << "Mean="; for (int i = 0; i < Channels(); ++i) { std::cout << mean[i] << " "; } @@ -72,18 +124,25 @@ void Mat::PrintInfo(const std::string& flag) { "FastDeploy didn't compile with FlyCV, but met data type with " "fcv::Mat."); #endif - } else { + } else if (mat_type == ProcLib::OPENCV) { cv::Scalar mean = cv::mean(cpu_mat); - std::cout << flag << ": " - << "DataType=" << Type() << ", " - << "Channel=" << Channels() << ", " - << "Height=" << Height() << ", " - << "Width=" << Width() << ", " - << "Mean="; for (int i = 0; i < Channels(); ++i) { std::cout << mean[i] << " "; } std::cout << std::endl; + } else if (mat_type == ProcLib::CUDA || mat_type == ProcLib::CVCUDA) { +#ifdef WITH_GPU + FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess, + "[ERROR] Error occurs while sync cuda stream."); + cv::Mat tmp_mat = CreateZeroCopyOpenCVMatFromTensor(fd_tensor); + cv::Scalar mean = cv::mean(tmp_mat); + for (int i = 0; i < Channels(); ++i) { + std::cout << mean[i] << " "; + } + std::cout << std::endl; +#else + FDASSERT(false, "FastDeploy didn't compiled with -DWITH_GPU=ON"); +#endif } } @@ -97,6 +156,8 @@ FDDataType Mat::Type() { "FastDeploy didn't compile with FlyCV, but met data type with " "fcv::Mat."); #endif + } else if (mat_type == ProcLib::CUDA || mat_type == ProcLib::CVCUDA) { + return fd_tensor.Dtype(); } return OpenCVDataTypeToFD(cpu_mat.type()); } @@ -128,50 +189,49 @@ Mat Mat::Create(const FDTensor& tensor, ProcLib lib) { #else FDASSERT(false, "FastDeploy didn't compiled with FlyCV!"); #endif - } + } cv::Mat tmp_ocv_mat = CreateZeroCopyOpenCVMatFromTensor(tensor); Mat mat = Mat(tmp_ocv_mat); return mat; } -Mat Mat::Create(int height, int width, int channels, - FDDataType type, void* data) { +Mat Mat::Create(int height, int width, int channels, FDDataType type, + void* data) { if (DefaultProcLib::default_lib == ProcLib::FLYCV) { #ifdef ENABLE_FLYCV - fcv::Mat tmp_fcv_mat = CreateZeroCopyFlyCVMatFromBuffer( - height, width, channels, type, data); + fcv::Mat tmp_fcv_mat = + CreateZeroCopyFlyCVMatFromBuffer(height, width, channels, type, data); Mat mat = Mat(tmp_fcv_mat); return mat; #else FDASSERT(false, "FastDeploy didn't compiled with FlyCV!"); #endif } - cv::Mat tmp_ocv_mat = CreateZeroCopyOpenCVMatFromBuffer( - height, width, channels, type, data); + cv::Mat tmp_ocv_mat = + CreateZeroCopyOpenCVMatFromBuffer(height, width, channels, type, data); Mat mat = Mat(tmp_ocv_mat); - return mat; + return mat; } -Mat Mat::Create(int height, int width, int channels, - FDDataType type, void* data, - ProcLib lib) { +Mat Mat::Create(int height, int width, int channels, FDDataType type, + void* data, ProcLib lib) { if (lib == ProcLib::DEFAULT) { return Create(height, width, channels, type, data); - } + } if (lib == ProcLib::FLYCV) { #ifdef ENABLE_FLYCV - fcv::Mat tmp_fcv_mat = CreateZeroCopyFlyCVMatFromBuffer( - height, width, channels, type, data); + fcv::Mat tmp_fcv_mat = + CreateZeroCopyFlyCVMatFromBuffer(height, width, channels, type, data); Mat mat = Mat(tmp_fcv_mat); return mat; #else FDASSERT(false, "FastDeploy didn't compiled with FlyCV!"); #endif - } - cv::Mat tmp_ocv_mat = CreateZeroCopyOpenCVMatFromBuffer( - height, width, channels, type, data); + } + cv::Mat tmp_ocv_mat = + CreateZeroCopyOpenCVMatFromBuffer(height, width, channels, type, data); Mat mat = Mat(tmp_ocv_mat); - return mat; + return mat; } FDMat WrapMat(const cv::Mat& image) { diff --git a/fastdeploy/vision/common/processors/mat.h b/fastdeploy/vision/common/processors/mat.h index dc13c823b..568744a04 100644 --- a/fastdeploy/vision/common/processors/mat.h +++ b/fastdeploy/vision/common/processors/mat.h @@ -17,6 +17,10 @@ #include "fastdeploy/vision/common/processors/proc_lib.h" #include "opencv2/core/core.hpp" +#ifdef WITH_GPU +#include +#endif + namespace fastdeploy { namespace vision { @@ -60,24 +64,7 @@ struct FASTDEPLOY_DECL Mat { mat_type = ProcLib::OPENCV; } - cv::Mat* GetOpenCVMat() { - if (mat_type == ProcLib::OPENCV) { - return &cpu_mat; - } else if (mat_type == ProcLib::FLYCV) { -#ifdef ENABLE_FLYCV - // Just a reference to fcv_mat, zero copy. After you - // call this method, cpu_mat and fcv_mat will point - // to the same memory buffer. - cpu_mat = ConvertFlyCVMatToOpenCV(fcv_mat); - mat_type = ProcLib::OPENCV; - return &cpu_mat; -#else - FDASSERT(false, "FastDeploy didn't compiled with FlyCV!"); -#endif - } else { - FDASSERT(false, "The mat_type of custom Mat can not be ProcLib::DEFAULT"); - } - } + cv::Mat* GetOpenCVMat(); #ifdef ENABLE_FLYCV void SetMat(const fcv::Mat& mat) { @@ -103,6 +90,12 @@ struct FASTDEPLOY_DECL Mat { void* Data(); + // Get fd_tensor + FDTensor* Tensor(); + + // Set fd_tensor + void SetTensor(FDTensor* tensor); + private: int channels; int height; @@ -111,6 +104,12 @@ struct FASTDEPLOY_DECL Mat { #ifdef ENABLE_FLYCV fcv::Mat fcv_mat; #endif +#ifdef WITH_GPU + cudaStream_t stream = nullptr; +#endif + // Currently, fd_tensor is only used by CUDA and CV-CUDA, + // OpenCV and FlyCV are not using it. + FDTensor fd_tensor; public: FDDataType Type(); @@ -120,6 +119,10 @@ struct FASTDEPLOY_DECL Mat { void SetChannels(int s) { channels = s; } void SetWidth(int w) { width = w; } void SetHeight(int h) { height = h; } +#ifdef WITH_GPU + cudaStream_t Stream() const { return stream; } + void SetStream(cudaStream_t s) { stream = s; } +#endif // Transfer the vision::Mat to FDTensor void ShareWithTensor(FDTensor* tensor); diff --git a/fastdeploy/vision/common/processors/normalize_and_permute.cu b/fastdeploy/vision/common/processors/normalize_and_permute.cu index fabd01fe6..69bb6af1d 100644 --- a/fastdeploy/vision/common/processors/normalize_and_permute.cu +++ b/fastdeploy/vision/common/processors/normalize_and_permute.cu @@ -37,49 +37,46 @@ __global__ void NormalizeAndPermuteKernel(uint8_t* src, float* dst, } bool NormalizeAndPermute::ImplByCuda(Mat* mat) { - cv::Mat* im = mat->GetOpenCVMat(); - std::string buf_name = Name() + "_src"; - std::vector shape = {im->rows, im->cols, im->channels()}; - FDTensor* src = - UpdateAndGetReusedBuffer(shape, im->type(), buf_name, Device::GPU); - FDASSERT(cudaMemcpy(src->Data(), im->ptr(), src->Nbytes(), - cudaMemcpyHostToDevice) == 0, - "Error occurs while copy memory from CPU to GPU."); + // Prepare input tensor + std::string tensor_name = Name() + "_cvcuda_src"; + FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name); - buf_name = Name() + "_dst"; - FDTensor* dst = UpdateAndGetReusedBuffer(shape, CV_32FC(im->channels()), - buf_name, Device::GPU); - cv::Mat res(im->rows, im->cols, CV_32FC(im->channels()), dst->Data()); + // Prepare output tensor + tensor_name = Name() + "_dst"; + FDTensor* dst = UpdateAndGetCachedTensor(src->Shape(), FDDataType::FP32, + tensor_name, Device::GPU); - buf_name = Name() + "_alpha"; - FDTensor* alpha = UpdateAndGetReusedBuffer({(int64_t)alpha_.size()}, CV_32FC1, - buf_name, Device::GPU); - FDASSERT(cudaMemcpy(alpha->Data(), alpha_.data(), alpha->Nbytes(), - cudaMemcpyHostToDevice) == 0, - "Error occurs while copy memory from CPU to GPU."); + // Copy alpha and beta to GPU + tensor_name = Name() + "_alpha"; + FDMat alpha_mat = + FDMat::Create(1, 1, alpha_.size(), FDDataType::FP32, alpha_.data()); + FDTensor* alpha = CreateCachedGpuInputTensor(&alpha_mat, tensor_name); - buf_name = Name() + "_beta"; - FDTensor* beta = UpdateAndGetReusedBuffer({(int64_t)beta_.size()}, CV_32FC1, - buf_name, Device::GPU); - FDASSERT(cudaMemcpy(beta->Data(), beta_.data(), beta->Nbytes(), - cudaMemcpyHostToDevice) == 0, - "Error occurs while copy memory from CPU to GPU."); + tensor_name = Name() + "_beta"; + FDMat beta_mat = + FDMat::Create(1, 1, beta_.size(), FDDataType::FP32, beta_.data()); + FDTensor* beta = CreateCachedGpuInputTensor(&beta_mat, tensor_name); - int jobs = im->cols * im->rows; + int jobs = mat->Width() * mat->Height(); int threads = 256; int blocks = ceil(jobs / (float)threads); - NormalizeAndPermuteKernel<<>>( + NormalizeAndPermuteKernel<<Stream()>>>( reinterpret_cast(src->Data()), reinterpret_cast(dst->Data()), reinterpret_cast(alpha->Data()), - reinterpret_cast(beta->Data()), im->channels(), swap_rb_, jobs); + reinterpret_cast(beta->Data()), mat->Channels(), swap_rb_, jobs); - mat->SetMat(res); + mat->SetTensor(dst); mat->device = Device::GPU; mat->layout = Layout::CHW; + mat->mat_type = ProcLib::CUDA; return true; } +#ifdef ENABLE_CVCUDA +bool NormalizeAndPermute::ImplByCvCuda(Mat* mat) { return ImplByCuda(mat); } +#endif + } // namespace vision } // namespace fastdeploy #endif diff --git a/fastdeploy/vision/common/processors/normalize_and_permute.h b/fastdeploy/vision/common/processors/normalize_and_permute.h index ea7649d92..ff8394c67 100644 --- a/fastdeploy/vision/common/processors/normalize_and_permute.h +++ b/fastdeploy/vision/common/processors/normalize_and_permute.h @@ -31,6 +31,9 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor { #endif #ifdef WITH_GPU bool ImplByCuda(Mat* mat); +#endif +#ifdef ENABLE_CVCUDA + bool ImplByCvCuda(Mat* mat); #endif std::string Name() { return "NormalizeAndPermute"; } diff --git a/fastdeploy/vision/common/processors/proc_lib.h b/fastdeploy/vision/common/processors/proc_lib.h index 512ed9f83..06ca4a4a5 100644 --- a/fastdeploy/vision/common/processors/proc_lib.h +++ b/fastdeploy/vision/common/processors/proc_lib.h @@ -18,7 +18,7 @@ namespace fastdeploy { namespace vision { -enum class FASTDEPLOY_DECL ProcLib { DEFAULT, OPENCV, FLYCV, CUDA }; +enum class FASTDEPLOY_DECL ProcLib { DEFAULT, OPENCV, FLYCV, CUDA, CVCUDA }; FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out, const ProcLib& p); diff --git a/fastdeploy/vision/common/processors/resize.cc b/fastdeploy/vision/common/processors/resize.cc index c64a67b23..29a8798ad 100644 --- a/fastdeploy/vision/common/processors/resize.cc +++ b/fastdeploy/vision/common/processors/resize.cc @@ -14,6 +14,12 @@ #include "fastdeploy/vision/common/processors/resize.h" +#ifdef ENABLE_CVCUDA +#include + +#include "fastdeploy/vision/common/processors/cvcuda_utils.h" +#endif + namespace fastdeploy { namespace vision { @@ -79,7 +85,7 @@ bool Resize::ImplByFlyCV(Mat* mat) { } else if (interp_ == 2) { interp_method = fcv::InterpolationType::INTER_CUBIC; } else if (interp_ == 3) { - interp_method = fcv::InterpolationType::INTER_AREA; + interp_method = fcv::InterpolationType::INTER_AREA; } else { FDERROR << "Resize: Only support interp_ be 0/1/2/3 with FlyCV, but " "now it's " @@ -116,6 +122,52 @@ bool Resize::ImplByFlyCV(Mat* mat) { } #endif +#ifdef ENABLE_CVCUDA +bool Resize::ImplByCvCuda(Mat* mat) { + if (width_ == mat->Width() && height_ == mat->Height()) { + return true; + } + if (fabs(scale_w_ - 1.0) < 1e-06 && fabs(scale_h_ - 1.0) < 1e-06) { + return true; + } + + if (width_ > 0 && height_ > 0) { + } else if (scale_w_ > 0 && scale_h_ > 0) { + width_ = std::round(scale_w_ * mat->Width()); + height_ = std::round(scale_h_ * mat->Height()); + } else { + FDERROR << "Resize: the parameters must satisfy (width > 0 && height > 0) " + "or (scale_w > 0 && scale_h > 0)." + << std::endl; + return false; + } + + // Prepare input tensor + std::string tensor_name = Name() + "_cvcuda_src"; + FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name); + auto src_tensor = CreateCvCudaTensorWrapData(*src); + + // Prepare output tensor + tensor_name = Name() + "_cvcuda_dst"; + FDTensor* dst = + UpdateAndGetCachedTensor({height_, width_, mat->Channels()}, mat->Type(), + tensor_name, Device::GPU); + auto dst_tensor = CreateCvCudaTensorWrapData(*dst); + + // CV-CUDA Interp value is compatible with OpenCV + cvcuda::Resize resize_op; + resize_op(mat->Stream(), src_tensor, dst_tensor, + NVCVInterpolationType(interp_)); + + mat->SetTensor(dst); + mat->SetWidth(width_); + mat->SetHeight(height_); + mat->device = Device::GPU; + mat->mat_type = ProcLib::CVCUDA; + return true; +} +#endif + bool Resize::Run(Mat* mat, int width, int height, float scale_w, float scale_h, int interp, bool use_scale, ProcLib lib) { if (mat->Height() == height && mat->Width() == width) { diff --git a/fastdeploy/vision/common/processors/resize.h b/fastdeploy/vision/common/processors/resize.h index e6a4ba1b0..54480108b 100644 --- a/fastdeploy/vision/common/processors/resize.h +++ b/fastdeploy/vision/common/processors/resize.h @@ -34,6 +34,9 @@ class FASTDEPLOY_DECL Resize : public Processor { bool ImplByOpenCV(Mat* mat); #ifdef ENABLE_FLYCV bool ImplByFlyCV(Mat* mat); +#endif +#ifdef ENABLE_CVCUDA + bool ImplByCvCuda(Mat* mat); #endif std::string Name() { return "Resize"; } diff --git a/fastdeploy/vision/common/processors/resize_by_short.cc b/fastdeploy/vision/common/processors/resize_by_short.cc index 2dbbefd29..1d6309f5d 100644 --- a/fastdeploy/vision/common/processors/resize_by_short.cc +++ b/fastdeploy/vision/common/processors/resize_by_short.cc @@ -14,6 +14,12 @@ #include "fastdeploy/vision/common/processors/resize_by_short.h" +#ifdef ENABLE_CVCUDA +#include + +#include "fastdeploy/vision/common/processors/cvcuda_utils.h" +#endif + namespace fastdeploy { namespace vision { @@ -51,7 +57,7 @@ bool ResizeByShort::ImplByFlyCV(Mat* mat) { } else if (interp_ == 2) { interp_method = fcv::InterpolationType::INTER_CUBIC; } else if (interp_ == 3) { - interp_method = fcv::InterpolationType::INTER_AREA; + interp_method = fcv::InterpolationType::INTER_AREA; } else { FDERROR << "LimitByShort: Only support interp_ be 0/1/2/3 with FlyCV, but " "now it's " @@ -80,6 +86,37 @@ bool ResizeByShort::ImplByFlyCV(Mat* mat) { } #endif +#ifdef ENABLE_CVCUDA +bool ResizeByShort::ImplByCvCuda(Mat* mat) { + // Prepare input tensor + std::string tensor_name = Name() + "_cvcuda_src"; + FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name); + auto src_tensor = CreateCvCudaTensorWrapData(*src); + + double scale = GenerateScale(mat->Width(), mat->Height()); + int width = static_cast(round(scale * mat->Width())); + int height = static_cast(round(scale * mat->Height())); + + // Prepare output tensor + tensor_name = Name() + "_cvcuda_dst"; + FDTensor* dst = UpdateAndGetCachedTensor( + {height, width, mat->Channels()}, mat->Type(), tensor_name, Device::GPU); + auto dst_tensor = CreateCvCudaTensorWrapData(*dst); + + // CV-CUDA Interp value is compatible with OpenCV + cvcuda::Resize resize_op; + resize_op(mat->Stream(), src_tensor, dst_tensor, + NVCVInterpolationType(interp_)); + + mat->SetTensor(dst); + mat->SetWidth(width); + mat->SetHeight(height); + mat->device = Device::GPU; + mat->mat_type = ProcLib::CVCUDA; + return true; +} +#endif + double ResizeByShort::GenerateScale(const int origin_w, const int origin_h) { int im_size_max = std::max(origin_w, origin_h); int im_size_min = std::min(origin_w, origin_h); diff --git a/fastdeploy/vision/common/processors/resize_by_short.h b/fastdeploy/vision/common/processors/resize_by_short.h index 151605beb..64a7f09f0 100644 --- a/fastdeploy/vision/common/processors/resize_by_short.h +++ b/fastdeploy/vision/common/processors/resize_by_short.h @@ -31,6 +31,9 @@ class FASTDEPLOY_DECL ResizeByShort : public Processor { bool ImplByOpenCV(Mat* mat); #ifdef ENABLE_FLYCV bool ImplByFlyCV(Mat* mat); +#endif +#ifdef ENABLE_CVCUDA + bool ImplByCvCuda(Mat* mat); #endif std::string Name() { return "ResizeByShort"; } diff --git a/python/fastdeploy/vision/classification/ppcls/__init__.py b/python/fastdeploy/vision/classification/ppcls/__init__.py index b88c4361f..455702271 100644 --- a/python/fastdeploy/vision/classification/ppcls/__init__.py +++ b/python/fastdeploy/vision/classification/ppcls/__init__.py @@ -35,12 +35,13 @@ class PaddleClasPreprocessor: """ return self._preprocessor.run(input_ims) - def use_gpu(self, gpu_id=-1): + def use_cuda(self, enable_cv_cuda=False, gpu_id=-1): """Use CUDA preprocessors + :param: enable_cv_cuda: Whether to enable CV-CUDA :param: gpu_id: GPU device id """ - return self._preprocessor.use_gpu(gpu_id) + return self._preprocessor.use_cuda(enable_cv_cuda, gpu_id) def disable_normalize(self): """ @@ -52,7 +53,7 @@ class PaddleClasPreprocessor: """ This function will disable hwc2chw in preprocessing step. """ - self._preprocessor.disable_permute() + self._preprocessor.disable_permute() class PaddleClasPostprocessor: