mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[CVCUDA] CMake integration, vison processor CV-CUDA integration, PaddleClas support CV-CUDA (#1074)
* cvcuda resize * cvcuda center crop * cvcuda resize * add a fdtensor in fdmat * get cv mat and get tensor support gpu * paddleclas cvcuda preprocessor * fix compile err * fix windows compile error * rename reused to cached * address comment * remove debug code * add comment * add manager run * use cuda and cuda used * use cv cuda doc * address comment --------- Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
@@ -66,6 +66,7 @@ option(ENABLE_LITE_BACKEND "Whether to enable paddle lite backend." OFF)
|
||||
option(ENABLE_VISION "Whether to enable vision models usage." OFF)
|
||||
option(ENABLE_TEXT "Whether to enable text models usage." OFF)
|
||||
option(ENABLE_FLYCV "Whether to enable flycv to boost image preprocess." OFF)
|
||||
option(ENABLE_CVCUDA "Whether to enable NVIDIA CV-CUDA to boost image preprocess." OFF)
|
||||
option(ENABLE_ENCRYPTION "Whether to enable ENCRYPTION." OFF)
|
||||
option(WITH_ASCEND "Whether to compile for Huawei Ascend deploy." OFF)
|
||||
option(WITH_TIMVX "Whether to compile for TIMVX deploy." OFF)
|
||||
@@ -373,6 +374,12 @@ if(ENABLE_VISION)
|
||||
include(${PROJECT_SOURCE_DIR}/cmake/flycv.cmake)
|
||||
list(APPEND DEPEND_LIBS external_flycv)
|
||||
endif()
|
||||
|
||||
if(ENABLE_CVCUDA)
|
||||
include(${PROJECT_SOURCE_DIR}/cmake/cvcuda.cmake)
|
||||
add_definitions(-DENABLE_CVCUDA)
|
||||
list(APPEND DEPEND_LIBS nvcv_types cvcuda)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(ENABLE_TEXT)
|
||||
|
@@ -13,6 +13,7 @@ set(ENABLE_TRT_BACKEND @ENABLE_TRT_BACKEND@)
|
||||
set(ENABLE_PADDLE2ONNX @ENABLE_PADDLE2ONNX@)
|
||||
set(ENABLE_VISION @ENABLE_VISION@)
|
||||
set(ENABLE_FLYCV @ENABLE_FLYCV@)
|
||||
set(ENABLE_CVCUDA @ENABLE_CVCUDA@)
|
||||
set(ENABLE_TEXT @ENABLE_TEXT@)
|
||||
set(ENABLE_ENCRYPTION @ENABLE_ENCRYPTION@)
|
||||
set(BUILD_ON_JETSON @BUILD_ON_JETSON@)
|
||||
@@ -140,6 +141,7 @@ if(WITH_GPU)
|
||||
message(FATAL_ERROR "[FastDeploy] Cannot find library cudart in ${CUDA_DIRECTORY}, Please define CUDA_DIRECTORY, e.g -DCUDA_DIRECTORY=/path/to/cuda")
|
||||
endif()
|
||||
list(APPEND FASTDEPLOY_LIBS ${CUDA_LIB})
|
||||
list(APPEND FASTDEPLOY_INCS ${CUDA_DIRECTORY}/include)
|
||||
|
||||
if (ENABLE_TRT_BACKEND)
|
||||
if(BUILD_ON_JETSON)
|
||||
@@ -218,6 +220,12 @@ if(ENABLE_VISION)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(ENABLE_CVCUDA)
|
||||
find_library(CVCUDA_LIB cvcuda ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/cvcuda/lib NO_DEFAULT_PATH)
|
||||
find_library(NVCV_TYPES_LIB nvcv_types ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/cvcuda/lib NO_DEFAULT_PATH)
|
||||
list(APPEND FASTDEPLOY_LIBS ${CVCUDA_LIB} ${NVCV_TYPES_LIB})
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
if (ENABLE_TEXT)
|
||||
@@ -288,6 +296,7 @@ if(ENABLE_OPENVINO_BACKEND)
|
||||
endif()
|
||||
message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}")
|
||||
message(STATUS " ENABLE_VISION : ${ENABLE_VISION}")
|
||||
message(STATUS " ENABLE_CVCUDA : ${ENABLE_CVCUDA}")
|
||||
message(STATUS " ENABLE_TEXT : ${ENABLE_TEXT}")
|
||||
message(STATUS " ENABLE_ENCRYPTION : ${ENABLE_ENCRYPTION}")
|
||||
if(WITH_GPU)
|
||||
|
43
cmake/cvcuda.cmake
Normal file
43
cmake/cvcuda.cmake
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
if(NOT WITH_GPU)
|
||||
message(FATAL_ERROR "ENABLE_CVCUDA is available on Linux and WITH_GPU=ON, but now WITH_GPU=OFF.")
|
||||
endif()
|
||||
|
||||
if(APPLE OR ANDROID OR IOS OR WIN32)
|
||||
message(FATAL_ERROR "Cannot enable CV-CUDA in mac/ios/android/windows os, please set -DENABLE_CVCUDA=OFF.")
|
||||
endif()
|
||||
|
||||
if(NOT (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64"))
|
||||
message(FATAL_ERROR "CV-CUDA only support x86_64.")
|
||||
endif()
|
||||
|
||||
set(CVCUDA_LIB_URL https://github.com/CVCUDA/CV-CUDA/releases/download/v0.2.0-alpha/nvcv-lib-0.2.0_alpha-cuda11-x86_64-linux.tar.xz)
|
||||
set(CVCUDA_LIB_FILENAME nvcv-lib-0.2.0_alpha-cuda11-x86_64-linux.tar.xz)
|
||||
set(CVCUDA_DEV_URL https://github.com/CVCUDA/CV-CUDA/releases/download/v0.2.0-alpha/nvcv-dev-0.2.0_alpha-cuda11-x86_64-linux.tar.xz)
|
||||
set(CVCUDA_DEV_FILENAME nvcv-dev-0.2.0_alpha-cuda11-x86_64-linux.tar.xz)
|
||||
|
||||
download_and_decompress(${CVCUDA_LIB_URL} ${CMAKE_CURRENT_BINARY_DIR}/${CVCUDA_LIB_FILENAME} ${THIRD_PARTY_PATH}/cvcuda)
|
||||
download_and_decompress(${CVCUDA_DEV_URL} ${CMAKE_CURRENT_BINARY_DIR}/${CVCUDA_DEV_FILENAME} ${THIRD_PARTY_PATH}/cvcuda)
|
||||
|
||||
execute_process(COMMAND rm -rf ${THIRD_PARTY_PATH}/install/cvcuda)
|
||||
execute_process(COMMAND mkdir -p ${THIRD_PARTY_PATH}/install/cvcuda)
|
||||
execute_process(COMMAND cp -r ${THIRD_PARTY_PATH}/cvcuda/opt/nvidia/cvcuda0/lib/x86_64-linux-gnu/ ${THIRD_PARTY_PATH}/install/cvcuda/lib)
|
||||
execute_process(COMMAND cp -r ${THIRD_PARTY_PATH}/cvcuda/opt/nvidia/cvcuda0/include/ ${THIRD_PARTY_PATH}/install/cvcuda/include)
|
||||
|
||||
link_directories(${THIRD_PARTY_PATH}/install/cvcuda/lib)
|
||||
include_directories(${THIRD_PARTY_PATH}/install/cvcuda/include)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
39
docs/cn/faq/use_cv_cuda.md
Normal file
39
docs/cn/faq/use_cv_cuda.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# 使用CV-CUDA/CUDA加速GPU端到端推理性能
|
||||
|
||||
FastDeploy集成了CV-CUDA来加速预/后处理,个别CV-CUDA不支持的算子使用了CUDA kernel的方式实现。
|
||||
|
||||
FastDeploy的Vision Processor模块对CV-CUDA的算子做了进一步的封装,用户不需要自己去调用CV-CUDA,
|
||||
使用FastDeploy的模型推理接口即可利用CV-CUDA的加速能力。
|
||||
|
||||
FastDeploy的Vision Processor模块在集成CV-CUDA时,做了以下工作来方便用户的使用:
|
||||
- GPU内存管理,缓存算子的输入、输出tensor,避免重复分配GPU内存
|
||||
- CV-CUDA不支持的个别算子利用CUDA kernel实现
|
||||
- CV-CUDA/CUDA不支持的算子可以fallback到OpenCV/FlyCV
|
||||
|
||||
## 使用方式
|
||||
编译FastDeploy时,开启CV-CUDA编译选项
|
||||
```bash
|
||||
# 编译C++预测库时, 开启CV-CUDA编译选项.
|
||||
-DENABLE_CVCUDA=ON \
|
||||
|
||||
# 在编译Python预测库时, 开启CV-CUDA编译选项
|
||||
export ENABLE_CVCUDA=ON
|
||||
```
|
||||
|
||||
只有继承了ProcessorManager类的模型预处理,才可以使用CV-CUDA,这里以PaddleClasPreprocessor为例
|
||||
```bash
|
||||
# C++
|
||||
# 创建model之后,调用model preprocessor的UseCuda接口即可打开CV-CUDA/CUDA预处理
|
||||
# 第一个参数enable_cv_cuda,true代表使用CV-CUDA,false代表只使用CUDA(支持的算子较少)
|
||||
# 第二个参数是GPU id,-1代表不指定,使用当前GPU
|
||||
model.GetPreprocessor().UseCuda(true, 0);
|
||||
|
||||
# Python
|
||||
model.preprocessor.use_cuda(True, 0)
|
||||
```
|
||||
|
||||
## 最佳实践
|
||||
|
||||
- 如果预处理第一个算子是resize,则要根据实际情况决定resize是否跑在GPU。因为当resize跑在GPU,
|
||||
且图片解码在CPU时,需要把原图copy到GPU内存,开销较大,而resize之后再copy到GPU内存,则往往只需要
|
||||
copy较少的数据。
|
@@ -70,7 +70,7 @@ class TritonPythonModel:
|
||||
yaml_path)
|
||||
if args['model_instance_kind'] == 'GPU':
|
||||
device_id = int(args['model_instance_device_id'])
|
||||
self.preprocess_.use_gpu(device_id)
|
||||
self.preprocess_.use_cuda(False, device_id)
|
||||
|
||||
def execute(self, requests):
|
||||
"""`execute` must be implemented in every Python model. `execute`
|
||||
|
@@ -18,67 +18,89 @@ void BindPaddleClas(pybind11::module& m) {
|
||||
pybind11::class_<vision::classification::PaddleClasPreprocessor>(
|
||||
m, "PaddleClasPreprocessor")
|
||||
.def(pybind11::init<std::string>())
|
||||
.def("run", [](vision::classification::PaddleClasPreprocessor& self, std::vector<pybind11::array>& im_list) {
|
||||
.def("run",
|
||||
[](vision::classification::PaddleClasPreprocessor& self,
|
||||
std::vector<pybind11::array>& im_list) {
|
||||
std::vector<vision::FDMat> images;
|
||||
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
|
||||
}
|
||||
std::vector<FDTensor> outputs;
|
||||
if (!self.Run(&images, &outputs)) {
|
||||
throw std::runtime_error("Failed to preprocess the input data in PaddleClasPreprocessor.");
|
||||
throw std::runtime_error(
|
||||
"Failed to preprocess the input data in "
|
||||
"PaddleClasPreprocessor.");
|
||||
}
|
||||
if (!self.WithGpu()) {
|
||||
if (!self.CudaUsed()) {
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
outputs[i].StopSharing();
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
})
|
||||
.def("use_gpu", [](vision::classification::PaddleClasPreprocessor& self, int gpu_id = -1) {
|
||||
self.UseGpu(gpu_id);
|
||||
})
|
||||
.def("disable_normalize", [](vision::classification::PaddleClasPreprocessor& self) {
|
||||
.def("use_cuda",
|
||||
[](vision::classification::PaddleClasPreprocessor& self,
|
||||
bool enable_cv_cuda = false,
|
||||
int gpu_id = -1) { self.UseCuda(enable_cv_cuda, gpu_id); })
|
||||
.def("disable_normalize",
|
||||
[](vision::classification::PaddleClasPreprocessor& self) {
|
||||
self.DisableNormalize();
|
||||
})
|
||||
.def("disable_permute", [](vision::classification::PaddleClasPreprocessor& self) {
|
||||
.def("disable_permute",
|
||||
[](vision::classification::PaddleClasPreprocessor& self) {
|
||||
self.DisablePermute();
|
||||
});
|
||||
|
||||
pybind11::class_<vision::classification::PaddleClasPostprocessor>(
|
||||
m, "PaddleClasPostprocessor")
|
||||
.def(pybind11::init<int>())
|
||||
.def("run", [](vision::classification::PaddleClasPostprocessor& self, std::vector<FDTensor>& inputs) {
|
||||
.def("run",
|
||||
[](vision::classification::PaddleClasPostprocessor& self,
|
||||
std::vector<FDTensor>& inputs) {
|
||||
std::vector<vision::ClassifyResult> results;
|
||||
if (!self.Run(inputs, &results)) {
|
||||
throw std::runtime_error("Failed to postprocess the runtime result in PaddleClasPostprocessor.");
|
||||
throw std::runtime_error(
|
||||
"Failed to postprocess the runtime result in "
|
||||
"PaddleClasPostprocessor.");
|
||||
}
|
||||
return results;
|
||||
})
|
||||
.def("run", [](vision::classification::PaddleClasPostprocessor& self, std::vector<pybind11::array>& input_array) {
|
||||
.def("run",
|
||||
[](vision::classification::PaddleClasPostprocessor& self,
|
||||
std::vector<pybind11::array>& input_array) {
|
||||
std::vector<vision::ClassifyResult> results;
|
||||
std::vector<FDTensor> inputs;
|
||||
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
|
||||
if (!self.Run(inputs, &results)) {
|
||||
throw std::runtime_error("Failed to postprocess the runtime result in PaddleClasPostprocessor.");
|
||||
throw std::runtime_error(
|
||||
"Failed to postprocess the runtime result in "
|
||||
"PaddleClasPostprocessor.");
|
||||
}
|
||||
return results;
|
||||
})
|
||||
.def_property("topk", &vision::classification::PaddleClasPostprocessor::GetTopk, &vision::classification::PaddleClasPostprocessor::SetTopk);
|
||||
.def_property("topk",
|
||||
&vision::classification::PaddleClasPostprocessor::GetTopk,
|
||||
&vision::classification::PaddleClasPostprocessor::SetTopk);
|
||||
|
||||
pybind11::class_<vision::classification::PaddleClasModel, FastDeployModel>(
|
||||
m, "PaddleClasModel")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def("clone", [](vision::classification::PaddleClasModel& self) {
|
||||
.def("clone",
|
||||
[](vision::classification::PaddleClasModel& self) {
|
||||
return self.Clone();
|
||||
})
|
||||
.def("predict", [](vision::classification::PaddleClasModel& self, pybind11::array& data) {
|
||||
.def("predict",
|
||||
[](vision::classification::PaddleClasModel& self,
|
||||
pybind11::array& data) {
|
||||
cv::Mat im = PyArrayToCvMat(data);
|
||||
vision::ClassifyResult result;
|
||||
self.Predict(im, &result);
|
||||
return result;
|
||||
})
|
||||
.def("batch_predict", [](vision::classification::PaddleClasModel& self, std::vector<pybind11::array>& data) {
|
||||
.def("batch_predict",
|
||||
[](vision::classification::PaddleClasModel& self,
|
||||
std::vector<pybind11::array>& data) {
|
||||
std::vector<cv::Mat> images;
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
images.push_back(PyArrayToCvMat(data[i]));
|
||||
@@ -87,7 +109,11 @@ void BindPaddleClas(pybind11::module& m) {
|
||||
self.BatchPredict(images, &results);
|
||||
return results;
|
||||
})
|
||||
.def_property_readonly("preprocessor", &vision::classification::PaddleClasModel::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor", &vision::classification::PaddleClasModel::GetPostprocessor);
|
||||
.def_property_readonly(
|
||||
"preprocessor",
|
||||
&vision::classification::PaddleClasModel::GetPreprocessor)
|
||||
.def_property_readonly(
|
||||
"postprocessor",
|
||||
&vision::classification::PaddleClasModel::GetPostprocessor);
|
||||
}
|
||||
} // namespace fastdeploy
|
||||
|
@@ -13,11 +13,9 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/classification/ppcls/preprocessor.h"
|
||||
|
||||
#include "fastdeploy/function/concat.h"
|
||||
#include "yaml-cpp/yaml.h"
|
||||
#ifdef WITH_GPU
|
||||
#include <cuda_runtime_api.h>
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -61,7 +59,8 @@ bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig() {
|
||||
auto mean = op.begin()->second["mean"].as<std::vector<float>>();
|
||||
auto std = op.begin()->second["std"].as<std::vector<float>>();
|
||||
auto scale = op.begin()->second["scale"].as<float>();
|
||||
FDASSERT((scale - 0.00392157) < 1e-06 && (scale - 0.00392157) > -1e-06,
|
||||
FDASSERT(
|
||||
(scale - 0.00392157) < 1e-06 && (scale - 0.00392157) > -1e-06,
|
||||
"Only support scale in Normalize be 0.00392157, means the pixel "
|
||||
"is in range of [0, 255].");
|
||||
processors_.push_back(std::make_shared<Normalize>(mean, std));
|
||||
@@ -84,53 +83,32 @@ bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig() {
|
||||
|
||||
void PaddleClasPreprocessor::DisableNormalize() {
|
||||
this->disable_normalize_ = true;
|
||||
// the DisableNormalize function will be invalid if the configuration file is loaded during preprocessing
|
||||
// the DisableNormalize function will be invalid if the configuration file is
|
||||
// loaded during preprocessing
|
||||
if (!BuildPreprocessPipelineFromConfig()) {
|
||||
FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl;
|
||||
FDERROR << "Failed to build preprocess pipeline from configuration file."
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
void PaddleClasPreprocessor::DisablePermute() {
|
||||
this->disable_permute_ = true;
|
||||
// the DisablePermute function will be invalid if the configuration file is loaded during preprocessing
|
||||
// the DisablePermute function will be invalid if the configuration file is
|
||||
// loaded during preprocessing
|
||||
if (!BuildPreprocessPipelineFromConfig()) {
|
||||
FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl;
|
||||
FDERROR << "Failed to build preprocess pipeline from configuration file."
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void PaddleClasPreprocessor::UseGpu(int gpu_id) {
|
||||
#ifdef WITH_GPU
|
||||
use_cuda_ = true;
|
||||
if (gpu_id < 0) return;
|
||||
device_id_ = gpu_id;
|
||||
cudaSetDevice(device_id_);
|
||||
#else
|
||||
FDWARNING << "FastDeploy didn't compile with WITH_GPU. "
|
||||
<< "Will force to use CPU to run preprocessing." << std::endl;
|
||||
use_cuda_ = false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool PaddleClasPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
|
||||
if (!initialized_) {
|
||||
FDERROR << "The preprocessor is not initialized." << std::endl;
|
||||
return false;
|
||||
}
|
||||
if (images->size() == 0) {
|
||||
FDERROR << "The size of input images should be greater than 0." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PaddleClasPreprocessor::Apply(std::vector<FDMat>* images,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
for (size_t j = 0; j < processors_.size(); ++j) {
|
||||
bool ret = false;
|
||||
if (processors_[j]->Name() == "NormalizeAndPermute" && use_cuda_) {
|
||||
ret = (*(processors_[j].get()))(&((*images)[i]), ProcLib::CUDA);
|
||||
} else {
|
||||
ret = (*(processors_[j].get()))(&((*images)[i]));
|
||||
}
|
||||
if (!ret) {
|
||||
FDERROR << "Failed to processs image:" << i << " in "
|
||||
<< processors_[i]->Name() << "." << std::endl;
|
||||
<< processors_[j]->Name() << "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -148,7 +126,7 @@ bool PaddleClasPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTenso
|
||||
} else {
|
||||
function::Concat(tensors, &((*outputs)[0]), 0);
|
||||
}
|
||||
(*outputs)[0].device_id = device_id_;
|
||||
(*outputs)[0].device_id = DeviceId();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/common/processors/manager.h"
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
|
||||
@@ -22,7 +23,7 @@ namespace vision {
|
||||
namespace classification {
|
||||
/*! @brief Preprocessor object for PaddleClas serials model.
|
||||
*/
|
||||
class FASTDEPLOY_DECL PaddleClasPreprocessor {
|
||||
class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager {
|
||||
public:
|
||||
/** \brief Create a preprocessor instance for PaddleClas serials model
|
||||
*
|
||||
@@ -36,15 +37,8 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor {
|
||||
* \param[in] outputs The output tensors which will feed in runtime
|
||||
* \return true if the preprocess successed, otherwise false
|
||||
*/
|
||||
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
|
||||
|
||||
/** \brief Use GPU to run preprocessing
|
||||
*
|
||||
* \param[in] gpu_id GPU device id
|
||||
*/
|
||||
void UseGpu(int gpu_id = -1);
|
||||
|
||||
bool WithGpu() { return use_cuda_; }
|
||||
virtual bool Apply(std::vector<FDMat>* images,
|
||||
std::vector<FDTensor>* outputs);
|
||||
|
||||
/// This function will disable normalize in preprocessing step.
|
||||
void DisableNormalize();
|
||||
@@ -54,10 +48,6 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor {
|
||||
private:
|
||||
bool BuildPreprocessPipelineFromConfig();
|
||||
std::vector<std::shared_ptr<Processor>> processors_;
|
||||
bool initialized_ = false;
|
||||
bool use_cuda_ = false;
|
||||
// GPU device id
|
||||
int device_id_ = -1;
|
||||
// for recording the switch of hwc2chw
|
||||
bool disable_permute_ = false;
|
||||
// for recording the switch of normalize
|
||||
|
@@ -13,9 +13,9 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/common/processors/base.h"
|
||||
#include "fastdeploy/vision/common/processors/proc_lib.h"
|
||||
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
#include "fastdeploy/vision/common/processors/proc_lib.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -33,27 +33,58 @@ bool Processor::operator()(Mat* mat, ProcLib lib) {
|
||||
#endif
|
||||
} else if (target == ProcLib::CUDA) {
|
||||
#ifdef WITH_GPU
|
||||
FDASSERT(mat->Stream() != nullptr,
|
||||
"CUDA processor requires cuda stream, please set stream for Mat");
|
||||
return ImplByCuda(mat);
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
|
||||
#endif
|
||||
} else if (target == ProcLib::CVCUDA) {
|
||||
#ifdef ENABLE_CVCUDA
|
||||
FDASSERT(mat->Stream() != nullptr,
|
||||
"CV-CUDA requires cuda stream, please set stream for Mat");
|
||||
return ImplByCvCuda(mat);
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with CV-CUDA.");
|
||||
#endif
|
||||
}
|
||||
// DEFAULT & OPENCV
|
||||
return ImplByOpenCV(mat);
|
||||
}
|
||||
|
||||
FDTensor* Processor::UpdateAndGetReusedBuffer(
|
||||
const std::vector<int64_t>& new_shape, const int& opencv_dtype,
|
||||
const std::string& buffer_name, const Device& new_device,
|
||||
FDTensor* Processor::UpdateAndGetCachedTensor(
|
||||
const std::vector<int64_t>& new_shape, const FDDataType& data_type,
|
||||
const std::string& tensor_name, const Device& new_device,
|
||||
const bool& use_pinned_memory) {
|
||||
if (reused_buffers_.count(buffer_name) == 0) {
|
||||
reused_buffers_[buffer_name] = FDTensor();
|
||||
if (cached_tensors_.count(tensor_name) == 0) {
|
||||
cached_tensors_[tensor_name] = FDTensor();
|
||||
}
|
||||
reused_buffers_[buffer_name].is_pinned_memory = use_pinned_memory;
|
||||
reused_buffers_[buffer_name].Resize(new_shape,
|
||||
OpenCVDataTypeToFD(opencv_dtype),
|
||||
buffer_name, new_device);
|
||||
return &reused_buffers_[buffer_name];
|
||||
cached_tensors_[tensor_name].is_pinned_memory = use_pinned_memory;
|
||||
cached_tensors_[tensor_name].Resize(new_shape, data_type, tensor_name,
|
||||
new_device);
|
||||
return &cached_tensors_[tensor_name];
|
||||
}
|
||||
|
||||
FDTensor* Processor::CreateCachedGpuInputTensor(
|
||||
Mat* mat, const std::string& tensor_name) {
|
||||
#ifdef WITH_GPU
|
||||
FDTensor* src = mat->Tensor();
|
||||
if (src->device == Device::GPU) {
|
||||
return src;
|
||||
} else if (src->device == Device::CPU) {
|
||||
FDTensor* tensor = UpdateAndGetCachedTensor(src->Shape(), src->Dtype(),
|
||||
tensor_name, Device::GPU);
|
||||
FDASSERT(cudaMemcpyAsync(tensor->Data(), src->Data(), tensor->Nbytes(),
|
||||
cudaMemcpyHostToDevice, mat->Stream()) == 0,
|
||||
"[ERROR] Error occurs while copy memory from CPU to GPU.");
|
||||
return tensor;
|
||||
} else {
|
||||
FDASSERT(false, "FDMat is on unsupported device: %d", src->device);
|
||||
}
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void EnableFlyCV() {
|
||||
|
@@ -59,16 +59,33 @@ class FASTDEPLOY_DECL Processor {
|
||||
return ImplByOpenCV(mat);
|
||||
}
|
||||
|
||||
virtual bool ImplByCvCuda(Mat* mat) {
|
||||
return ImplByOpenCV(mat);
|
||||
}
|
||||
|
||||
virtual bool operator()(Mat* mat, ProcLib lib = ProcLib::DEFAULT);
|
||||
|
||||
protected:
|
||||
FDTensor* UpdateAndGetReusedBuffer(
|
||||
const std::vector<int64_t>& new_shape, const int& opencv_dtype,
|
||||
const std::string& buffer_name, const Device& new_device = Device::CPU,
|
||||
// Update and get the cached tensor from the cached_tensors_ map.
|
||||
// The tensor is indexed by a string.
|
||||
// If the tensor doesn't exists in the map, then create a new tensor.
|
||||
// If the tensor exists and shape is getting larger, then realloc the buffer.
|
||||
// If the tensor exists and shape is not getting larger, then return the
|
||||
// cached tensor directly.
|
||||
FDTensor* UpdateAndGetCachedTensor(
|
||||
const std::vector<int64_t>& new_shape, const FDDataType& data_type,
|
||||
const std::string& tensor_name, const Device& new_device = Device::CPU,
|
||||
const bool& use_pinned_memory = false);
|
||||
|
||||
// Create an input tensor on GPU and save into cached_tensors_.
|
||||
// If the Mat is on GPU, return the mat->Tensor() directly.
|
||||
// If the Mat is on CPU, then create a cached GPU tensor and copy the mat's
|
||||
// CPU tensor to this new GPU tensor.
|
||||
FDTensor* CreateCachedGpuInputTensor(Mat* mat,
|
||||
const std::string& tensor_name);
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, FDTensor> reused_buffers_;
|
||||
std::unordered_map<std::string, FDTensor> cached_tensors_;
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
|
@@ -14,6 +14,12 @@
|
||||
|
||||
#include "fastdeploy/vision/common/processors/center_crop.h"
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
#include <cvcuda/OpCustomCrop.hpp>
|
||||
|
||||
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
@@ -56,6 +62,35 @@ bool CenterCrop::ImplByFlyCV(Mat* mat) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool CenterCrop::ImplByCvCuda(Mat* mat) {
|
||||
// Prepare input tensor
|
||||
std::string tensor_name = Name() + "_cvcuda_src";
|
||||
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
|
||||
auto src_tensor = CreateCvCudaTensorWrapData(*src);
|
||||
|
||||
// Prepare output tensor
|
||||
tensor_name = Name() + "_cvcuda_dst";
|
||||
FDTensor* dst =
|
||||
UpdateAndGetCachedTensor({height_, width_, mat->Channels()}, src->Dtype(),
|
||||
tensor_name, Device::GPU);
|
||||
auto dst_tensor = CreateCvCudaTensorWrapData(*dst);
|
||||
|
||||
int offset_x = static_cast<int>((mat->Width() - width_) / 2);
|
||||
int offset_y = static_cast<int>((mat->Height() - height_) / 2);
|
||||
cvcuda::CustomCrop crop_op;
|
||||
NVCVRectI crop_roi = {offset_x, offset_y, width_, height_};
|
||||
crop_op(mat->Stream(), src_tensor, dst_tensor, crop_roi);
|
||||
|
||||
mat->SetTensor(dst);
|
||||
mat->SetWidth(width_);
|
||||
mat->SetHeight(height_);
|
||||
mat->device = Device::GPU;
|
||||
mat->mat_type = ProcLib::CVCUDA;
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool CenterCrop::Run(Mat* mat, const int& width, const int& height,
|
||||
ProcLib lib) {
|
||||
auto c = CenterCrop(width, height);
|
||||
|
@@ -25,6 +25,9 @@ class FASTDEPLOY_DECL CenterCrop : public Processor {
|
||||
bool ImplByOpenCV(Mat* mat);
|
||||
#ifdef ENABLE_FLYCV
|
||||
bool ImplByFlyCV(Mat* mat);
|
||||
#endif
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool ImplByCvCuda(Mat* mat);
|
||||
#endif
|
||||
std::string Name() { return "CenterCrop"; }
|
||||
|
||||
|
76
fastdeploy/vision/common/processors/cvcuda_utils.cc
Normal file
76
fastdeploy/vision/common/processors/cvcuda_utils.cc
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel) {
|
||||
FDASSERT(channel == 1 || channel == 3 || channel == 4,
|
||||
"Only support channel be 1/3/4 in CV-CUDA.");
|
||||
if (type == FDDataType::UINT8) {
|
||||
if (channel == 1) {
|
||||
return nvcv::FMT_U8;
|
||||
} else if (channel == 3) {
|
||||
return nvcv::FMT_BGR8;
|
||||
} else {
|
||||
return nvcv::FMT_BGRA8;
|
||||
}
|
||||
} else if (type == FDDataType::FP32) {
|
||||
if (channel == 1) {
|
||||
return nvcv::FMT_F32;
|
||||
} else if (channel == 3) {
|
||||
return nvcv::FMT_BGRf32;
|
||||
} else {
|
||||
return nvcv::FMT_BGRAf32;
|
||||
}
|
||||
}
|
||||
FDASSERT(false, "Data type of %s is not supported.", Str(type).c_str());
|
||||
return nvcv::FMT_BGRf32;
|
||||
}
|
||||
|
||||
nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor) {
|
||||
FDASSERT(tensor.shape.size() == 3,
|
||||
"When create CVCUDA tensor from FD tensor,"
|
||||
"tensor shape should be 3-Dim, HWC layout");
|
||||
int batchsize = 1;
|
||||
|
||||
nvcv::TensorDataStridedCuda::Buffer buf;
|
||||
buf.strides[3] = FDDataTypeSize(tensor.Dtype());
|
||||
buf.strides[2] = tensor.shape[2] * buf.strides[3];
|
||||
buf.strides[1] = tensor.shape[1] * buf.strides[2];
|
||||
buf.strides[0] = tensor.shape[0] * buf.strides[1];
|
||||
buf.basePtr = reinterpret_cast<NVCVByte*>(const_cast<void*>(tensor.Data()));
|
||||
|
||||
nvcv::Tensor::Requirements req = nvcv::Tensor::CalcRequirements(
|
||||
batchsize, {tensor.shape[1], tensor.shape[0]},
|
||||
CreateCvCudaImageFormat(tensor.Dtype(), tensor.shape[2]));
|
||||
|
||||
nvcv::TensorDataStridedCuda tensor_data(
|
||||
nvcv::TensorShape{req.shape, req.rank, req.layout},
|
||||
nvcv::DataType{req.dtype}, buf);
|
||||
return nvcv::TensorWrapData(tensor_data);
|
||||
}
|
||||
|
||||
void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor) {
|
||||
auto data =
|
||||
dynamic_cast<const nvcv::ITensorDataStridedCuda*>(tensor.exportData());
|
||||
return reinterpret_cast<void*>(data->basePtr());
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
31
fastdeploy/vision/common/processors/cvcuda_utils.h
Normal file
31
fastdeploy/vision/common/processors/cvcuda_utils.h
Normal file
@@ -0,0 +1,31 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
#include "nvcv/Tensor.hpp"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel);
|
||||
nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor);
|
||||
void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor);
|
||||
|
||||
}
|
||||
}
|
||||
#endif
|
80
fastdeploy/vision/common/processors/manager.cc
Normal file
80
fastdeploy/vision/common/processors/manager.cc
Normal file
@@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "fastdeploy/vision/common/processors/manager.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
ProcessorManager::~ProcessorManager() {
|
||||
#ifdef WITH_GPU
|
||||
if (stream_) cudaStreamDestroy(stream_);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ProcessorManager::UseCuda(bool enable_cv_cuda, int gpu_id) {
|
||||
#ifdef WITH_GPU
|
||||
if (gpu_id >= 0) {
|
||||
device_id_ = gpu_id;
|
||||
FDASSERT(cudaSetDevice(device_id_) == cudaSuccess,
|
||||
"[ERROR] Error occurs while setting cuda device.");
|
||||
}
|
||||
FDASSERT(cudaStreamCreate(&stream_) == cudaSuccess,
|
||||
"[ERROR] Error occurs while creating cuda stream.");
|
||||
DefaultProcLib::default_lib = ProcLib::CUDA;
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
|
||||
#endif
|
||||
|
||||
if (enable_cv_cuda) {
|
||||
#ifdef ENABLE_CVCUDA
|
||||
DefaultProcLib::default_lib = ProcLib::CVCUDA;
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with CV-CUDA.");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
bool ProcessorManager::CudaUsed() {
|
||||
return (DefaultProcLib::default_lib == ProcLib::CUDA ||
|
||||
DefaultProcLib::default_lib == ProcLib::CVCUDA);
|
||||
}
|
||||
|
||||
bool ProcessorManager::Run(std::vector<FDMat>* images,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
if (!initialized_) {
|
||||
FDERROR << "The preprocessor is not initialized." << std::endl;
|
||||
return false;
|
||||
}
|
||||
if (images->size() == 0) {
|
||||
FDERROR << "The size of input images should be greater than 0."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
if (CudaUsed()) {
|
||||
SetStream(&((*images)[i]));
|
||||
}
|
||||
}
|
||||
|
||||
bool ret = Apply(images, outputs);
|
||||
|
||||
if (CudaUsed()) {
|
||||
SyncStream();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
74
fastdeploy/vision/common/processors/manager.h
Normal file
74
fastdeploy/vision/common/processors/manager.h
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
#include "fastdeploy/vision/common/processors/mat.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
class FASTDEPLOY_DECL ProcessorManager {
|
||||
public:
|
||||
~ProcessorManager();
|
||||
|
||||
void UseCuda(bool enable_cv_cuda = false, int gpu_id = -1);
|
||||
|
||||
bool CudaUsed();
|
||||
|
||||
void SetStream(Mat* mat) {
|
||||
#ifdef WITH_GPU
|
||||
mat->SetStream(stream_);
|
||||
#endif
|
||||
}
|
||||
|
||||
void SyncStream() {
|
||||
#ifdef WITH_GPU
|
||||
FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
|
||||
"[ERROR] Error occurs while sync cuda stream.");
|
||||
#endif
|
||||
}
|
||||
|
||||
int DeviceId() { return device_id_; }
|
||||
|
||||
/** \brief Process the input image and prepare input tensors for runtime
|
||||
*
|
||||
* \param[in] images The input image data list, all the elements are returned by cv::imread()
|
||||
* \param[in] outputs The output tensors which will feed in runtime
|
||||
* \return true if the preprocess successed, otherwise false
|
||||
*/
|
||||
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
|
||||
|
||||
/** \brief The body of Run() function which needs to be implemented by a derived class
|
||||
*
|
||||
* \param[in] images The input image data list, all the elements are returned by cv::imread()
|
||||
* \param[in] outputs The output tensors which will feed in runtime
|
||||
* \return true if the preprocess successed, otherwise false
|
||||
*/
|
||||
virtual bool Apply(std::vector<FDMat>* images,
|
||||
std::vector<FDTensor>* outputs) = 0;
|
||||
|
||||
protected:
|
||||
bool initialized_ = false;
|
||||
|
||||
private:
|
||||
#ifdef WITH_GPU
|
||||
cudaStream_t stream_ = nullptr;
|
||||
#endif
|
||||
int device_id_ = -1;
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -19,6 +19,36 @@
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
cv::Mat* Mat::GetOpenCVMat() {
|
||||
if (mat_type == ProcLib::OPENCV) {
|
||||
return &cpu_mat;
|
||||
} else if (mat_type == ProcLib::FLYCV) {
|
||||
#ifdef ENABLE_FLYCV
|
||||
// Just a reference to fcv_mat, zero copy. After you
|
||||
// call this method, cpu_mat and fcv_mat will point
|
||||
// to the same memory buffer.
|
||||
cpu_mat = ConvertFlyCVMatToOpenCV(fcv_mat);
|
||||
mat_type = ProcLib::OPENCV;
|
||||
return &cpu_mat;
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compiled with FlyCV!");
|
||||
#endif
|
||||
} else if (mat_type == ProcLib::CUDA || mat_type == ProcLib::CVCUDA) {
|
||||
#ifdef WITH_GPU
|
||||
FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess,
|
||||
"[ERROR] Error occurs while sync cuda stream.");
|
||||
cpu_mat = CreateZeroCopyOpenCVMatFromTensor(fd_tensor);
|
||||
mat_type = ProcLib::OPENCV;
|
||||
device = Device::CPU;
|
||||
return &cpu_mat;
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compiled with -DWITH_GPU=ON");
|
||||
#endif
|
||||
} else {
|
||||
FDASSERT(false, "The mat_type of custom Mat can not be ProcLib::DEFAULT");
|
||||
}
|
||||
}
|
||||
|
||||
void* Mat::Data() {
|
||||
if (mat_type == ProcLib::FLYCV) {
|
||||
#ifdef ENABLE_FLYCV
|
||||
@@ -28,10 +58,32 @@ void* Mat::Data() {
|
||||
"FastDeploy didn't compile with FlyCV, but met data type with "
|
||||
"fcv::Mat.");
|
||||
#endif
|
||||
} else if (device == Device::GPU) {
|
||||
return fd_tensor.Data();
|
||||
}
|
||||
return cpu_mat.ptr();
|
||||
}
|
||||
|
||||
FDTensor* Mat::Tensor() {
|
||||
if (mat_type == ProcLib::OPENCV) {
|
||||
ShareWithTensor(&fd_tensor);
|
||||
} else if (mat_type == ProcLib::FLYCV) {
|
||||
#ifdef ENABLE_FLYCV
|
||||
cpu_mat = ConvertFlyCVMatToOpenCV(fcv_mat);
|
||||
mat_type = ProcLib::OPENCV;
|
||||
ShareWithTensor(&fd_tensor);
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compiled with FlyCV!");
|
||||
#endif
|
||||
}
|
||||
return &fd_tensor;
|
||||
}
|
||||
|
||||
void Mat::SetTensor(FDTensor* tensor) {
|
||||
fd_tensor.SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(),
|
||||
tensor->device, tensor->device_id);
|
||||
}
|
||||
|
||||
void Mat::ShareWithTensor(FDTensor* tensor) {
|
||||
tensor->SetExternalData({Channels(), Height(), Width()}, Type(), Data());
|
||||
tensor->device = device;
|
||||
@@ -54,15 +106,15 @@ bool Mat::CopyToTensor(FDTensor* tensor) {
|
||||
}
|
||||
|
||||
void Mat::PrintInfo(const std::string& flag) {
|
||||
if (mat_type == ProcLib::FLYCV) {
|
||||
#ifdef ENABLE_FLYCV
|
||||
fcv::Scalar mean = fcv::mean(fcv_mat);
|
||||
std::cout << flag << ": "
|
||||
<< "DataType=" << Type() << ", "
|
||||
<< "Channel=" << Channels() << ", "
|
||||
<< "Height=" << Height() << ", "
|
||||
<< "Width=" << Width() << ", "
|
||||
<< "Mean=";
|
||||
if (mat_type == ProcLib::FLYCV) {
|
||||
#ifdef ENABLE_FLYCV
|
||||
fcv::Scalar mean = fcv::mean(fcv_mat);
|
||||
for (int i = 0; i < Channels(); ++i) {
|
||||
std::cout << mean[i] << " ";
|
||||
}
|
||||
@@ -72,18 +124,25 @@ void Mat::PrintInfo(const std::string& flag) {
|
||||
"FastDeploy didn't compile with FlyCV, but met data type with "
|
||||
"fcv::Mat.");
|
||||
#endif
|
||||
} else {
|
||||
} else if (mat_type == ProcLib::OPENCV) {
|
||||
cv::Scalar mean = cv::mean(cpu_mat);
|
||||
std::cout << flag << ": "
|
||||
<< "DataType=" << Type() << ", "
|
||||
<< "Channel=" << Channels() << ", "
|
||||
<< "Height=" << Height() << ", "
|
||||
<< "Width=" << Width() << ", "
|
||||
<< "Mean=";
|
||||
for (int i = 0; i < Channels(); ++i) {
|
||||
std::cout << mean[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
} else if (mat_type == ProcLib::CUDA || mat_type == ProcLib::CVCUDA) {
|
||||
#ifdef WITH_GPU
|
||||
FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess,
|
||||
"[ERROR] Error occurs while sync cuda stream.");
|
||||
cv::Mat tmp_mat = CreateZeroCopyOpenCVMatFromTensor(fd_tensor);
|
||||
cv::Scalar mean = cv::mean(tmp_mat);
|
||||
for (int i = 0; i < Channels(); ++i) {
|
||||
std::cout << mean[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compiled with -DWITH_GPU=ON");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,6 +156,8 @@ FDDataType Mat::Type() {
|
||||
"FastDeploy didn't compile with FlyCV, but met data type with "
|
||||
"fcv::Mat.");
|
||||
#endif
|
||||
} else if (mat_type == ProcLib::CUDA || mat_type == ProcLib::CVCUDA) {
|
||||
return fd_tensor.Dtype();
|
||||
}
|
||||
return OpenCVDataTypeToFD(cpu_mat.type());
|
||||
}
|
||||
@@ -134,42 +195,41 @@ Mat Mat::Create(const FDTensor& tensor, ProcLib lib) {
|
||||
return mat;
|
||||
}
|
||||
|
||||
Mat Mat::Create(int height, int width, int channels,
|
||||
FDDataType type, void* data) {
|
||||
Mat Mat::Create(int height, int width, int channels, FDDataType type,
|
||||
void* data) {
|
||||
if (DefaultProcLib::default_lib == ProcLib::FLYCV) {
|
||||
#ifdef ENABLE_FLYCV
|
||||
fcv::Mat tmp_fcv_mat = CreateZeroCopyFlyCVMatFromBuffer(
|
||||
height, width, channels, type, data);
|
||||
fcv::Mat tmp_fcv_mat =
|
||||
CreateZeroCopyFlyCVMatFromBuffer(height, width, channels, type, data);
|
||||
Mat mat = Mat(tmp_fcv_mat);
|
||||
return mat;
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compiled with FlyCV!");
|
||||
#endif
|
||||
}
|
||||
cv::Mat tmp_ocv_mat = CreateZeroCopyOpenCVMatFromBuffer(
|
||||
height, width, channels, type, data);
|
||||
cv::Mat tmp_ocv_mat =
|
||||
CreateZeroCopyOpenCVMatFromBuffer(height, width, channels, type, data);
|
||||
Mat mat = Mat(tmp_ocv_mat);
|
||||
return mat;
|
||||
}
|
||||
|
||||
Mat Mat::Create(int height, int width, int channels,
|
||||
FDDataType type, void* data,
|
||||
ProcLib lib) {
|
||||
Mat Mat::Create(int height, int width, int channels, FDDataType type,
|
||||
void* data, ProcLib lib) {
|
||||
if (lib == ProcLib::DEFAULT) {
|
||||
return Create(height, width, channels, type, data);
|
||||
}
|
||||
if (lib == ProcLib::FLYCV) {
|
||||
#ifdef ENABLE_FLYCV
|
||||
fcv::Mat tmp_fcv_mat = CreateZeroCopyFlyCVMatFromBuffer(
|
||||
height, width, channels, type, data);
|
||||
fcv::Mat tmp_fcv_mat =
|
||||
CreateZeroCopyFlyCVMatFromBuffer(height, width, channels, type, data);
|
||||
Mat mat = Mat(tmp_fcv_mat);
|
||||
return mat;
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compiled with FlyCV!");
|
||||
#endif
|
||||
}
|
||||
cv::Mat tmp_ocv_mat = CreateZeroCopyOpenCVMatFromBuffer(
|
||||
height, width, channels, type, data);
|
||||
cv::Mat tmp_ocv_mat =
|
||||
CreateZeroCopyOpenCVMatFromBuffer(height, width, channels, type, data);
|
||||
Mat mat = Mat(tmp_ocv_mat);
|
||||
return mat;
|
||||
}
|
||||
|
@@ -17,6 +17,10 @@
|
||||
#include "fastdeploy/vision/common/processors/proc_lib.h"
|
||||
#include "opencv2/core/core.hpp"
|
||||
|
||||
#ifdef WITH_GPU
|
||||
#include <cuda_runtime_api.h>
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
@@ -60,24 +64,7 @@ struct FASTDEPLOY_DECL Mat {
|
||||
mat_type = ProcLib::OPENCV;
|
||||
}
|
||||
|
||||
cv::Mat* GetOpenCVMat() {
|
||||
if (mat_type == ProcLib::OPENCV) {
|
||||
return &cpu_mat;
|
||||
} else if (mat_type == ProcLib::FLYCV) {
|
||||
#ifdef ENABLE_FLYCV
|
||||
// Just a reference to fcv_mat, zero copy. After you
|
||||
// call this method, cpu_mat and fcv_mat will point
|
||||
// to the same memory buffer.
|
||||
cpu_mat = ConvertFlyCVMatToOpenCV(fcv_mat);
|
||||
mat_type = ProcLib::OPENCV;
|
||||
return &cpu_mat;
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compiled with FlyCV!");
|
||||
#endif
|
||||
} else {
|
||||
FDASSERT(false, "The mat_type of custom Mat can not be ProcLib::DEFAULT");
|
||||
}
|
||||
}
|
||||
cv::Mat* GetOpenCVMat();
|
||||
|
||||
#ifdef ENABLE_FLYCV
|
||||
void SetMat(const fcv::Mat& mat) {
|
||||
@@ -103,6 +90,12 @@ struct FASTDEPLOY_DECL Mat {
|
||||
|
||||
void* Data();
|
||||
|
||||
// Get fd_tensor
|
||||
FDTensor* Tensor();
|
||||
|
||||
// Set fd_tensor
|
||||
void SetTensor(FDTensor* tensor);
|
||||
|
||||
private:
|
||||
int channels;
|
||||
int height;
|
||||
@@ -111,6 +104,12 @@ struct FASTDEPLOY_DECL Mat {
|
||||
#ifdef ENABLE_FLYCV
|
||||
fcv::Mat fcv_mat;
|
||||
#endif
|
||||
#ifdef WITH_GPU
|
||||
cudaStream_t stream = nullptr;
|
||||
#endif
|
||||
// Currently, fd_tensor is only used by CUDA and CV-CUDA,
|
||||
// OpenCV and FlyCV are not using it.
|
||||
FDTensor fd_tensor;
|
||||
|
||||
public:
|
||||
FDDataType Type();
|
||||
@@ -120,6 +119,10 @@ struct FASTDEPLOY_DECL Mat {
|
||||
void SetChannels(int s) { channels = s; }
|
||||
void SetWidth(int w) { width = w; }
|
||||
void SetHeight(int h) { height = h; }
|
||||
#ifdef WITH_GPU
|
||||
cudaStream_t Stream() const { return stream; }
|
||||
void SetStream(cudaStream_t s) { stream = s; }
|
||||
#endif
|
||||
|
||||
// Transfer the vision::Mat to FDTensor
|
||||
void ShareWithTensor(FDTensor* tensor);
|
||||
|
@@ -37,49 +37,46 @@ __global__ void NormalizeAndPermuteKernel(uint8_t* src, float* dst,
|
||||
}
|
||||
|
||||
bool NormalizeAndPermute::ImplByCuda(Mat* mat) {
|
||||
cv::Mat* im = mat->GetOpenCVMat();
|
||||
std::string buf_name = Name() + "_src";
|
||||
std::vector<int64_t> shape = {im->rows, im->cols, im->channels()};
|
||||
FDTensor* src =
|
||||
UpdateAndGetReusedBuffer(shape, im->type(), buf_name, Device::GPU);
|
||||
FDASSERT(cudaMemcpy(src->Data(), im->ptr(), src->Nbytes(),
|
||||
cudaMemcpyHostToDevice) == 0,
|
||||
"Error occurs while copy memory from CPU to GPU.");
|
||||
// Prepare input tensor
|
||||
std::string tensor_name = Name() + "_cvcuda_src";
|
||||
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
|
||||
|
||||
buf_name = Name() + "_dst";
|
||||
FDTensor* dst = UpdateAndGetReusedBuffer(shape, CV_32FC(im->channels()),
|
||||
buf_name, Device::GPU);
|
||||
cv::Mat res(im->rows, im->cols, CV_32FC(im->channels()), dst->Data());
|
||||
// Prepare output tensor
|
||||
tensor_name = Name() + "_dst";
|
||||
FDTensor* dst = UpdateAndGetCachedTensor(src->Shape(), FDDataType::FP32,
|
||||
tensor_name, Device::GPU);
|
||||
|
||||
buf_name = Name() + "_alpha";
|
||||
FDTensor* alpha = UpdateAndGetReusedBuffer({(int64_t)alpha_.size()}, CV_32FC1,
|
||||
buf_name, Device::GPU);
|
||||
FDASSERT(cudaMemcpy(alpha->Data(), alpha_.data(), alpha->Nbytes(),
|
||||
cudaMemcpyHostToDevice) == 0,
|
||||
"Error occurs while copy memory from CPU to GPU.");
|
||||
// Copy alpha and beta to GPU
|
||||
tensor_name = Name() + "_alpha";
|
||||
FDMat alpha_mat =
|
||||
FDMat::Create(1, 1, alpha_.size(), FDDataType::FP32, alpha_.data());
|
||||
FDTensor* alpha = CreateCachedGpuInputTensor(&alpha_mat, tensor_name);
|
||||
|
||||
buf_name = Name() + "_beta";
|
||||
FDTensor* beta = UpdateAndGetReusedBuffer({(int64_t)beta_.size()}, CV_32FC1,
|
||||
buf_name, Device::GPU);
|
||||
FDASSERT(cudaMemcpy(beta->Data(), beta_.data(), beta->Nbytes(),
|
||||
cudaMemcpyHostToDevice) == 0,
|
||||
"Error occurs while copy memory from CPU to GPU.");
|
||||
tensor_name = Name() + "_beta";
|
||||
FDMat beta_mat =
|
||||
FDMat::Create(1, 1, beta_.size(), FDDataType::FP32, beta_.data());
|
||||
FDTensor* beta = CreateCachedGpuInputTensor(&beta_mat, tensor_name);
|
||||
|
||||
int jobs = im->cols * im->rows;
|
||||
int jobs = mat->Width() * mat->Height();
|
||||
int threads = 256;
|
||||
int blocks = ceil(jobs / (float)threads);
|
||||
NormalizeAndPermuteKernel<<<blocks, threads, 0, NULL>>>(
|
||||
NormalizeAndPermuteKernel<<<blocks, threads, 0, mat->Stream()>>>(
|
||||
reinterpret_cast<uint8_t*>(src->Data()),
|
||||
reinterpret_cast<float*>(dst->Data()),
|
||||
reinterpret_cast<float*>(alpha->Data()),
|
||||
reinterpret_cast<float*>(beta->Data()), im->channels(), swap_rb_, jobs);
|
||||
reinterpret_cast<float*>(beta->Data()), mat->Channels(), swap_rb_, jobs);
|
||||
|
||||
mat->SetMat(res);
|
||||
mat->SetTensor(dst);
|
||||
mat->device = Device::GPU;
|
||||
mat->layout = Layout::CHW;
|
||||
mat->mat_type = ProcLib::CUDA;
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool NormalizeAndPermute::ImplByCvCuda(Mat* mat) { return ImplByCuda(mat); }
|
||||
#endif
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
#endif
|
||||
|
@@ -31,6 +31,9 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor {
|
||||
#endif
|
||||
#ifdef WITH_GPU
|
||||
bool ImplByCuda(Mat* mat);
|
||||
#endif
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool ImplByCvCuda(Mat* mat);
|
||||
#endif
|
||||
std::string Name() { return "NormalizeAndPermute"; }
|
||||
|
||||
|
@@ -18,7 +18,7 @@
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
enum class FASTDEPLOY_DECL ProcLib { DEFAULT, OPENCV, FLYCV, CUDA };
|
||||
enum class FASTDEPLOY_DECL ProcLib { DEFAULT, OPENCV, FLYCV, CUDA, CVCUDA };
|
||||
|
||||
FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out, const ProcLib& p);
|
||||
|
||||
|
@@ -14,6 +14,12 @@
|
||||
|
||||
#include "fastdeploy/vision/common/processors/resize.h"
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
#include <cvcuda/OpResize.hpp>
|
||||
|
||||
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
@@ -116,6 +122,52 @@ bool Resize::ImplByFlyCV(Mat* mat) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool Resize::ImplByCvCuda(Mat* mat) {
|
||||
if (width_ == mat->Width() && height_ == mat->Height()) {
|
||||
return true;
|
||||
}
|
||||
if (fabs(scale_w_ - 1.0) < 1e-06 && fabs(scale_h_ - 1.0) < 1e-06) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (width_ > 0 && height_ > 0) {
|
||||
} else if (scale_w_ > 0 && scale_h_ > 0) {
|
||||
width_ = std::round(scale_w_ * mat->Width());
|
||||
height_ = std::round(scale_h_ * mat->Height());
|
||||
} else {
|
||||
FDERROR << "Resize: the parameters must satisfy (width > 0 && height > 0) "
|
||||
"or (scale_w > 0 && scale_h > 0)."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Prepare input tensor
|
||||
std::string tensor_name = Name() + "_cvcuda_src";
|
||||
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
|
||||
auto src_tensor = CreateCvCudaTensorWrapData(*src);
|
||||
|
||||
// Prepare output tensor
|
||||
tensor_name = Name() + "_cvcuda_dst";
|
||||
FDTensor* dst =
|
||||
UpdateAndGetCachedTensor({height_, width_, mat->Channels()}, mat->Type(),
|
||||
tensor_name, Device::GPU);
|
||||
auto dst_tensor = CreateCvCudaTensorWrapData(*dst);
|
||||
|
||||
// CV-CUDA Interp value is compatible with OpenCV
|
||||
cvcuda::Resize resize_op;
|
||||
resize_op(mat->Stream(), src_tensor, dst_tensor,
|
||||
NVCVInterpolationType(interp_));
|
||||
|
||||
mat->SetTensor(dst);
|
||||
mat->SetWidth(width_);
|
||||
mat->SetHeight(height_);
|
||||
mat->device = Device::GPU;
|
||||
mat->mat_type = ProcLib::CVCUDA;
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool Resize::Run(Mat* mat, int width, int height, float scale_w, float scale_h,
|
||||
int interp, bool use_scale, ProcLib lib) {
|
||||
if (mat->Height() == height && mat->Width() == width) {
|
||||
|
@@ -34,6 +34,9 @@ class FASTDEPLOY_DECL Resize : public Processor {
|
||||
bool ImplByOpenCV(Mat* mat);
|
||||
#ifdef ENABLE_FLYCV
|
||||
bool ImplByFlyCV(Mat* mat);
|
||||
#endif
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool ImplByCvCuda(Mat* mat);
|
||||
#endif
|
||||
std::string Name() { return "Resize"; }
|
||||
|
||||
|
@@ -14,6 +14,12 @@
|
||||
|
||||
#include "fastdeploy/vision/common/processors/resize_by_short.h"
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
#include <cvcuda/OpResize.hpp>
|
||||
|
||||
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
@@ -80,6 +86,37 @@ bool ResizeByShort::ImplByFlyCV(Mat* mat) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool ResizeByShort::ImplByCvCuda(Mat* mat) {
|
||||
// Prepare input tensor
|
||||
std::string tensor_name = Name() + "_cvcuda_src";
|
||||
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
|
||||
auto src_tensor = CreateCvCudaTensorWrapData(*src);
|
||||
|
||||
double scale = GenerateScale(mat->Width(), mat->Height());
|
||||
int width = static_cast<int>(round(scale * mat->Width()));
|
||||
int height = static_cast<int>(round(scale * mat->Height()));
|
||||
|
||||
// Prepare output tensor
|
||||
tensor_name = Name() + "_cvcuda_dst";
|
||||
FDTensor* dst = UpdateAndGetCachedTensor(
|
||||
{height, width, mat->Channels()}, mat->Type(), tensor_name, Device::GPU);
|
||||
auto dst_tensor = CreateCvCudaTensorWrapData(*dst);
|
||||
|
||||
// CV-CUDA Interp value is compatible with OpenCV
|
||||
cvcuda::Resize resize_op;
|
||||
resize_op(mat->Stream(), src_tensor, dst_tensor,
|
||||
NVCVInterpolationType(interp_));
|
||||
|
||||
mat->SetTensor(dst);
|
||||
mat->SetWidth(width);
|
||||
mat->SetHeight(height);
|
||||
mat->device = Device::GPU;
|
||||
mat->mat_type = ProcLib::CVCUDA;
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
double ResizeByShort::GenerateScale(const int origin_w, const int origin_h) {
|
||||
int im_size_max = std::max(origin_w, origin_h);
|
||||
int im_size_min = std::min(origin_w, origin_h);
|
||||
|
@@ -31,6 +31,9 @@ class FASTDEPLOY_DECL ResizeByShort : public Processor {
|
||||
bool ImplByOpenCV(Mat* mat);
|
||||
#ifdef ENABLE_FLYCV
|
||||
bool ImplByFlyCV(Mat* mat);
|
||||
#endif
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool ImplByCvCuda(Mat* mat);
|
||||
#endif
|
||||
std::string Name() { return "ResizeByShort"; }
|
||||
|
||||
|
@@ -35,12 +35,13 @@ class PaddleClasPreprocessor:
|
||||
"""
|
||||
return self._preprocessor.run(input_ims)
|
||||
|
||||
def use_gpu(self, gpu_id=-1):
|
||||
def use_cuda(self, enable_cv_cuda=False, gpu_id=-1):
|
||||
"""Use CUDA preprocessors
|
||||
|
||||
:param: enable_cv_cuda: Whether to enable CV-CUDA
|
||||
:param: gpu_id: GPU device id
|
||||
"""
|
||||
return self._preprocessor.use_gpu(gpu_id)
|
||||
return self._preprocessor.use_cuda(enable_cv_cuda, gpu_id)
|
||||
|
||||
def disable_normalize(self):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user