[CVCUDA] CMake integration, vison processor CV-CUDA integration, PaddleClas support CV-CUDA (#1074)

* cvcuda resize

* cvcuda center crop

* cvcuda resize

* add a fdtensor in fdmat

* get cv mat and get tensor support gpu

* paddleclas cvcuda preprocessor

* fix compile err

* fix windows compile error

* rename reused to cached

* address comment

* remove debug code

* add comment

* add manager run

* use cuda and cuda used

* use cv cuda doc

* address comment

---------

Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
Wang Xinyu
2023-01-30 09:33:49 +08:00
committed by GitHub
parent 0c735e9c0b
commit 62e051e21d
26 changed files with 814 additions and 216 deletions

View File

@@ -66,6 +66,7 @@ option(ENABLE_LITE_BACKEND "Whether to enable paddle lite backend." OFF)
option(ENABLE_VISION "Whether to enable vision models usage." OFF) option(ENABLE_VISION "Whether to enable vision models usage." OFF)
option(ENABLE_TEXT "Whether to enable text models usage." OFF) option(ENABLE_TEXT "Whether to enable text models usage." OFF)
option(ENABLE_FLYCV "Whether to enable flycv to boost image preprocess." OFF) option(ENABLE_FLYCV "Whether to enable flycv to boost image preprocess." OFF)
option(ENABLE_CVCUDA "Whether to enable NVIDIA CV-CUDA to boost image preprocess." OFF)
option(ENABLE_ENCRYPTION "Whether to enable ENCRYPTION." OFF) option(ENABLE_ENCRYPTION "Whether to enable ENCRYPTION." OFF)
option(WITH_ASCEND "Whether to compile for Huawei Ascend deploy." OFF) option(WITH_ASCEND "Whether to compile for Huawei Ascend deploy." OFF)
option(WITH_TIMVX "Whether to compile for TIMVX deploy." OFF) option(WITH_TIMVX "Whether to compile for TIMVX deploy." OFF)
@@ -373,6 +374,12 @@ if(ENABLE_VISION)
include(${PROJECT_SOURCE_DIR}/cmake/flycv.cmake) include(${PROJECT_SOURCE_DIR}/cmake/flycv.cmake)
list(APPEND DEPEND_LIBS external_flycv) list(APPEND DEPEND_LIBS external_flycv)
endif() endif()
if(ENABLE_CVCUDA)
include(${PROJECT_SOURCE_DIR}/cmake/cvcuda.cmake)
add_definitions(-DENABLE_CVCUDA)
list(APPEND DEPEND_LIBS nvcv_types cvcuda)
endif()
endif() endif()
if(ENABLE_TEXT) if(ENABLE_TEXT)

View File

@@ -13,6 +13,7 @@ set(ENABLE_TRT_BACKEND @ENABLE_TRT_BACKEND@)
set(ENABLE_PADDLE2ONNX @ENABLE_PADDLE2ONNX@) set(ENABLE_PADDLE2ONNX @ENABLE_PADDLE2ONNX@)
set(ENABLE_VISION @ENABLE_VISION@) set(ENABLE_VISION @ENABLE_VISION@)
set(ENABLE_FLYCV @ENABLE_FLYCV@) set(ENABLE_FLYCV @ENABLE_FLYCV@)
set(ENABLE_CVCUDA @ENABLE_CVCUDA@)
set(ENABLE_TEXT @ENABLE_TEXT@) set(ENABLE_TEXT @ENABLE_TEXT@)
set(ENABLE_ENCRYPTION @ENABLE_ENCRYPTION@) set(ENABLE_ENCRYPTION @ENABLE_ENCRYPTION@)
set(BUILD_ON_JETSON @BUILD_ON_JETSON@) set(BUILD_ON_JETSON @BUILD_ON_JETSON@)
@@ -140,6 +141,7 @@ if(WITH_GPU)
message(FATAL_ERROR "[FastDeploy] Cannot find library cudart in ${CUDA_DIRECTORY}, Please define CUDA_DIRECTORY, e.g -DCUDA_DIRECTORY=/path/to/cuda") message(FATAL_ERROR "[FastDeploy] Cannot find library cudart in ${CUDA_DIRECTORY}, Please define CUDA_DIRECTORY, e.g -DCUDA_DIRECTORY=/path/to/cuda")
endif() endif()
list(APPEND FASTDEPLOY_LIBS ${CUDA_LIB}) list(APPEND FASTDEPLOY_LIBS ${CUDA_LIB})
list(APPEND FASTDEPLOY_INCS ${CUDA_DIRECTORY}/include)
if (ENABLE_TRT_BACKEND) if (ENABLE_TRT_BACKEND)
if(BUILD_ON_JETSON) if(BUILD_ON_JETSON)
@@ -218,6 +220,12 @@ if(ENABLE_VISION)
endif() endif()
endif() endif()
if(ENABLE_CVCUDA)
find_library(CVCUDA_LIB cvcuda ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/cvcuda/lib NO_DEFAULT_PATH)
find_library(NVCV_TYPES_LIB nvcv_types ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/cvcuda/lib NO_DEFAULT_PATH)
list(APPEND FASTDEPLOY_LIBS ${CVCUDA_LIB} ${NVCV_TYPES_LIB})
endif()
endif() endif()
if (ENABLE_TEXT) if (ENABLE_TEXT)
@@ -288,6 +296,7 @@ if(ENABLE_OPENVINO_BACKEND)
endif() endif()
message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}") message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}")
message(STATUS " ENABLE_VISION : ${ENABLE_VISION}") message(STATUS " ENABLE_VISION : ${ENABLE_VISION}")
message(STATUS " ENABLE_CVCUDA : ${ENABLE_CVCUDA}")
message(STATUS " ENABLE_TEXT : ${ENABLE_TEXT}") message(STATUS " ENABLE_TEXT : ${ENABLE_TEXT}")
message(STATUS " ENABLE_ENCRYPTION : ${ENABLE_ENCRYPTION}") message(STATUS " ENABLE_ENCRYPTION : ${ENABLE_ENCRYPTION}")
if(WITH_GPU) if(WITH_GPU)

43
cmake/cvcuda.cmake Normal file
View File

@@ -0,0 +1,43 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if(NOT WITH_GPU)
message(FATAL_ERROR "ENABLE_CVCUDA is available on Linux and WITH_GPU=ON, but now WITH_GPU=OFF.")
endif()
if(APPLE OR ANDROID OR IOS OR WIN32)
message(FATAL_ERROR "Cannot enable CV-CUDA in mac/ios/android/windows os, please set -DENABLE_CVCUDA=OFF.")
endif()
if(NOT (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64"))
message(FATAL_ERROR "CV-CUDA only support x86_64.")
endif()
set(CVCUDA_LIB_URL https://github.com/CVCUDA/CV-CUDA/releases/download/v0.2.0-alpha/nvcv-lib-0.2.0_alpha-cuda11-x86_64-linux.tar.xz)
set(CVCUDA_LIB_FILENAME nvcv-lib-0.2.0_alpha-cuda11-x86_64-linux.tar.xz)
set(CVCUDA_DEV_URL https://github.com/CVCUDA/CV-CUDA/releases/download/v0.2.0-alpha/nvcv-dev-0.2.0_alpha-cuda11-x86_64-linux.tar.xz)
set(CVCUDA_DEV_FILENAME nvcv-dev-0.2.0_alpha-cuda11-x86_64-linux.tar.xz)
download_and_decompress(${CVCUDA_LIB_URL} ${CMAKE_CURRENT_BINARY_DIR}/${CVCUDA_LIB_FILENAME} ${THIRD_PARTY_PATH}/cvcuda)
download_and_decompress(${CVCUDA_DEV_URL} ${CMAKE_CURRENT_BINARY_DIR}/${CVCUDA_DEV_FILENAME} ${THIRD_PARTY_PATH}/cvcuda)
execute_process(COMMAND rm -rf ${THIRD_PARTY_PATH}/install/cvcuda)
execute_process(COMMAND mkdir -p ${THIRD_PARTY_PATH}/install/cvcuda)
execute_process(COMMAND cp -r ${THIRD_PARTY_PATH}/cvcuda/opt/nvidia/cvcuda0/lib/x86_64-linux-gnu/ ${THIRD_PARTY_PATH}/install/cvcuda/lib)
execute_process(COMMAND cp -r ${THIRD_PARTY_PATH}/cvcuda/opt/nvidia/cvcuda0/include/ ${THIRD_PARTY_PATH}/install/cvcuda/include)
link_directories(${THIRD_PARTY_PATH}/install/cvcuda/lib)
include_directories(${THIRD_PARTY_PATH}/install/cvcuda/include)
set(CMAKE_CXX_STANDARD 17)

View File

@@ -0,0 +1,39 @@
# 使用CV-CUDA/CUDA加速GPU端到端推理性能
FastDeploy集成了CV-CUDA来加速预/后处理个别CV-CUDA不支持的算子使用了CUDA kernel的方式实现。
FastDeploy的Vision Processor模块对CV-CUDA的算子做了进一步的封装用户不需要自己去调用CV-CUDA
使用FastDeploy的模型推理接口即可利用CV-CUDA的加速能力。
FastDeploy的Vision Processor模块在集成CV-CUDA时做了以下工作来方便用户的使用
- GPU内存管理缓存算子的输入、输出tensor避免重复分配GPU内存
- CV-CUDA不支持的个别算子利用CUDA kernel实现
- CV-CUDA/CUDA不支持的算子可以fallback到OpenCV/FlyCV
## 使用方式
编译FastDeploy时开启CV-CUDA编译选项
```bash
# 编译C++预测库时, 开启CV-CUDA编译选项.
-DENABLE_CVCUDA=ON \
# 在编译Python预测库时, 开启CV-CUDA编译选项
export ENABLE_CVCUDA=ON
```
只有继承了ProcessorManager类的模型预处理才可以使用CV-CUDA这里以PaddleClasPreprocessor为例
```bash
# C++
# 创建model之后调用model preprocessor的UseCuda接口即可打开CV-CUDA/CUDA预处理
# 第一个参数enable_cv_cudatrue代表使用CV-CUDAfalse代表只使用CUDA支持的算子较少
# 第二个参数是GPU id-1代表不指定使用当前GPU
model.GetPreprocessor().UseCuda(true, 0);
# Python
model.preprocessor.use_cuda(True, 0)
```
## 最佳实践
- 如果预处理第一个算子是resize则要根据实际情况决定resize是否跑在GPU。因为当resize跑在GPU
且图片解码在CPU时需要把原图copy到GPU内存开销较大而resize之后再copy到GPU内存则往往只需要
copy较少的数据。

View File

@@ -70,7 +70,7 @@ class TritonPythonModel:
yaml_path) yaml_path)
if args['model_instance_kind'] == 'GPU': if args['model_instance_kind'] == 'GPU':
device_id = int(args['model_instance_device_id']) device_id = int(args['model_instance_device_id'])
self.preprocess_.use_gpu(device_id) self.preprocess_.use_cuda(False, device_id)
def execute(self, requests): def execute(self, requests):
"""`execute` must be implemented in every Python model. `execute` """`execute` must be implemented in every Python model. `execute`

View File

@@ -18,67 +18,89 @@ void BindPaddleClas(pybind11::module& m) {
pybind11::class_<vision::classification::PaddleClasPreprocessor>( pybind11::class_<vision::classification::PaddleClasPreprocessor>(
m, "PaddleClasPreprocessor") m, "PaddleClasPreprocessor")
.def(pybind11::init<std::string>()) .def(pybind11::init<std::string>())
.def("run", [](vision::classification::PaddleClasPreprocessor& self, std::vector<pybind11::array>& im_list) { .def("run",
[](vision::classification::PaddleClasPreprocessor& self,
std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images; std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) { for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
} }
std::vector<FDTensor> outputs; std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) { if (!self.Run(&images, &outputs)) {
throw std::runtime_error("Failed to preprocess the input data in PaddleClasPreprocessor."); throw std::runtime_error(
"Failed to preprocess the input data in "
"PaddleClasPreprocessor.");
} }
if (!self.WithGpu()) { if (!self.CudaUsed()) {
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing(); outputs[i].StopSharing();
} }
} }
return outputs; return outputs;
}) })
.def("use_gpu", [](vision::classification::PaddleClasPreprocessor& self, int gpu_id = -1) { .def("use_cuda",
self.UseGpu(gpu_id); [](vision::classification::PaddleClasPreprocessor& self,
}) bool enable_cv_cuda = false,
.def("disable_normalize", [](vision::classification::PaddleClasPreprocessor& self) { int gpu_id = -1) { self.UseCuda(enable_cv_cuda, gpu_id); })
.def("disable_normalize",
[](vision::classification::PaddleClasPreprocessor& self) {
self.DisableNormalize(); self.DisableNormalize();
}) })
.def("disable_permute", [](vision::classification::PaddleClasPreprocessor& self) { .def("disable_permute",
[](vision::classification::PaddleClasPreprocessor& self) {
self.DisablePermute(); self.DisablePermute();
}); });
pybind11::class_<vision::classification::PaddleClasPostprocessor>( pybind11::class_<vision::classification::PaddleClasPostprocessor>(
m, "PaddleClasPostprocessor") m, "PaddleClasPostprocessor")
.def(pybind11::init<int>()) .def(pybind11::init<int>())
.def("run", [](vision::classification::PaddleClasPostprocessor& self, std::vector<FDTensor>& inputs) { .def("run",
[](vision::classification::PaddleClasPostprocessor& self,
std::vector<FDTensor>& inputs) {
std::vector<vision::ClassifyResult> results; std::vector<vision::ClassifyResult> results;
if (!self.Run(inputs, &results)) { if (!self.Run(inputs, &results)) {
throw std::runtime_error("Failed to postprocess the runtime result in PaddleClasPostprocessor."); throw std::runtime_error(
"Failed to postprocess the runtime result in "
"PaddleClasPostprocessor.");
} }
return results; return results;
}) })
.def("run", [](vision::classification::PaddleClasPostprocessor& self, std::vector<pybind11::array>& input_array) { .def("run",
[](vision::classification::PaddleClasPostprocessor& self,
std::vector<pybind11::array>& input_array) {
std::vector<vision::ClassifyResult> results; std::vector<vision::ClassifyResult> results;
std::vector<FDTensor> inputs; std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
if (!self.Run(inputs, &results)) { if (!self.Run(inputs, &results)) {
throw std::runtime_error("Failed to postprocess the runtime result in PaddleClasPostprocessor."); throw std::runtime_error(
"Failed to postprocess the runtime result in "
"PaddleClasPostprocessor.");
} }
return results; return results;
}) })
.def_property("topk", &vision::classification::PaddleClasPostprocessor::GetTopk, &vision::classification::PaddleClasPostprocessor::SetTopk); .def_property("topk",
&vision::classification::PaddleClasPostprocessor::GetTopk,
&vision::classification::PaddleClasPostprocessor::SetTopk);
pybind11::class_<vision::classification::PaddleClasModel, FastDeployModel>( pybind11::class_<vision::classification::PaddleClasModel, FastDeployModel>(
m, "PaddleClasModel") m, "PaddleClasModel")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption, .def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>()) ModelFormat>())
.def("clone", [](vision::classification::PaddleClasModel& self) { .def("clone",
[](vision::classification::PaddleClasModel& self) {
return self.Clone(); return self.Clone();
}) })
.def("predict", [](vision::classification::PaddleClasModel& self, pybind11::array& data) { .def("predict",
[](vision::classification::PaddleClasModel& self,
pybind11::array& data) {
cv::Mat im = PyArrayToCvMat(data); cv::Mat im = PyArrayToCvMat(data);
vision::ClassifyResult result; vision::ClassifyResult result;
self.Predict(im, &result); self.Predict(im, &result);
return result; return result;
}) })
.def("batch_predict", [](vision::classification::PaddleClasModel& self, std::vector<pybind11::array>& data) { .def("batch_predict",
[](vision::classification::PaddleClasModel& self,
std::vector<pybind11::array>& data) {
std::vector<cv::Mat> images; std::vector<cv::Mat> images;
for (size_t i = 0; i < data.size(); ++i) { for (size_t i = 0; i < data.size(); ++i) {
images.push_back(PyArrayToCvMat(data[i])); images.push_back(PyArrayToCvMat(data[i]));
@@ -87,7 +109,11 @@ void BindPaddleClas(pybind11::module& m) {
self.BatchPredict(images, &results); self.BatchPredict(images, &results);
return results; return results;
}) })
.def_property_readonly("preprocessor", &vision::classification::PaddleClasModel::GetPreprocessor) .def_property_readonly(
.def_property_readonly("postprocessor", &vision::classification::PaddleClasModel::GetPostprocessor); "preprocessor",
&vision::classification::PaddleClasModel::GetPreprocessor)
.def_property_readonly(
"postprocessor",
&vision::classification::PaddleClasModel::GetPostprocessor);
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -13,11 +13,9 @@
// limitations under the License. // limitations under the License.
#include "fastdeploy/vision/classification/ppcls/preprocessor.h" #include "fastdeploy/vision/classification/ppcls/preprocessor.h"
#include "fastdeploy/function/concat.h" #include "fastdeploy/function/concat.h"
#include "yaml-cpp/yaml.h" #include "yaml-cpp/yaml.h"
#ifdef WITH_GPU
#include <cuda_runtime_api.h>
#endif
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -61,7 +59,8 @@ bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig() {
auto mean = op.begin()->second["mean"].as<std::vector<float>>(); auto mean = op.begin()->second["mean"].as<std::vector<float>>();
auto std = op.begin()->second["std"].as<std::vector<float>>(); auto std = op.begin()->second["std"].as<std::vector<float>>();
auto scale = op.begin()->second["scale"].as<float>(); auto scale = op.begin()->second["scale"].as<float>();
FDASSERT((scale - 0.00392157) < 1e-06 && (scale - 0.00392157) > -1e-06, FDASSERT(
(scale - 0.00392157) < 1e-06 && (scale - 0.00392157) > -1e-06,
"Only support scale in Normalize be 0.00392157, means the pixel " "Only support scale in Normalize be 0.00392157, means the pixel "
"is in range of [0, 255]."); "is in range of [0, 255].");
processors_.push_back(std::make_shared<Normalize>(mean, std)); processors_.push_back(std::make_shared<Normalize>(mean, std));
@@ -84,53 +83,32 @@ bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig() {
void PaddleClasPreprocessor::DisableNormalize() { void PaddleClasPreprocessor::DisableNormalize() {
this->disable_normalize_ = true; this->disable_normalize_ = true;
// the DisableNormalize function will be invalid if the configuration file is loaded during preprocessing // the DisableNormalize function will be invalid if the configuration file is
// loaded during preprocessing
if (!BuildPreprocessPipelineFromConfig()) { if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl; FDERROR << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
} }
} }
void PaddleClasPreprocessor::DisablePermute() { void PaddleClasPreprocessor::DisablePermute() {
this->disable_permute_ = true; this->disable_permute_ = true;
// the DisablePermute function will be invalid if the configuration file is loaded during preprocessing // the DisablePermute function will be invalid if the configuration file is
// loaded during preprocessing
if (!BuildPreprocessPipelineFromConfig()) { if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl; FDERROR << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
} }
} }
void PaddleClasPreprocessor::UseGpu(int gpu_id) { bool PaddleClasPreprocessor::Apply(std::vector<FDMat>* images,
#ifdef WITH_GPU std::vector<FDTensor>* outputs) {
use_cuda_ = true;
if (gpu_id < 0) return;
device_id_ = gpu_id;
cudaSetDevice(device_id_);
#else
FDWARNING << "FastDeploy didn't compile with WITH_GPU. "
<< "Will force to use CPU to run preprocessing." << std::endl;
use_cuda_ = false;
#endif
}
bool PaddleClasPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0." << std::endl;
return false;
}
for (size_t i = 0; i < images->size(); ++i) { for (size_t i = 0; i < images->size(); ++i) {
for (size_t j = 0; j < processors_.size(); ++j) { for (size_t j = 0; j < processors_.size(); ++j) {
bool ret = false; bool ret = false;
if (processors_[j]->Name() == "NormalizeAndPermute" && use_cuda_) {
ret = (*(processors_[j].get()))(&((*images)[i]), ProcLib::CUDA);
} else {
ret = (*(processors_[j].get()))(&((*images)[i])); ret = (*(processors_[j].get()))(&((*images)[i]));
}
if (!ret) { if (!ret) {
FDERROR << "Failed to processs image:" << i << " in " FDERROR << "Failed to processs image:" << i << " in "
<< processors_[i]->Name() << "." << std::endl; << processors_[j]->Name() << "." << std::endl;
return false; return false;
} }
} }
@@ -148,7 +126,7 @@ bool PaddleClasPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTenso
} else { } else {
function::Concat(tensors, &((*outputs)[0]), 0); function::Concat(tensors, &((*outputs)[0]), 0);
} }
(*outputs)[0].device_id = device_id_; (*outputs)[0].device_id = DeviceId();
return true; return true;
} }

View File

@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "fastdeploy/vision/common/processors/manager.h"
#include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h" #include "fastdeploy/vision/common/result.h"
@@ -22,7 +23,7 @@ namespace vision {
namespace classification { namespace classification {
/*! @brief Preprocessor object for PaddleClas serials model. /*! @brief Preprocessor object for PaddleClas serials model.
*/ */
class FASTDEPLOY_DECL PaddleClasPreprocessor { class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager {
public: public:
/** \brief Create a preprocessor instance for PaddleClas serials model /** \brief Create a preprocessor instance for PaddleClas serials model
* *
@@ -36,15 +37,8 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor {
* \param[in] outputs The output tensors which will feed in runtime * \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false * \return true if the preprocess successed, otherwise false
*/ */
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs); virtual bool Apply(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs);
/** \brief Use GPU to run preprocessing
*
* \param[in] gpu_id GPU device id
*/
void UseGpu(int gpu_id = -1);
bool WithGpu() { return use_cuda_; }
/// This function will disable normalize in preprocessing step. /// This function will disable normalize in preprocessing step.
void DisableNormalize(); void DisableNormalize();
@@ -54,10 +48,6 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor {
private: private:
bool BuildPreprocessPipelineFromConfig(); bool BuildPreprocessPipelineFromConfig();
std::vector<std::shared_ptr<Processor>> processors_; std::vector<std::shared_ptr<Processor>> processors_;
bool initialized_ = false;
bool use_cuda_ = false;
// GPU device id
int device_id_ = -1;
// for recording the switch of hwc2chw // for recording the switch of hwc2chw
bool disable_permute_ = false; bool disable_permute_ = false;
// for recording the switch of normalize // for recording the switch of normalize

View File

@@ -13,9 +13,9 @@
// limitations under the License. // limitations under the License.
#include "fastdeploy/vision/common/processors/base.h" #include "fastdeploy/vision/common/processors/base.h"
#include "fastdeploy/vision/common/processors/proc_lib.h"
#include "fastdeploy/utils/utils.h" #include "fastdeploy/utils/utils.h"
#include "fastdeploy/vision/common/processors/proc_lib.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -33,27 +33,58 @@ bool Processor::operator()(Mat* mat, ProcLib lib) {
#endif #endif
} else if (target == ProcLib::CUDA) { } else if (target == ProcLib::CUDA) {
#ifdef WITH_GPU #ifdef WITH_GPU
FDASSERT(mat->Stream() != nullptr,
"CUDA processor requires cuda stream, please set stream for Mat");
return ImplByCuda(mat); return ImplByCuda(mat);
#else #else
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU."); FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
#endif
} else if (target == ProcLib::CVCUDA) {
#ifdef ENABLE_CVCUDA
FDASSERT(mat->Stream() != nullptr,
"CV-CUDA requires cuda stream, please set stream for Mat");
return ImplByCvCuda(mat);
#else
FDASSERT(false, "FastDeploy didn't compile with CV-CUDA.");
#endif #endif
} }
// DEFAULT & OPENCV // DEFAULT & OPENCV
return ImplByOpenCV(mat); return ImplByOpenCV(mat);
} }
FDTensor* Processor::UpdateAndGetReusedBuffer( FDTensor* Processor::UpdateAndGetCachedTensor(
const std::vector<int64_t>& new_shape, const int& opencv_dtype, const std::vector<int64_t>& new_shape, const FDDataType& data_type,
const std::string& buffer_name, const Device& new_device, const std::string& tensor_name, const Device& new_device,
const bool& use_pinned_memory) { const bool& use_pinned_memory) {
if (reused_buffers_.count(buffer_name) == 0) { if (cached_tensors_.count(tensor_name) == 0) {
reused_buffers_[buffer_name] = FDTensor(); cached_tensors_[tensor_name] = FDTensor();
} }
reused_buffers_[buffer_name].is_pinned_memory = use_pinned_memory; cached_tensors_[tensor_name].is_pinned_memory = use_pinned_memory;
reused_buffers_[buffer_name].Resize(new_shape, cached_tensors_[tensor_name].Resize(new_shape, data_type, tensor_name,
OpenCVDataTypeToFD(opencv_dtype), new_device);
buffer_name, new_device); return &cached_tensors_[tensor_name];
return &reused_buffers_[buffer_name]; }
FDTensor* Processor::CreateCachedGpuInputTensor(
Mat* mat, const std::string& tensor_name) {
#ifdef WITH_GPU
FDTensor* src = mat->Tensor();
if (src->device == Device::GPU) {
return src;
} else if (src->device == Device::CPU) {
FDTensor* tensor = UpdateAndGetCachedTensor(src->Shape(), src->Dtype(),
tensor_name, Device::GPU);
FDASSERT(cudaMemcpyAsync(tensor->Data(), src->Data(), tensor->Nbytes(),
cudaMemcpyHostToDevice, mat->Stream()) == 0,
"[ERROR] Error occurs while copy memory from CPU to GPU.");
return tensor;
} else {
FDASSERT(false, "FDMat is on unsupported device: %d", src->device);
}
#else
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
#endif
return nullptr;
} }
void EnableFlyCV() { void EnableFlyCV() {

View File

@@ -59,16 +59,33 @@ class FASTDEPLOY_DECL Processor {
return ImplByOpenCV(mat); return ImplByOpenCV(mat);
} }
virtual bool ImplByCvCuda(Mat* mat) {
return ImplByOpenCV(mat);
}
virtual bool operator()(Mat* mat, ProcLib lib = ProcLib::DEFAULT); virtual bool operator()(Mat* mat, ProcLib lib = ProcLib::DEFAULT);
protected: protected:
FDTensor* UpdateAndGetReusedBuffer( // Update and get the cached tensor from the cached_tensors_ map.
const std::vector<int64_t>& new_shape, const int& opencv_dtype, // The tensor is indexed by a string.
const std::string& buffer_name, const Device& new_device = Device::CPU, // If the tensor doesn't exists in the map, then create a new tensor.
// If the tensor exists and shape is getting larger, then realloc the buffer.
// If the tensor exists and shape is not getting larger, then return the
// cached tensor directly.
FDTensor* UpdateAndGetCachedTensor(
const std::vector<int64_t>& new_shape, const FDDataType& data_type,
const std::string& tensor_name, const Device& new_device = Device::CPU,
const bool& use_pinned_memory = false); const bool& use_pinned_memory = false);
// Create an input tensor on GPU and save into cached_tensors_.
// If the Mat is on GPU, return the mat->Tensor() directly.
// If the Mat is on CPU, then create a cached GPU tensor and copy the mat's
// CPU tensor to this new GPU tensor.
FDTensor* CreateCachedGpuInputTensor(Mat* mat,
const std::string& tensor_name);
private: private:
std::unordered_map<std::string, FDTensor> reused_buffers_; std::unordered_map<std::string, FDTensor> cached_tensors_;
}; };
} // namespace vision } // namespace vision

View File

@@ -14,6 +14,12 @@
#include "fastdeploy/vision/common/processors/center_crop.h" #include "fastdeploy/vision/common/processors/center_crop.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpCustomCrop.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -56,6 +62,35 @@ bool CenterCrop::ImplByFlyCV(Mat* mat) {
} }
#endif #endif
#ifdef ENABLE_CVCUDA
bool CenterCrop::ImplByCvCuda(Mat* mat) {
// Prepare input tensor
std::string tensor_name = Name() + "_cvcuda_src";
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
auto src_tensor = CreateCvCudaTensorWrapData(*src);
// Prepare output tensor
tensor_name = Name() + "_cvcuda_dst";
FDTensor* dst =
UpdateAndGetCachedTensor({height_, width_, mat->Channels()}, src->Dtype(),
tensor_name, Device::GPU);
auto dst_tensor = CreateCvCudaTensorWrapData(*dst);
int offset_x = static_cast<int>((mat->Width() - width_) / 2);
int offset_y = static_cast<int>((mat->Height() - height_) / 2);
cvcuda::CustomCrop crop_op;
NVCVRectI crop_roi = {offset_x, offset_y, width_, height_};
crop_op(mat->Stream(), src_tensor, dst_tensor, crop_roi);
mat->SetTensor(dst);
mat->SetWidth(width_);
mat->SetHeight(height_);
mat->device = Device::GPU;
mat->mat_type = ProcLib::CVCUDA;
return true;
}
#endif
bool CenterCrop::Run(Mat* mat, const int& width, const int& height, bool CenterCrop::Run(Mat* mat, const int& width, const int& height,
ProcLib lib) { ProcLib lib) {
auto c = CenterCrop(width, height); auto c = CenterCrop(width, height);

View File

@@ -25,6 +25,9 @@ class FASTDEPLOY_DECL CenterCrop : public Processor {
bool ImplByOpenCV(Mat* mat); bool ImplByOpenCV(Mat* mat);
#ifdef ENABLE_FLYCV #ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat); bool ImplByFlyCV(Mat* mat);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(Mat* mat);
#endif #endif
std::string Name() { return "CenterCrop"; } std::string Name() { return "CenterCrop"; }

View File

@@ -0,0 +1,76 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
namespace fastdeploy {
namespace vision {
#ifdef ENABLE_CVCUDA
nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel) {
FDASSERT(channel == 1 || channel == 3 || channel == 4,
"Only support channel be 1/3/4 in CV-CUDA.");
if (type == FDDataType::UINT8) {
if (channel == 1) {
return nvcv::FMT_U8;
} else if (channel == 3) {
return nvcv::FMT_BGR8;
} else {
return nvcv::FMT_BGRA8;
}
} else if (type == FDDataType::FP32) {
if (channel == 1) {
return nvcv::FMT_F32;
} else if (channel == 3) {
return nvcv::FMT_BGRf32;
} else {
return nvcv::FMT_BGRAf32;
}
}
FDASSERT(false, "Data type of %s is not supported.", Str(type).c_str());
return nvcv::FMT_BGRf32;
}
nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor) {
FDASSERT(tensor.shape.size() == 3,
"When create CVCUDA tensor from FD tensor,"
"tensor shape should be 3-Dim, HWC layout");
int batchsize = 1;
nvcv::TensorDataStridedCuda::Buffer buf;
buf.strides[3] = FDDataTypeSize(tensor.Dtype());
buf.strides[2] = tensor.shape[2] * buf.strides[3];
buf.strides[1] = tensor.shape[1] * buf.strides[2];
buf.strides[0] = tensor.shape[0] * buf.strides[1];
buf.basePtr = reinterpret_cast<NVCVByte*>(const_cast<void*>(tensor.Data()));
nvcv::Tensor::Requirements req = nvcv::Tensor::CalcRequirements(
batchsize, {tensor.shape[1], tensor.shape[0]},
CreateCvCudaImageFormat(tensor.Dtype(), tensor.shape[2]));
nvcv::TensorDataStridedCuda tensor_data(
nvcv::TensorShape{req.shape, req.rank, req.layout},
nvcv::DataType{req.dtype}, buf);
return nvcv::TensorWrapData(tensor_data);
}
void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor) {
auto data =
dynamic_cast<const nvcv::ITensorDataStridedCuda*>(tensor.exportData());
return reinterpret_cast<void*>(data->basePtr());
}
#endif
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,31 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_tensor.h"
#ifdef ENABLE_CVCUDA
#include "nvcv/Tensor.hpp"
namespace fastdeploy {
namespace vision {
nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel);
nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor);
void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor);
}
}
#endif

View File

@@ -0,0 +1,80 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/common/processors/manager.h"
namespace fastdeploy {
namespace vision {
ProcessorManager::~ProcessorManager() {
#ifdef WITH_GPU
if (stream_) cudaStreamDestroy(stream_);
#endif
}
void ProcessorManager::UseCuda(bool enable_cv_cuda, int gpu_id) {
#ifdef WITH_GPU
if (gpu_id >= 0) {
device_id_ = gpu_id;
FDASSERT(cudaSetDevice(device_id_) == cudaSuccess,
"[ERROR] Error occurs while setting cuda device.");
}
FDASSERT(cudaStreamCreate(&stream_) == cudaSuccess,
"[ERROR] Error occurs while creating cuda stream.");
DefaultProcLib::default_lib = ProcLib::CUDA;
#else
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
#endif
if (enable_cv_cuda) {
#ifdef ENABLE_CVCUDA
DefaultProcLib::default_lib = ProcLib::CVCUDA;
#else
FDASSERT(false, "FastDeploy didn't compile with CV-CUDA.");
#endif
}
}
bool ProcessorManager::CudaUsed() {
return (DefaultProcLib::default_lib == ProcLib::CUDA ||
DefaultProcLib::default_lib == ProcLib::CVCUDA);
}
bool ProcessorManager::Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0."
<< std::endl;
return false;
}
for (size_t i = 0; i < images->size(); ++i) {
if (CudaUsed()) {
SetStream(&((*images)[i]));
}
}
bool ret = Apply(images, outputs);
if (CudaUsed()) {
SyncStream();
}
return ret;
}
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,74 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/utils/utils.h"
#include "fastdeploy/vision/common/processors/mat.h"
namespace fastdeploy {
namespace vision {
class FASTDEPLOY_DECL ProcessorManager {
public:
~ProcessorManager();
void UseCuda(bool enable_cv_cuda = false, int gpu_id = -1);
bool CudaUsed();
void SetStream(Mat* mat) {
#ifdef WITH_GPU
mat->SetStream(stream_);
#endif
}
void SyncStream() {
#ifdef WITH_GPU
FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
"[ERROR] Error occurs while sync cuda stream.");
#endif
}
int DeviceId() { return device_id_; }
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
/** \brief The body of Run() function which needs to be implemented by a derived class
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
virtual bool Apply(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs) = 0;
protected:
bool initialized_ = false;
private:
#ifdef WITH_GPU
cudaStream_t stream_ = nullptr;
#endif
int device_id_ = -1;
};
} // namespace vision
} // namespace fastdeploy

View File

@@ -19,6 +19,36 @@
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
cv::Mat* Mat::GetOpenCVMat() {
if (mat_type == ProcLib::OPENCV) {
return &cpu_mat;
} else if (mat_type == ProcLib::FLYCV) {
#ifdef ENABLE_FLYCV
// Just a reference to fcv_mat, zero copy. After you
// call this method, cpu_mat and fcv_mat will point
// to the same memory buffer.
cpu_mat = ConvertFlyCVMatToOpenCV(fcv_mat);
mat_type = ProcLib::OPENCV;
return &cpu_mat;
#else
FDASSERT(false, "FastDeploy didn't compiled with FlyCV!");
#endif
} else if (mat_type == ProcLib::CUDA || mat_type == ProcLib::CVCUDA) {
#ifdef WITH_GPU
FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess,
"[ERROR] Error occurs while sync cuda stream.");
cpu_mat = CreateZeroCopyOpenCVMatFromTensor(fd_tensor);
mat_type = ProcLib::OPENCV;
device = Device::CPU;
return &cpu_mat;
#else
FDASSERT(false, "FastDeploy didn't compiled with -DWITH_GPU=ON");
#endif
} else {
FDASSERT(false, "The mat_type of custom Mat can not be ProcLib::DEFAULT");
}
}
void* Mat::Data() { void* Mat::Data() {
if (mat_type == ProcLib::FLYCV) { if (mat_type == ProcLib::FLYCV) {
#ifdef ENABLE_FLYCV #ifdef ENABLE_FLYCV
@@ -28,10 +58,32 @@ void* Mat::Data() {
"FastDeploy didn't compile with FlyCV, but met data type with " "FastDeploy didn't compile with FlyCV, but met data type with "
"fcv::Mat."); "fcv::Mat.");
#endif #endif
} else if (device == Device::GPU) {
return fd_tensor.Data();
} }
return cpu_mat.ptr(); return cpu_mat.ptr();
} }
FDTensor* Mat::Tensor() {
if (mat_type == ProcLib::OPENCV) {
ShareWithTensor(&fd_tensor);
} else if (mat_type == ProcLib::FLYCV) {
#ifdef ENABLE_FLYCV
cpu_mat = ConvertFlyCVMatToOpenCV(fcv_mat);
mat_type = ProcLib::OPENCV;
ShareWithTensor(&fd_tensor);
#else
FDASSERT(false, "FastDeploy didn't compiled with FlyCV!");
#endif
}
return &fd_tensor;
}
void Mat::SetTensor(FDTensor* tensor) {
fd_tensor.SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(),
tensor->device, tensor->device_id);
}
void Mat::ShareWithTensor(FDTensor* tensor) { void Mat::ShareWithTensor(FDTensor* tensor) {
tensor->SetExternalData({Channels(), Height(), Width()}, Type(), Data()); tensor->SetExternalData({Channels(), Height(), Width()}, Type(), Data());
tensor->device = device; tensor->device = device;
@@ -54,15 +106,15 @@ bool Mat::CopyToTensor(FDTensor* tensor) {
} }
void Mat::PrintInfo(const std::string& flag) { void Mat::PrintInfo(const std::string& flag) {
if (mat_type == ProcLib::FLYCV) {
#ifdef ENABLE_FLYCV
fcv::Scalar mean = fcv::mean(fcv_mat);
std::cout << flag << ": " std::cout << flag << ": "
<< "DataType=" << Type() << ", " << "DataType=" << Type() << ", "
<< "Channel=" << Channels() << ", " << "Channel=" << Channels() << ", "
<< "Height=" << Height() << ", " << "Height=" << Height() << ", "
<< "Width=" << Width() << ", " << "Width=" << Width() << ", "
<< "Mean="; << "Mean=";
if (mat_type == ProcLib::FLYCV) {
#ifdef ENABLE_FLYCV
fcv::Scalar mean = fcv::mean(fcv_mat);
for (int i = 0; i < Channels(); ++i) { for (int i = 0; i < Channels(); ++i) {
std::cout << mean[i] << " "; std::cout << mean[i] << " ";
} }
@@ -72,18 +124,25 @@ void Mat::PrintInfo(const std::string& flag) {
"FastDeploy didn't compile with FlyCV, but met data type with " "FastDeploy didn't compile with FlyCV, but met data type with "
"fcv::Mat."); "fcv::Mat.");
#endif #endif
} else { } else if (mat_type == ProcLib::OPENCV) {
cv::Scalar mean = cv::mean(cpu_mat); cv::Scalar mean = cv::mean(cpu_mat);
std::cout << flag << ": "
<< "DataType=" << Type() << ", "
<< "Channel=" << Channels() << ", "
<< "Height=" << Height() << ", "
<< "Width=" << Width() << ", "
<< "Mean=";
for (int i = 0; i < Channels(); ++i) { for (int i = 0; i < Channels(); ++i) {
std::cout << mean[i] << " "; std::cout << mean[i] << " ";
} }
std::cout << std::endl; std::cout << std::endl;
} else if (mat_type == ProcLib::CUDA || mat_type == ProcLib::CVCUDA) {
#ifdef WITH_GPU
FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess,
"[ERROR] Error occurs while sync cuda stream.");
cv::Mat tmp_mat = CreateZeroCopyOpenCVMatFromTensor(fd_tensor);
cv::Scalar mean = cv::mean(tmp_mat);
for (int i = 0; i < Channels(); ++i) {
std::cout << mean[i] << " ";
}
std::cout << std::endl;
#else
FDASSERT(false, "FastDeploy didn't compiled with -DWITH_GPU=ON");
#endif
} }
} }
@@ -97,6 +156,8 @@ FDDataType Mat::Type() {
"FastDeploy didn't compile with FlyCV, but met data type with " "FastDeploy didn't compile with FlyCV, but met data type with "
"fcv::Mat."); "fcv::Mat.");
#endif #endif
} else if (mat_type == ProcLib::CUDA || mat_type == ProcLib::CVCUDA) {
return fd_tensor.Dtype();
} }
return OpenCVDataTypeToFD(cpu_mat.type()); return OpenCVDataTypeToFD(cpu_mat.type());
} }
@@ -134,42 +195,41 @@ Mat Mat::Create(const FDTensor& tensor, ProcLib lib) {
return mat; return mat;
} }
Mat Mat::Create(int height, int width, int channels, Mat Mat::Create(int height, int width, int channels, FDDataType type,
FDDataType type, void* data) { void* data) {
if (DefaultProcLib::default_lib == ProcLib::FLYCV) { if (DefaultProcLib::default_lib == ProcLib::FLYCV) {
#ifdef ENABLE_FLYCV #ifdef ENABLE_FLYCV
fcv::Mat tmp_fcv_mat = CreateZeroCopyFlyCVMatFromBuffer( fcv::Mat tmp_fcv_mat =
height, width, channels, type, data); CreateZeroCopyFlyCVMatFromBuffer(height, width, channels, type, data);
Mat mat = Mat(tmp_fcv_mat); Mat mat = Mat(tmp_fcv_mat);
return mat; return mat;
#else #else
FDASSERT(false, "FastDeploy didn't compiled with FlyCV!"); FDASSERT(false, "FastDeploy didn't compiled with FlyCV!");
#endif #endif
} }
cv::Mat tmp_ocv_mat = CreateZeroCopyOpenCVMatFromBuffer( cv::Mat tmp_ocv_mat =
height, width, channels, type, data); CreateZeroCopyOpenCVMatFromBuffer(height, width, channels, type, data);
Mat mat = Mat(tmp_ocv_mat); Mat mat = Mat(tmp_ocv_mat);
return mat; return mat;
} }
Mat Mat::Create(int height, int width, int channels, Mat Mat::Create(int height, int width, int channels, FDDataType type,
FDDataType type, void* data, void* data, ProcLib lib) {
ProcLib lib) {
if (lib == ProcLib::DEFAULT) { if (lib == ProcLib::DEFAULT) {
return Create(height, width, channels, type, data); return Create(height, width, channels, type, data);
} }
if (lib == ProcLib::FLYCV) { if (lib == ProcLib::FLYCV) {
#ifdef ENABLE_FLYCV #ifdef ENABLE_FLYCV
fcv::Mat tmp_fcv_mat = CreateZeroCopyFlyCVMatFromBuffer( fcv::Mat tmp_fcv_mat =
height, width, channels, type, data); CreateZeroCopyFlyCVMatFromBuffer(height, width, channels, type, data);
Mat mat = Mat(tmp_fcv_mat); Mat mat = Mat(tmp_fcv_mat);
return mat; return mat;
#else #else
FDASSERT(false, "FastDeploy didn't compiled with FlyCV!"); FDASSERT(false, "FastDeploy didn't compiled with FlyCV!");
#endif #endif
} }
cv::Mat tmp_ocv_mat = CreateZeroCopyOpenCVMatFromBuffer( cv::Mat tmp_ocv_mat =
height, width, channels, type, data); CreateZeroCopyOpenCVMatFromBuffer(height, width, channels, type, data);
Mat mat = Mat(tmp_ocv_mat); Mat mat = Mat(tmp_ocv_mat);
return mat; return mat;
} }

View File

@@ -17,6 +17,10 @@
#include "fastdeploy/vision/common/processors/proc_lib.h" #include "fastdeploy/vision/common/processors/proc_lib.h"
#include "opencv2/core/core.hpp" #include "opencv2/core/core.hpp"
#ifdef WITH_GPU
#include <cuda_runtime_api.h>
#endif
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -60,24 +64,7 @@ struct FASTDEPLOY_DECL Mat {
mat_type = ProcLib::OPENCV; mat_type = ProcLib::OPENCV;
} }
cv::Mat* GetOpenCVMat() { cv::Mat* GetOpenCVMat();
if (mat_type == ProcLib::OPENCV) {
return &cpu_mat;
} else if (mat_type == ProcLib::FLYCV) {
#ifdef ENABLE_FLYCV
// Just a reference to fcv_mat, zero copy. After you
// call this method, cpu_mat and fcv_mat will point
// to the same memory buffer.
cpu_mat = ConvertFlyCVMatToOpenCV(fcv_mat);
mat_type = ProcLib::OPENCV;
return &cpu_mat;
#else
FDASSERT(false, "FastDeploy didn't compiled with FlyCV!");
#endif
} else {
FDASSERT(false, "The mat_type of custom Mat can not be ProcLib::DEFAULT");
}
}
#ifdef ENABLE_FLYCV #ifdef ENABLE_FLYCV
void SetMat(const fcv::Mat& mat) { void SetMat(const fcv::Mat& mat) {
@@ -103,6 +90,12 @@ struct FASTDEPLOY_DECL Mat {
void* Data(); void* Data();
// Get fd_tensor
FDTensor* Tensor();
// Set fd_tensor
void SetTensor(FDTensor* tensor);
private: private:
int channels; int channels;
int height; int height;
@@ -111,6 +104,12 @@ struct FASTDEPLOY_DECL Mat {
#ifdef ENABLE_FLYCV #ifdef ENABLE_FLYCV
fcv::Mat fcv_mat; fcv::Mat fcv_mat;
#endif #endif
#ifdef WITH_GPU
cudaStream_t stream = nullptr;
#endif
// Currently, fd_tensor is only used by CUDA and CV-CUDA,
// OpenCV and FlyCV are not using it.
FDTensor fd_tensor;
public: public:
FDDataType Type(); FDDataType Type();
@@ -120,6 +119,10 @@ struct FASTDEPLOY_DECL Mat {
void SetChannels(int s) { channels = s; } void SetChannels(int s) { channels = s; }
void SetWidth(int w) { width = w; } void SetWidth(int w) { width = w; }
void SetHeight(int h) { height = h; } void SetHeight(int h) { height = h; }
#ifdef WITH_GPU
cudaStream_t Stream() const { return stream; }
void SetStream(cudaStream_t s) { stream = s; }
#endif
// Transfer the vision::Mat to FDTensor // Transfer the vision::Mat to FDTensor
void ShareWithTensor(FDTensor* tensor); void ShareWithTensor(FDTensor* tensor);

View File

@@ -37,49 +37,46 @@ __global__ void NormalizeAndPermuteKernel(uint8_t* src, float* dst,
} }
bool NormalizeAndPermute::ImplByCuda(Mat* mat) { bool NormalizeAndPermute::ImplByCuda(Mat* mat) {
cv::Mat* im = mat->GetOpenCVMat(); // Prepare input tensor
std::string buf_name = Name() + "_src"; std::string tensor_name = Name() + "_cvcuda_src";
std::vector<int64_t> shape = {im->rows, im->cols, im->channels()}; FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
FDTensor* src =
UpdateAndGetReusedBuffer(shape, im->type(), buf_name, Device::GPU);
FDASSERT(cudaMemcpy(src->Data(), im->ptr(), src->Nbytes(),
cudaMemcpyHostToDevice) == 0,
"Error occurs while copy memory from CPU to GPU.");
buf_name = Name() + "_dst"; // Prepare output tensor
FDTensor* dst = UpdateAndGetReusedBuffer(shape, CV_32FC(im->channels()), tensor_name = Name() + "_dst";
buf_name, Device::GPU); FDTensor* dst = UpdateAndGetCachedTensor(src->Shape(), FDDataType::FP32,
cv::Mat res(im->rows, im->cols, CV_32FC(im->channels()), dst->Data()); tensor_name, Device::GPU);
buf_name = Name() + "_alpha"; // Copy alpha and beta to GPU
FDTensor* alpha = UpdateAndGetReusedBuffer({(int64_t)alpha_.size()}, CV_32FC1, tensor_name = Name() + "_alpha";
buf_name, Device::GPU); FDMat alpha_mat =
FDASSERT(cudaMemcpy(alpha->Data(), alpha_.data(), alpha->Nbytes(), FDMat::Create(1, 1, alpha_.size(), FDDataType::FP32, alpha_.data());
cudaMemcpyHostToDevice) == 0, FDTensor* alpha = CreateCachedGpuInputTensor(&alpha_mat, tensor_name);
"Error occurs while copy memory from CPU to GPU.");
buf_name = Name() + "_beta"; tensor_name = Name() + "_beta";
FDTensor* beta = UpdateAndGetReusedBuffer({(int64_t)beta_.size()}, CV_32FC1, FDMat beta_mat =
buf_name, Device::GPU); FDMat::Create(1, 1, beta_.size(), FDDataType::FP32, beta_.data());
FDASSERT(cudaMemcpy(beta->Data(), beta_.data(), beta->Nbytes(), FDTensor* beta = CreateCachedGpuInputTensor(&beta_mat, tensor_name);
cudaMemcpyHostToDevice) == 0,
"Error occurs while copy memory from CPU to GPU.");
int jobs = im->cols * im->rows; int jobs = mat->Width() * mat->Height();
int threads = 256; int threads = 256;
int blocks = ceil(jobs / (float)threads); int blocks = ceil(jobs / (float)threads);
NormalizeAndPermuteKernel<<<blocks, threads, 0, NULL>>>( NormalizeAndPermuteKernel<<<blocks, threads, 0, mat->Stream()>>>(
reinterpret_cast<uint8_t*>(src->Data()), reinterpret_cast<uint8_t*>(src->Data()),
reinterpret_cast<float*>(dst->Data()), reinterpret_cast<float*>(dst->Data()),
reinterpret_cast<float*>(alpha->Data()), reinterpret_cast<float*>(alpha->Data()),
reinterpret_cast<float*>(beta->Data()), im->channels(), swap_rb_, jobs); reinterpret_cast<float*>(beta->Data()), mat->Channels(), swap_rb_, jobs);
mat->SetMat(res); mat->SetTensor(dst);
mat->device = Device::GPU; mat->device = Device::GPU;
mat->layout = Layout::CHW; mat->layout = Layout::CHW;
mat->mat_type = ProcLib::CUDA;
return true; return true;
} }
#ifdef ENABLE_CVCUDA
bool NormalizeAndPermute::ImplByCvCuda(Mat* mat) { return ImplByCuda(mat); }
#endif
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy
#endif #endif

View File

@@ -31,6 +31,9 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor {
#endif #endif
#ifdef WITH_GPU #ifdef WITH_GPU
bool ImplByCuda(Mat* mat); bool ImplByCuda(Mat* mat);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(Mat* mat);
#endif #endif
std::string Name() { return "NormalizeAndPermute"; } std::string Name() { return "NormalizeAndPermute"; }

View File

@@ -18,7 +18,7 @@
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
enum class FASTDEPLOY_DECL ProcLib { DEFAULT, OPENCV, FLYCV, CUDA }; enum class FASTDEPLOY_DECL ProcLib { DEFAULT, OPENCV, FLYCV, CUDA, CVCUDA };
FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out, const ProcLib& p); FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out, const ProcLib& p);

View File

@@ -14,6 +14,12 @@
#include "fastdeploy/vision/common/processors/resize.h" #include "fastdeploy/vision/common/processors/resize.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpResize.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -116,6 +122,52 @@ bool Resize::ImplByFlyCV(Mat* mat) {
} }
#endif #endif
#ifdef ENABLE_CVCUDA
bool Resize::ImplByCvCuda(Mat* mat) {
if (width_ == mat->Width() && height_ == mat->Height()) {
return true;
}
if (fabs(scale_w_ - 1.0) < 1e-06 && fabs(scale_h_ - 1.0) < 1e-06) {
return true;
}
if (width_ > 0 && height_ > 0) {
} else if (scale_w_ > 0 && scale_h_ > 0) {
width_ = std::round(scale_w_ * mat->Width());
height_ = std::round(scale_h_ * mat->Height());
} else {
FDERROR << "Resize: the parameters must satisfy (width > 0 && height > 0) "
"or (scale_w > 0 && scale_h > 0)."
<< std::endl;
return false;
}
// Prepare input tensor
std::string tensor_name = Name() + "_cvcuda_src";
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
auto src_tensor = CreateCvCudaTensorWrapData(*src);
// Prepare output tensor
tensor_name = Name() + "_cvcuda_dst";
FDTensor* dst =
UpdateAndGetCachedTensor({height_, width_, mat->Channels()}, mat->Type(),
tensor_name, Device::GPU);
auto dst_tensor = CreateCvCudaTensorWrapData(*dst);
// CV-CUDA Interp value is compatible with OpenCV
cvcuda::Resize resize_op;
resize_op(mat->Stream(), src_tensor, dst_tensor,
NVCVInterpolationType(interp_));
mat->SetTensor(dst);
mat->SetWidth(width_);
mat->SetHeight(height_);
mat->device = Device::GPU;
mat->mat_type = ProcLib::CVCUDA;
return true;
}
#endif
bool Resize::Run(Mat* mat, int width, int height, float scale_w, float scale_h, bool Resize::Run(Mat* mat, int width, int height, float scale_w, float scale_h,
int interp, bool use_scale, ProcLib lib) { int interp, bool use_scale, ProcLib lib) {
if (mat->Height() == height && mat->Width() == width) { if (mat->Height() == height && mat->Width() == width) {

View File

@@ -34,6 +34,9 @@ class FASTDEPLOY_DECL Resize : public Processor {
bool ImplByOpenCV(Mat* mat); bool ImplByOpenCV(Mat* mat);
#ifdef ENABLE_FLYCV #ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat); bool ImplByFlyCV(Mat* mat);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(Mat* mat);
#endif #endif
std::string Name() { return "Resize"; } std::string Name() { return "Resize"; }

View File

@@ -14,6 +14,12 @@
#include "fastdeploy/vision/common/processors/resize_by_short.h" #include "fastdeploy/vision/common/processors/resize_by_short.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpResize.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -80,6 +86,37 @@ bool ResizeByShort::ImplByFlyCV(Mat* mat) {
} }
#endif #endif
#ifdef ENABLE_CVCUDA
bool ResizeByShort::ImplByCvCuda(Mat* mat) {
// Prepare input tensor
std::string tensor_name = Name() + "_cvcuda_src";
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
auto src_tensor = CreateCvCudaTensorWrapData(*src);
double scale = GenerateScale(mat->Width(), mat->Height());
int width = static_cast<int>(round(scale * mat->Width()));
int height = static_cast<int>(round(scale * mat->Height()));
// Prepare output tensor
tensor_name = Name() + "_cvcuda_dst";
FDTensor* dst = UpdateAndGetCachedTensor(
{height, width, mat->Channels()}, mat->Type(), tensor_name, Device::GPU);
auto dst_tensor = CreateCvCudaTensorWrapData(*dst);
// CV-CUDA Interp value is compatible with OpenCV
cvcuda::Resize resize_op;
resize_op(mat->Stream(), src_tensor, dst_tensor,
NVCVInterpolationType(interp_));
mat->SetTensor(dst);
mat->SetWidth(width);
mat->SetHeight(height);
mat->device = Device::GPU;
mat->mat_type = ProcLib::CVCUDA;
return true;
}
#endif
double ResizeByShort::GenerateScale(const int origin_w, const int origin_h) { double ResizeByShort::GenerateScale(const int origin_w, const int origin_h) {
int im_size_max = std::max(origin_w, origin_h); int im_size_max = std::max(origin_w, origin_h);
int im_size_min = std::min(origin_w, origin_h); int im_size_min = std::min(origin_w, origin_h);

View File

@@ -31,6 +31,9 @@ class FASTDEPLOY_DECL ResizeByShort : public Processor {
bool ImplByOpenCV(Mat* mat); bool ImplByOpenCV(Mat* mat);
#ifdef ENABLE_FLYCV #ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat); bool ImplByFlyCV(Mat* mat);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(Mat* mat);
#endif #endif
std::string Name() { return "ResizeByShort"; } std::string Name() { return "ResizeByShort"; }

View File

@@ -35,12 +35,13 @@ class PaddleClasPreprocessor:
""" """
return self._preprocessor.run(input_ims) return self._preprocessor.run(input_ims)
def use_gpu(self, gpu_id=-1): def use_cuda(self, enable_cv_cuda=False, gpu_id=-1):
"""Use CUDA preprocessors """Use CUDA preprocessors
:param: enable_cv_cuda: Whether to enable CV-CUDA
:param: gpu_id: GPU device id :param: gpu_id: GPU device id
""" """
return self._preprocessor.use_gpu(gpu_id) return self._preprocessor.use_cuda(enable_cv_cuda, gpu_id)
def disable_normalize(self): def disable_normalize(self):
""" """