[CVCUDA] Utilize CV-CUDA batch processing function (#1223)

* norm and permute batch processing

* move cache to mat, batch processors

* get batched tensor logic, resize on cpu logic

* fix cpu compile error

* remove vector mat api

* nits

* add comments

* nits

* fix batch size

* move initial resize on cpu option to use_cuda api

* fix pybind

* processor manager pybind

* rename mat and matbatch

* move initial resize on cpu to ppcls preprocessor

---------

Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
Wang Xinyu
2023-02-07 13:44:30 +08:00
committed by GitHub
parent 7c9bf11c44
commit d3d914856d
29 changed files with 710 additions and 241 deletions

View File

@@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/core/float16.h"
#include "fastdeploy/utils/utils.h"
#include <algorithm>
#include <cstring>
#include "fastdeploy/core/float16.h"
#include "fastdeploy/utils/utils.h"
#ifdef WITH_GPU
#include <cuda_runtime_api.h>
#endif
@@ -142,6 +143,9 @@ void FDTensor::Resize(const std::vector<int64_t>& new_shape,
const FDDataType& data_type,
const std::string& tensor_name,
const Device& new_device) {
if (device != new_device) {
FreeFn();
}
external_data_ptr = nullptr;
name = tensor_name;
device = new_device;
@@ -269,7 +273,8 @@ bool FDTensor::ReallocFn(size_t nbytes) {
}
return buffer_ != nullptr;
#else
FDASSERT(false, "The FastDeploy FDTensor allocator didn't compile under "
FDASSERT(false,
"The FastDeploy FDTensor allocator didn't compile under "
"-DWITH_GPU=ON,"
"so this is an unexpected problem happend.");
#endif
@@ -285,7 +290,8 @@ bool FDTensor::ReallocFn(size_t nbytes) {
}
return buffer_ != nullptr;
#else
FDASSERT(false, "The FastDeploy FDTensor allocator didn't compile under "
FDASSERT(false,
"The FastDeploy FDTensor allocator didn't compile under "
"-DWITH_GPU=ON,"
"so this is an unexpected problem happend.");
#endif
@@ -296,8 +302,7 @@ bool FDTensor::ReallocFn(size_t nbytes) {
}
void FDTensor::FreeFn() {
if (external_data_ptr != nullptr)
external_data_ptr = nullptr;
if (external_data_ptr != nullptr) external_data_ptr = nullptr;
if (buffer_ != nullptr) {
if (device == Device::GPU) {
#ifdef WITH_GPU
@@ -386,8 +391,11 @@ FDTensor::FDTensor(const Scalar& scalar) {
}
FDTensor::FDTensor(const FDTensor& other)
: shape(other.shape), name(other.name), dtype(other.dtype),
device(other.device), external_data_ptr(other.external_data_ptr),
: shape(other.shape),
name(other.name),
dtype(other.dtype),
device(other.device),
external_data_ptr(other.external_data_ptr),
device_id(other.device_id) {
// Copy buffer
if (other.buffer_ == nullptr) {
@@ -401,9 +409,12 @@ FDTensor::FDTensor(const FDTensor& other)
}
FDTensor::FDTensor(FDTensor&& other)
: buffer_(other.buffer_), shape(std::move(other.shape)),
name(std::move(other.name)), dtype(other.dtype),
external_data_ptr(other.external_data_ptr), device(other.device),
: buffer_(other.buffer_),
shape(std::move(other.shape)),
name(std::move(other.name)),
dtype(other.dtype),
external_data_ptr(other.external_data_ptr),
device(other.device),
device_id(other.device_id) {
other.name = "";
// Note(zhoushunjie): Avoid double free.

View File

@@ -15,33 +15,9 @@
namespace fastdeploy {
void BindPaddleClas(pybind11::module& m) {
pybind11::class_<vision::classification::PaddleClasPreprocessor>(
m, "PaddleClasPreprocessor")
pybind11::class_<vision::classification::PaddleClasPreprocessor,
vision::ProcessorManager>(m, "PaddleClasPreprocessor")
.def(pybind11::init<std::string>())
.def("run",
[](vision::classification::PaddleClasPreprocessor& self,
std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
}
std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) {
throw std::runtime_error(
"Failed to preprocess the input data in "
"PaddleClasPreprocessor.");
}
if (!self.CudaUsed()) {
for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing();
}
}
return outputs;
})
.def("use_cuda",
[](vision::classification::PaddleClasPreprocessor& self,
bool enable_cv_cuda = false,
int gpu_id = -1) { self.UseCuda(enable_cv_cuda, gpu_id); })
.def("disable_normalize",
[](vision::classification::PaddleClasPreprocessor& self) {
self.DisableNormalize();
@@ -49,6 +25,10 @@ void BindPaddleClas(pybind11::module& m) {
.def("disable_permute",
[](vision::classification::PaddleClasPreprocessor& self) {
self.DisablePermute();
})
.def("initial_resize_on_cpu",
[](vision::classification::PaddleClasPreprocessor& self, bool v) {
self.InitialResizeOnCpu(v);
});
pybind11::class_<vision::classification::PaddleClasPostprocessor>(

View File

@@ -100,32 +100,23 @@ void PaddleClasPreprocessor::DisablePermute() {
}
}
bool PaddleClasPreprocessor::Apply(std::vector<FDMat>* images,
bool PaddleClasPreprocessor::Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs) {
for (size_t i = 0; i < images->size(); ++i) {
for (size_t j = 0; j < processors_.size(); ++j) {
bool ret = false;
ret = (*(processors_[j].get()))(&((*images)[i]));
if (!ret) {
FDERROR << "Failed to processs image:" << i << " in "
<< processors_[j]->Name() << "." << std::endl;
return false;
ProcLib lib = ProcLib::DEFAULT;
if (initial_resize_on_cpu_ && j == 0 &&
processors_[j]->Name().find("Resize") == 0) {
lib = ProcLib::OPENCV;
}
if (!(*(processors_[j].get()))(image_batch, lib)) {
FDERROR << "Failed to processs image in " << processors_[j]->Name() << "."
<< std::endl;
return false;
}
}
outputs->resize(1);
// Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> tensors(images->size());
for (size_t i = 0; i < images->size(); ++i) {
(*images)[i].ShareWithTensor(&(tensors[i]));
tensors[i].ExpandDim(0);
}
if (tensors.size() == 1) {
(*outputs)[0] = std::move(tensors[0]);
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
}
(*outputs)[0] = std::move(*(image_batch->Tensor()));
(*outputs)[0].device_id = DeviceId();
return true;
}

View File

@@ -33,11 +33,11 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager {
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] image_batch The input image batch
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
virtual bool Apply(std::vector<FDMat>* images,
virtual bool Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs);
/// This function will disable normalize in preprocessing step.
@@ -45,6 +45,14 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager {
/// This function will disable hwc2chw in preprocessing step.
void DisablePermute();
/** \brief When the initial operator is Resize, and input image size is large,
* maybe it's better to run resize on CPU, because the HostToDevice memcpy
* is time consuming. Set this true to run the initial resize on CPU.
*
* \param[in] v ture or false
*/
void InitialResizeOnCpu(bool v) { initial_resize_on_cpu_ = v; }
private:
bool BuildPreprocessPipelineFromConfig();
std::vector<std::shared_ptr<Processor>> processors_;
@@ -54,6 +62,7 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager {
bool disable_normalize_ = false;
// read config file
std::string config_file_;
bool initial_resize_on_cpu_ = false;
};
} // namespace classification

View File

@@ -20,7 +20,7 @@
namespace fastdeploy {
namespace vision {
bool Processor::operator()(Mat* mat, ProcLib lib) {
bool Processor::operator()(FDMat* mat, ProcLib lib) {
ProcLib target = lib;
if (lib == ProcLib::DEFAULT) {
target = DefaultProcLib::default_lib;
@@ -52,39 +52,38 @@ bool Processor::operator()(Mat* mat, ProcLib lib) {
return ImplByOpenCV(mat);
}
FDTensor* Processor::UpdateAndGetCachedTensor(
const std::vector<int64_t>& new_shape, const FDDataType& data_type,
const std::string& tensor_name, const Device& new_device,
const bool& use_pinned_memory) {
if (cached_tensors_.count(tensor_name) == 0) {
cached_tensors_[tensor_name] = FDTensor();
bool Processor::operator()(FDMatBatch* mat_batch, ProcLib lib) {
ProcLib target = lib;
if (lib == ProcLib::DEFAULT) {
target = DefaultProcLib::default_lib;
}
cached_tensors_[tensor_name].is_pinned_memory = use_pinned_memory;
cached_tensors_[tensor_name].Resize(new_shape, data_type, tensor_name,
new_device);
return &cached_tensors_[tensor_name];
}
FDTensor* Processor::CreateCachedGpuInputTensor(
Mat* mat, const std::string& tensor_name) {
if (target == ProcLib::FLYCV) {
#ifdef ENABLE_FLYCV
return ImplByFlyCV(mat_batch);
#else
FDASSERT(false, "FastDeploy didn't compile with FlyCV.");
#endif
} else if (target == ProcLib::CUDA) {
#ifdef WITH_GPU
FDTensor* src = mat->Tensor();
if (src->device == Device::GPU) {
return src;
} else if (src->device == Device::CPU) {
FDTensor* tensor = UpdateAndGetCachedTensor(src->Shape(), src->Dtype(),
tensor_name, Device::GPU);
FDASSERT(cudaMemcpyAsync(tensor->Data(), src->Data(), tensor->Nbytes(),
cudaMemcpyHostToDevice, mat->Stream()) == 0,
"[ERROR] Error occurs while copy memory from CPU to GPU.");
return tensor;
} else {
FDASSERT(false, "FDMat is on unsupported device: %d", src->device);
}
FDASSERT(
mat_batch->Stream() != nullptr,
"CUDA processor requires cuda stream, please set stream for mat_batch");
return ImplByCuda(mat_batch);
#else
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
#endif
return nullptr;
} else if (target == ProcLib::CVCUDA) {
#ifdef ENABLE_CVCUDA
FDASSERT(mat_batch->Stream() != nullptr,
"CV-CUDA processor requires cuda stream, please set stream for "
"mat_batch");
return ImplByCvCuda(mat_batch);
#else
FDASSERT(false, "FastDeploy didn't compile with CV-CUDA.");
#endif
}
// DEFAULT & OPENCV
return ImplByOpenCV(mat_batch);
}
void EnableFlyCV() {

View File

@@ -16,6 +16,7 @@
#include "fastdeploy/utils/utils.h"
#include "fastdeploy/vision/common/processors/mat.h"
#include "fastdeploy/vision/common/processors/mat_batch.h"
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"
#include <unordered_map>
@@ -46,46 +47,63 @@ class FASTDEPLOY_DECL Processor {
virtual std::string Name() = 0;
virtual bool ImplByOpenCV(Mat* mat) {
virtual bool ImplByOpenCV(FDMat* mat) {
FDERROR << Name() << " Not Implement Yet." << std::endl;
return false;
}
virtual bool ImplByFlyCV(Mat* mat) {
virtual bool ImplByOpenCV(FDMatBatch* mat_batch) {
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
if (ImplByOpenCV(&(*(mat_batch->mats))[i]) != true) {
return false;
}
}
return true;
}
virtual bool ImplByFlyCV(FDMat* mat) {
return ImplByOpenCV(mat);
}
virtual bool ImplByCuda(Mat* mat) {
virtual bool ImplByFlyCV(FDMatBatch* mat_batch) {
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
if (ImplByFlyCV(&(*(mat_batch->mats))[i]) != true) {
return false;
}
}
return true;
}
virtual bool ImplByCuda(FDMat* mat) {
return ImplByOpenCV(mat);
}
virtual bool ImplByCvCuda(Mat* mat) {
virtual bool ImplByCuda(FDMatBatch* mat_batch) {
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
if (ImplByCuda(&(*(mat_batch->mats))[i]) != true) {
return false;
}
}
return true;
}
virtual bool ImplByCvCuda(FDMat* mat) {
return ImplByOpenCV(mat);
}
virtual bool operator()(Mat* mat, ProcLib lib = ProcLib::DEFAULT);
virtual bool ImplByCvCuda(FDMatBatch* mat_batch) {
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
if (ImplByCvCuda(&(*(mat_batch->mats))[i]) != true) {
return false;
}
}
return true;
}
protected:
// Update and get the cached tensor from the cached_tensors_ map.
// The tensor is indexed by a string.
// If the tensor doesn't exists in the map, then create a new tensor.
// If the tensor exists and shape is getting larger, then realloc the buffer.
// If the tensor exists and shape is not getting larger, then return the
// cached tensor directly.
FDTensor* UpdateAndGetCachedTensor(
const std::vector<int64_t>& new_shape, const FDDataType& data_type,
const std::string& tensor_name, const Device& new_device = Device::CPU,
const bool& use_pinned_memory = false);
virtual bool operator()(FDMat* mat, ProcLib lib = ProcLib::DEFAULT);
// Create an input tensor on GPU and save into cached_tensors_.
// If the Mat is on GPU, return the mat->Tensor() directly.
// If the Mat is on CPU, then create a cached GPU tensor and copy the mat's
// CPU tensor to this new GPU tensor.
FDTensor* CreateCachedGpuInputTensor(Mat* mat,
const std::string& tensor_name);
private:
std::unordered_map<std::string, FDTensor> cached_tensors_;
virtual bool operator()(FDMatBatch* mat_batch,
ProcLib lib = ProcLib::DEFAULT);
};
} // namespace vision

View File

@@ -23,7 +23,7 @@
namespace fastdeploy {
namespace vision {
bool CenterCrop::ImplByOpenCV(Mat* mat) {
bool CenterCrop::ImplByOpenCV(FDMat* mat) {
cv::Mat* im = mat->GetOpenCVMat();
int height = static_cast<int>(im->rows);
int width = static_cast<int>(im->cols);
@@ -42,7 +42,7 @@ bool CenterCrop::ImplByOpenCV(Mat* mat) {
}
#ifdef ENABLE_FLYCV
bool CenterCrop::ImplByFlyCV(Mat* mat) {
bool CenterCrop::ImplByFlyCV(FDMat* mat) {
fcv::Mat* im = mat->GetFlyCVMat();
int height = static_cast<int>(im->height());
int width = static_cast<int>(im->width());
@@ -63,18 +63,15 @@ bool CenterCrop::ImplByFlyCV(Mat* mat) {
#endif
#ifdef ENABLE_CVCUDA
bool CenterCrop::ImplByCvCuda(Mat* mat) {
bool CenterCrop::ImplByCvCuda(FDMat* mat) {
// Prepare input tensor
std::string tensor_name = Name() + "_cvcuda_src";
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
FDTensor* src = CreateCachedGpuInputTensor(mat);
auto src_tensor = CreateCvCudaTensorWrapData(*src);
// Prepare output tensor
tensor_name = Name() + "_cvcuda_dst";
FDTensor* dst =
UpdateAndGetCachedTensor({height_, width_, mat->Channels()}, src->Dtype(),
tensor_name, Device::GPU);
auto dst_tensor = CreateCvCudaTensorWrapData(*dst);
mat->output_cache->Resize({height_, width_, mat->Channels()}, src->Dtype(),
"output_cache", Device::GPU);
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
int offset_x = static_cast<int>((mat->Width() - width_) / 2);
int offset_y = static_cast<int>((mat->Height() - height_) / 2);
@@ -82,16 +79,27 @@ bool CenterCrop::ImplByCvCuda(Mat* mat) {
NVCVRectI crop_roi = {offset_x, offset_y, width_, height_};
crop_op(mat->Stream(), src_tensor, dst_tensor, crop_roi);
mat->SetTensor(dst);
mat->SetTensor(mat->output_cache);
mat->SetWidth(width_);
mat->SetHeight(height_);
mat->device = Device::GPU;
mat->mat_type = ProcLib::CVCUDA;
return true;
}
bool CenterCrop::ImplByCvCuda(FDMatBatch* mat_batch) {
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
if (ImplByCvCuda(&((*(mat_batch->mats))[i])) != true) {
return false;
}
}
mat_batch->device = Device::GPU;
mat_batch->mat_type = ProcLib::CVCUDA;
return true;
}
#endif
bool CenterCrop::Run(Mat* mat, const int& width, const int& height,
bool CenterCrop::Run(FDMat* mat, const int& width, const int& height,
ProcLib lib) {
auto c = CenterCrop(width, height);
return c(mat, lib);

View File

@@ -22,16 +22,17 @@ namespace vision {
class FASTDEPLOY_DECL CenterCrop : public Processor {
public:
CenterCrop(int width, int height) : height_(height), width_(width) {}
bool ImplByOpenCV(Mat* mat);
bool ImplByOpenCV(FDMat* mat);
#ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat);
bool ImplByFlyCV(FDMat* mat);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(Mat* mat);
bool ImplByCvCuda(FDMat* mat);
bool ImplByCvCuda(FDMatBatch* mat_batch);
#endif
std::string Name() { return "CenterCrop"; }
static bool Run(Mat* mat, const int& width, const int& height,
static bool Run(FDMat* mat, const int& width, const int& height,
ProcLib lib = ProcLib::DEFAULT);
private:

View File

@@ -47,17 +47,19 @@ nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor) {
"When create CVCUDA tensor from FD tensor,"
"tensor shape should be 3-Dim, HWC layout");
int batchsize = 1;
int h = tensor.Shape()[0];
int w = tensor.Shape()[1];
int c = tensor.Shape()[2];
nvcv::TensorDataStridedCuda::Buffer buf;
buf.strides[3] = FDDataTypeSize(tensor.Dtype());
buf.strides[2] = tensor.shape[2] * buf.strides[3];
buf.strides[1] = tensor.shape[1] * buf.strides[2];
buf.strides[0] = tensor.shape[0] * buf.strides[1];
buf.strides[2] = c * buf.strides[3];
buf.strides[1] = w * buf.strides[2];
buf.strides[0] = h * buf.strides[1];
buf.basePtr = reinterpret_cast<NVCVByte*>(const_cast<void*>(tensor.Data()));
nvcv::Tensor::Requirements req = nvcv::Tensor::CalcRequirements(
batchsize, {tensor.shape[1], tensor.shape[0]},
CreateCvCudaImageFormat(tensor.Dtype(), tensor.shape[2]));
batchsize, {w, h}, CreateCvCudaImageFormat(tensor.Dtype(), c));
nvcv::TensorDataStridedCuda tensor_data(
nvcv::TensorShape{req.shape, req.rank, req.layout},
@@ -70,6 +72,33 @@ void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor) {
dynamic_cast<const nvcv::ITensorDataStridedCuda*>(tensor.exportData());
return reinterpret_cast<void*>(data->basePtr());
}
nvcv::ImageWrapData CreateImageWrapData(const FDTensor& tensor) {
FDASSERT(tensor.shape.size() == 3,
"When create CVCUDA image from FD tensor,"
"tensor shape should be 3-Dim, HWC layout");
int h = tensor.Shape()[0];
int w = tensor.Shape()[1];
int c = tensor.Shape()[2];
nvcv::ImageDataStridedCuda::Buffer buf;
buf.numPlanes = 1;
buf.planes[0].width = w;
buf.planes[0].height = h;
buf.planes[0].rowStride = w * c * FDDataTypeSize(tensor.Dtype());
buf.planes[0].basePtr =
reinterpret_cast<NVCVByte*>(const_cast<void*>(tensor.Data()));
nvcv::ImageWrapData nvimg{nvcv::ImageDataStridedCuda{
nvcv::ImageFormat{CreateCvCudaImageFormat(tensor.Dtype(), c)}, buf}};
return nvimg;
}
void CreateCvCudaImageBatchVarShape(std::vector<FDTensor*>& tensors,
nvcv::ImageBatchVarShape& img_batch) {
for (size_t i = 0; i < tensors.size(); ++i) {
FDASSERT(tensors[i]->device == Device::GPU, "Tensor must on GPU.");
img_batch.pushBack(CreateImageWrapData(*(tensors[i])));
}
}
#endif
} // namespace vision

View File

@@ -18,6 +18,7 @@
#ifdef ENABLE_CVCUDA
#include "nvcv/Tensor.hpp"
#include <nvcv/ImageBatch.hpp>
namespace fastdeploy {
namespace vision {
@@ -25,7 +26,10 @@ namespace vision {
nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel);
nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor);
void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor);
nvcv::ImageWrapData CreateImageWrapData(const FDTensor& tensor);
void CreateCvCudaImageBatchVarShape(std::vector<FDTensor*>& tensors,
nvcv::ImageBatchVarShape& img_batch);
}
}
} // namespace vision
} // namespace fastdeploy
#endif

View File

@@ -62,13 +62,24 @@ bool ProcessorManager::Run(std::vector<FDMat>* images,
return false;
}
for (size_t i = 0; i < images->size(); ++i) {
if (CudaUsed()) {
SetStream(&((*images)[i]));
}
if (images->size() > input_caches_.size()) {
input_caches_.resize(images->size());
output_caches_.resize(images->size());
}
bool ret = Apply(images, outputs);
FDMatBatch image_batch(images);
image_batch.input_cache = &batch_input_cache_;
image_batch.output_cache = &batch_output_cache_;
for (size_t i = 0; i < images->size(); ++i) {
if (CudaUsed()) {
SetStream(&image_batch);
}
(*images)[i].input_cache = &input_caches_[i];
(*images)[i].output_cache = &output_caches_[i];
}
bool ret = Apply(&image_batch, outputs);
if (CudaUsed()) {
SyncStream();

View File

@@ -16,6 +16,7 @@
#include "fastdeploy/utils/utils.h"
#include "fastdeploy/vision/common/processors/mat.h"
#include "fastdeploy/vision/common/processors/mat_batch.h"
namespace fastdeploy {
namespace vision {
@@ -24,16 +25,28 @@ class FASTDEPLOY_DECL ProcessorManager {
public:
~ProcessorManager();
/** \brief Use CUDA to boost the performance of processors
*
* \param[in] enable_cv_cuda ture: use CV-CUDA, false: use CUDA only
* \param[in] gpu_id GPU device id
* \return true if the preprocess successed, otherwise false
*/
void UseCuda(bool enable_cv_cuda = false, int gpu_id = -1);
bool CudaUsed();
void SetStream(Mat* mat) {
void SetStream(FDMat* mat) {
#ifdef WITH_GPU
mat->SetStream(stream_);
#endif
}
void SetStream(FDMatBatch* mat_batch) {
#ifdef WITH_GPU
mat_batch->SetStream(stream_);
#endif
}
void SyncStream() {
#ifdef WITH_GPU
FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
@@ -51,13 +64,13 @@ class FASTDEPLOY_DECL ProcessorManager {
*/
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
/** \brief The body of Run() function which needs to be implemented by a derived class
/** \brief Apply() is the body of Run() function, it needs to be implemented by a derived class
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] image_batch The input image batch
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
virtual bool Apply(std::vector<FDMat>* images,
virtual bool Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs) = 0;
protected:
@@ -68,6 +81,11 @@ class FASTDEPLOY_DECL ProcessorManager {
cudaStream_t stream_ = nullptr;
#endif
int device_id_ = -1;
std::vector<FDTensor> input_caches_;
std::vector<FDTensor> output_caches_;
FDTensor batch_input_cache_;
FDTensor batch_output_cache_;
};
} // namespace vision

View File

@@ -0,0 +1,41 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindProcessorManager(pybind11::module& m) {
pybind11::class_<vision::ProcessorManager>(m, "ProcessorManager")
.def("run",
[](vision::ProcessorManager& self,
std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
}
std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) {
throw std::runtime_error("Failed to process the input data");
}
if (!self.CudaUsed()) {
for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing();
}
}
return outputs;
})
.def("use_cuda",
[](vision::ProcessorManager& self, bool enable_cv_cuda = false,
int gpu_id = -1) { self.UseCuda(enable_cv_cuda, gpu_id); });
}
} // namespace fastdeploy

View File

@@ -247,5 +247,40 @@ std::vector<FDMat> WrapMat(const std::vector<cv::Mat>& images) {
return mats;
}
bool CheckShapeConsistency(std::vector<Mat>* mats) {
for (size_t i = 1; i < mats->size(); ++i) {
if ((*mats)[i].Channels() != (*mats)[0].Channels() ||
(*mats)[i].Width() != (*mats)[0].Width() ||
(*mats)[i].Height() != (*mats)[0].Height()) {
return false;
}
}
return true;
}
FDTensor* CreateCachedGpuInputTensor(Mat* mat) {
#ifdef WITH_GPU
FDTensor* src = mat->Tensor();
if (src->device == Device::GPU) {
return src;
} else if (src->device == Device::CPU) {
// Mats on CPU, we need copy these tensors from CPU to GPU
FDASSERT(src->Shape().size() == 3, "The CPU tensor must has 3 dims.")
mat->input_cache->Resize(src->Shape(), src->Dtype(), "input_cache",
Device::GPU);
FDASSERT(
cudaMemcpyAsync(mat->input_cache->Data(), src->Data(), src->Nbytes(),
cudaMemcpyHostToDevice, mat->Stream()) == 0,
"[ERROR] Error occurs while copy memory from CPU to GPU.");
return mat->input_cache;
} else {
FDASSERT(false, "FDMat is on unsupported device: %d", src->device);
}
#else
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
#endif
return nullptr;
}
} // namespace vision
} // namespace fastdeploy

View File

@@ -119,6 +119,11 @@ struct FASTDEPLOY_DECL Mat {
void SetChannels(int s) { channels = s; }
void SetWidth(int w) { width = w; }
void SetHeight(int h) { height = h; }
// When using CV-CUDA/CUDA, please set input/output cache,
// refer to manager.cc
FDTensor* input_cache = nullptr;
FDTensor* output_cache = nullptr;
#ifdef WITH_GPU
cudaStream_t Stream() const { return stream; }
void SetStream(cudaStream_t s) { stream = s; }
@@ -165,5 +170,12 @@ FASTDEPLOY_DECL FDMat WrapMat(const cv::Mat& image);
*/
FASTDEPLOY_DECL std::vector<FDMat> WrapMat(const std::vector<cv::Mat>& images);
bool CheckShapeConsistency(std::vector<Mat>* mats);
// Create an input tensor on GPU and save into input_cache.
// If the Mat is on GPU, return the mat->Tensor() directly.
// If the Mat is on CPU, then update the input cache tensor and copy the mat's
// CPU tensor to this new GPU input cache tensor.
FDTensor* CreateCachedGpuInputTensor(Mat* mat);
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,81 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/common/processors/mat_batch.h"
namespace fastdeploy {
namespace vision {
#ifdef WITH_GPU
void FDMatBatch::SetStream(cudaStream_t s) {
stream = s;
for (size_t i = 0; i < mats->size(); ++i) {
(*mats)[i].SetStream(s);
}
}
#endif
FDTensor* FDMatBatch::Tensor() {
if (has_batched_tensor) {
return &fd_tensor;
}
FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.")
// Each mat has its own tensor,
// to get a batched tensor, we need copy these tensors to a batched tensor
FDTensor* src = (*mats)[0].Tensor();
auto new_shape = src->Shape();
new_shape.insert(new_shape.begin(), mats->size());
input_cache->Resize(new_shape, src->Dtype(), "batch_input_cache", device);
for (size_t i = 0; i < mats->size(); ++i) {
FDASSERT(device == (*mats)[i].Tensor()->device,
"Mats and MatBatch are not on the same device");
uint8_t* p = reinterpret_cast<uint8_t*>(input_cache->Data());
int num_bytes = (*mats)[i].Tensor()->Nbytes();
FDTensor::CopyBuffer(p + i * num_bytes, (*mats)[i].Tensor()->Data(),
num_bytes, device, false);
}
SetTensor(input_cache);
return &fd_tensor;
}
void FDMatBatch::SetTensor(FDTensor* tensor) {
fd_tensor.SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(),
tensor->device, tensor->device_id);
has_batched_tensor = true;
}
FDTensor* CreateCachedGpuInputTensor(FDMatBatch* mat_batch) {
#ifdef WITH_GPU
auto mats = mat_batch->mats;
FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.")
FDTensor* src = (*mats)[0].Tensor();
if (mat_batch->device == Device::GPU) {
return mat_batch->Tensor();
} else if (mat_batch->device == Device::CPU) {
// Mats on CPU, we need copy them to GPU and then get a batched GPU tensor
for (size_t i = 0; i < mats->size(); ++i) {
FDTensor* tensor = CreateCachedGpuInputTensor(&(*mats)[i]);
(*mats)[i].SetTensor(tensor);
}
return mat_batch->Tensor();
} else {
FDASSERT(false, "FDMat is on unsupported device: %d", src->device);
}
#else
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
#endif
return nullptr;
}
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,76 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/mat.h"
#ifdef WITH_GPU
#include <cuda_runtime_api.h>
#endif
namespace fastdeploy {
namespace vision {
enum FDMatBatchLayout { NHWC, NCHW };
struct FASTDEPLOY_DECL FDMatBatch {
FDMatBatch() = default;
// MatBatch is intialized with a list of mats,
// the data is stored in the mats separately.
// Call Tensor() function to get a batched 4-dimension tensor.
explicit FDMatBatch(std::vector<Mat>* _mats) {
mats = _mats;
layout = FDMatBatchLayout::NHWC;
mat_type = ProcLib::OPENCV;
}
// Get the batched 4-dimension tensor.
FDTensor* Tensor();
void SetTensor(FDTensor* tensor);
private:
#ifdef WITH_GPU
cudaStream_t stream = nullptr;
#endif
FDTensor fd_tensor;
public:
// When using CV-CUDA/CUDA, please set input/output cache,
// refer to manager.cc
FDTensor* input_cache;
FDTensor* output_cache;
#ifdef WITH_GPU
cudaStream_t Stream() const { return stream; }
void SetStream(cudaStream_t s);
#endif
std::vector<FDMat>* mats;
ProcLib mat_type = ProcLib::OPENCV;
FDMatBatchLayout layout = FDMatBatchLayout::NHWC;
Device device = Device::CPU;
// False: the data is stored in the mats separately
// True: the data is stored in the fd_tensor continuously in 4 dimensions
bool has_batched_tensor = false;
};
// Create a batched input tensor on GPU and save into input_cache.
// If the MatBatch is on GPU, return the Tensor() directly.
// If the MatBatch is on CPU, then copy the CPU tensors to GPU and get a GPU
// batched input tensor.
FDTensor* CreateCachedGpuInputTensor(FDMatBatch* mat_batch);
} // namespace vision
} // namespace fastdeploy

View File

@@ -56,7 +56,7 @@ NormalizeAndPermute::NormalizeAndPermute(const std::vector<float>& mean,
swap_rb_ = swap_rb;
}
bool NormalizeAndPermute::ImplByOpenCV(Mat* mat) {
bool NormalizeAndPermute::ImplByOpenCV(FDMat* mat) {
cv::Mat* im = mat->GetOpenCVMat();
int origin_w = im->cols;
int origin_h = im->rows;
@@ -79,7 +79,7 @@ bool NormalizeAndPermute::ImplByOpenCV(Mat* mat) {
}
#ifdef ENABLE_FLYCV
bool NormalizeAndPermute::ImplByFlyCV(Mat* mat) {
bool NormalizeAndPermute::ImplByFlyCV(FDMat* mat) {
if (mat->layout != Layout::HWC) {
FDERROR << "Only supports input with HWC layout." << std::endl;
return false;
@@ -109,7 +109,7 @@ bool NormalizeAndPermute::ImplByFlyCV(Mat* mat) {
}
#endif
bool NormalizeAndPermute::Run(Mat* mat, const std::vector<float>& mean,
bool NormalizeAndPermute::Run(FDMat* mat, const std::vector<float>& mean,
const std::vector<float>& std, bool is_scale,
const std::vector<float>& min,
const std::vector<float>& max, ProcLib lib,

View File

@@ -18,63 +18,110 @@
namespace fastdeploy {
namespace vision {
__global__ void NormalizeAndPermuteKernel(uint8_t* src, float* dst,
__global__ void NormalizeAndPermuteKernel(const uint8_t* src, float* dst,
const float* alpha, const float* beta,
int num_channel, bool swap_rb,
int edge) {
int batch_size, int edge) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx >= edge) return;
if (swap_rb) {
uint8_t tmp = src[num_channel * idx];
src[num_channel * idx] = src[num_channel * idx + 2];
src[num_channel * idx + 2] = tmp;
}
int img_size = edge / batch_size;
int n = idx / img_size; // batch index
int p = idx - (n * img_size); // pixel index within the image
for (int i = 0; i < num_channel; ++i) {
dst[idx + edge * i] = src[num_channel * idx + i] * alpha[i] + beta[i];
int j = i;
if (swap_rb) {
j = 2 - i;
}
dst[n * img_size * num_channel + i * img_size + p] =
src[num_channel * idx + j] * alpha[i] + beta[i];
}
}
bool NormalizeAndPermute::ImplByCuda(Mat* mat) {
bool NormalizeAndPermute::ImplByCuda(FDMat* mat) {
// Prepare input tensor
std::string tensor_name = Name() + "_cvcuda_src";
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
FDTensor* src = CreateCachedGpuInputTensor(mat);
// Prepare output tensor
tensor_name = Name() + "_dst";
FDTensor* dst = UpdateAndGetCachedTensor(src->Shape(), FDDataType::FP32,
tensor_name, Device::GPU);
mat->output_cache->Resize(src->Shape(), FDDataType::FP32, "output_cache",
Device::GPU);
// Copy alpha and beta to GPU
tensor_name = Name() + "_alpha";
FDMat alpha_mat =
FDMat::Create(1, 1, alpha_.size(), FDDataType::FP32, alpha_.data());
FDTensor* alpha = CreateCachedGpuInputTensor(&alpha_mat, tensor_name);
gpu_alpha_.Resize({1, 1, static_cast<int>(alpha_.size())}, FDDataType::FP32,
"alpha", Device::GPU);
cudaMemcpy(gpu_alpha_.Data(), alpha_.data(), gpu_alpha_.Nbytes(),
cudaMemcpyHostToDevice);
tensor_name = Name() + "_beta";
FDMat beta_mat =
FDMat::Create(1, 1, beta_.size(), FDDataType::FP32, beta_.data());
FDTensor* beta = CreateCachedGpuInputTensor(&beta_mat, tensor_name);
gpu_beta_.Resize({1, 1, static_cast<int>(beta_.size())}, FDDataType::FP32,
"beta", Device::GPU);
cudaMemcpy(gpu_beta_.Data(), beta_.data(), gpu_beta_.Nbytes(),
cudaMemcpyHostToDevice);
int jobs = mat->Width() * mat->Height();
int jobs = 1 * mat->Width() * mat->Height();
int threads = 256;
int blocks = ceil(jobs / (float)threads);
NormalizeAndPermuteKernel<<<blocks, threads, 0, mat->Stream()>>>(
reinterpret_cast<uint8_t*>(src->Data()),
reinterpret_cast<float*>(dst->Data()),
reinterpret_cast<float*>(alpha->Data()),
reinterpret_cast<float*>(beta->Data()), mat->Channels(), swap_rb_, jobs);
reinterpret_cast<float*>(mat->output_cache->Data()),
reinterpret_cast<float*>(gpu_alpha_.Data()),
reinterpret_cast<float*>(gpu_beta_.Data()), mat->Channels(), swap_rb_, 1,
jobs);
mat->SetTensor(dst);
mat->SetTensor(mat->output_cache);
mat->device = Device::GPU;
mat->layout = Layout::CHW;
mat->mat_type = ProcLib::CUDA;
return true;
}
bool NormalizeAndPermute::ImplByCuda(FDMatBatch* mat_batch) {
// Prepare input tensor
FDTensor* src = CreateCachedGpuInputTensor(mat_batch);
// Prepare output tensor
mat_batch->output_cache->Resize(src->Shape(), FDDataType::FP32,
"output_cache", Device::GPU);
// NHWC -> NCHW
std::swap(mat_batch->output_cache->shape[1],
mat_batch->output_cache->shape[3]);
// Copy alpha and beta to GPU
gpu_alpha_.Resize({1, 1, static_cast<int>(alpha_.size())}, FDDataType::FP32,
"alpha", Device::GPU);
cudaMemcpy(gpu_alpha_.Data(), alpha_.data(), gpu_alpha_.Nbytes(),
cudaMemcpyHostToDevice);
gpu_beta_.Resize({1, 1, static_cast<int>(beta_.size())}, FDDataType::FP32,
"beta", Device::GPU);
cudaMemcpy(gpu_beta_.Data(), beta_.data(), gpu_beta_.Nbytes(),
cudaMemcpyHostToDevice);
int jobs =
mat_batch->output_cache->Numel() / mat_batch->output_cache->shape[1];
int threads = 256;
int blocks = ceil(jobs / (float)threads);
NormalizeAndPermuteKernel<<<blocks, threads, 0, mat_batch->Stream()>>>(
reinterpret_cast<uint8_t*>(src->Data()),
reinterpret_cast<float*>(mat_batch->output_cache->Data()),
reinterpret_cast<float*>(gpu_alpha_.Data()),
reinterpret_cast<float*>(gpu_beta_.Data()),
mat_batch->output_cache->shape[1], swap_rb_,
mat_batch->output_cache->shape[0], jobs);
mat_batch->SetTensor(mat_batch->output_cache);
mat_batch->device = Device::GPU;
mat_batch->layout = FDMatBatchLayout::NCHW;
mat_batch->mat_type = ProcLib::CUDA;
return true;
}
#ifdef ENABLE_CVCUDA
bool NormalizeAndPermute::ImplByCvCuda(Mat* mat) { return ImplByCuda(mat); }
bool NormalizeAndPermute::ImplByCvCuda(FDMat* mat) { return ImplByCuda(mat); }
bool NormalizeAndPermute::ImplByCvCuda(FDMatBatch* mat_batch) {
return ImplByCuda(mat_batch);
}
#endif
} // namespace vision

View File

@@ -25,15 +25,17 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor {
const std::vector<float>& min = std::vector<float>(),
const std::vector<float>& max = std::vector<float>(),
bool swap_rb = false);
bool ImplByOpenCV(Mat* mat);
bool ImplByOpenCV(FDMat* mat);
#ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat);
bool ImplByFlyCV(FDMat* mat);
#endif
#ifdef WITH_GPU
bool ImplByCuda(Mat* mat);
bool ImplByCuda(FDMat* mat);
bool ImplByCuda(FDMatBatch* mat_batch);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(Mat* mat);
bool ImplByCvCuda(FDMat* mat);
bool ImplByCvCuda(FDMatBatch* mat_batch);
#endif
std::string Name() { return "NormalizeAndPermute"; }
@@ -47,7 +49,7 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor {
// There will be some precomputation in contruct function
// and the `norm(mat)` only need to compute result = mat * alpha + beta
// which will reduce lots of time
static bool Run(Mat* mat, const std::vector<float>& mean,
static bool Run(FDMat* mat, const std::vector<float>& mean,
const std::vector<float>& std, bool is_scale = true,
const std::vector<float>& min = std::vector<float>(),
const std::vector<float>& max = std::vector<float>(),
@@ -76,6 +78,8 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor {
private:
std::vector<float> alpha_;
std::vector<float> beta_;
FDTensor gpu_alpha_;
FDTensor gpu_beta_;
bool swap_rb_;
};
} // namespace vision

View File

@@ -23,7 +23,7 @@
namespace fastdeploy {
namespace vision {
bool Resize::ImplByOpenCV(Mat* mat) {
bool Resize::ImplByOpenCV(FDMat* mat) {
if (mat->layout != Layout::HWC) {
FDERROR << "Resize: The format of input is not HWC." << std::endl;
return false;
@@ -61,7 +61,7 @@ bool Resize::ImplByOpenCV(Mat* mat) {
}
#ifdef ENABLE_FLYCV
bool Resize::ImplByFlyCV(Mat* mat) {
bool Resize::ImplByFlyCV(FDMat* mat) {
if (mat->layout != Layout::HWC) {
FDERROR << "Resize: The format of input is not HWC." << std::endl;
return false;
@@ -123,7 +123,7 @@ bool Resize::ImplByFlyCV(Mat* mat) {
#endif
#ifdef ENABLE_CVCUDA
bool Resize::ImplByCvCuda(Mat* mat) {
bool Resize::ImplByCvCuda(FDMat* mat) {
if (width_ == mat->Width() && height_ == mat->Height()) {
return true;
}
@@ -143,23 +143,20 @@ bool Resize::ImplByCvCuda(Mat* mat) {
}
// Prepare input tensor
std::string tensor_name = Name() + "_cvcuda_src";
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
FDTensor* src = CreateCachedGpuInputTensor(mat);
auto src_tensor = CreateCvCudaTensorWrapData(*src);
// Prepare output tensor
tensor_name = Name() + "_cvcuda_dst";
FDTensor* dst =
UpdateAndGetCachedTensor({height_, width_, mat->Channels()}, mat->Type(),
tensor_name, Device::GPU);
auto dst_tensor = CreateCvCudaTensorWrapData(*dst);
mat->output_cache->Resize({height_, width_, mat->Channels()}, mat->Type(),
"output_cache", Device::GPU);
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
// CV-CUDA Interp value is compatible with OpenCV
cvcuda::Resize resize_op;
resize_op(mat->Stream(), src_tensor, dst_tensor,
NVCVInterpolationType(interp_));
mat->SetTensor(dst);
mat->SetTensor(mat->output_cache);
mat->SetWidth(width_);
mat->SetHeight(height_);
mat->device = Device::GPU;
@@ -168,8 +165,8 @@ bool Resize::ImplByCvCuda(Mat* mat) {
}
#endif
bool Resize::Run(Mat* mat, int width, int height, float scale_w, float scale_h,
int interp, bool use_scale, ProcLib lib) {
bool Resize::Run(FDMat* mat, int width, int height, float scale_w,
float scale_h, int interp, bool use_scale, ProcLib lib) {
if (mat->Height() == height && mat->Width() == width) {
return true;
}

View File

@@ -31,16 +31,16 @@ class FASTDEPLOY_DECL Resize : public Processor {
use_scale_ = use_scale;
}
bool ImplByOpenCV(Mat* mat);
bool ImplByOpenCV(FDMat* mat);
#ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat);
bool ImplByFlyCV(FDMat* mat);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(Mat* mat);
bool ImplByCvCuda(FDMat* mat);
#endif
std::string Name() { return "Resize"; }
static bool Run(Mat* mat, int width, int height, float scale_w = -1.0,
static bool Run(FDMat* mat, int width, int height, float scale_w = -1.0,
float scale_h = -1.0, int interp = 1, bool use_scale = false,
ProcLib lib = ProcLib::DEFAULT);

View File

@@ -23,7 +23,7 @@
namespace fastdeploy {
namespace vision {
bool ResizeByShort::ImplByOpenCV(Mat* mat) {
bool ResizeByShort::ImplByOpenCV(FDMat* mat) {
cv::Mat* im = mat->GetOpenCVMat();
int origin_w = im->cols;
int origin_h = im->rows;
@@ -43,7 +43,7 @@ bool ResizeByShort::ImplByOpenCV(Mat* mat) {
}
#ifdef ENABLE_FLYCV
bool ResizeByShort::ImplByFlyCV(Mat* mat) {
bool ResizeByShort::ImplByFlyCV(FDMat* mat) {
fcv::Mat* im = mat->GetFlyCVMat();
int origin_w = im->width();
int origin_h = im->height();
@@ -87,10 +87,9 @@ bool ResizeByShort::ImplByFlyCV(Mat* mat) {
#endif
#ifdef ENABLE_CVCUDA
bool ResizeByShort::ImplByCvCuda(Mat* mat) {
bool ResizeByShort::ImplByCvCuda(FDMat* mat) {
// Prepare input tensor
std::string tensor_name = Name() + "_cvcuda_src";
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
FDTensor* src = CreateCachedGpuInputTensor(mat);
auto src_tensor = CreateCvCudaTensorWrapData(*src);
double scale = GenerateScale(mat->Width(), mat->Height());
@@ -98,23 +97,69 @@ bool ResizeByShort::ImplByCvCuda(Mat* mat) {
int height = static_cast<int>(round(scale * mat->Height()));
// Prepare output tensor
tensor_name = Name() + "_cvcuda_dst";
FDTensor* dst = UpdateAndGetCachedTensor(
{height, width, mat->Channels()}, mat->Type(), tensor_name, Device::GPU);
auto dst_tensor = CreateCvCudaTensorWrapData(*dst);
mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(),
"output_cache", Device::GPU);
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
// CV-CUDA Interp value is compatible with OpenCV
cvcuda::Resize resize_op;
resize_op(mat->Stream(), src_tensor, dst_tensor,
NVCVInterpolationType(interp_));
mat->SetTensor(dst);
mat->SetTensor(mat->output_cache);
mat->SetWidth(width);
mat->SetHeight(height);
mat->device = Device::GPU;
mat->mat_type = ProcLib::CVCUDA;
return true;
}
bool ResizeByShort::ImplByCvCuda(FDMatBatch* mat_batch) {
// TODO(wangxinyu): to support batched tensor as input
FDASSERT(mat_batch->has_batched_tensor == false,
"ResizeByShort doesn't support batched tensor as input for now.");
// Prepare input batch
std::string tensor_name = Name() + "_cvcuda_src";
std::vector<FDTensor*> src_tensors;
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
FDTensor* src = CreateCachedGpuInputTensor(&(*(mat_batch->mats))[i]);
src_tensors.push_back(src);
}
nvcv::ImageBatchVarShape src_batch(mat_batch->mats->size());
CreateCvCudaImageBatchVarShape(src_tensors, src_batch);
// Prepare output batch
tensor_name = Name() + "_cvcuda_dst";
std::vector<FDTensor*> dst_tensors;
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
FDMat* mat = &(*(mat_batch->mats))[i];
double scale = GenerateScale(mat->Width(), mat->Height());
int width = static_cast<int>(round(scale * mat->Width()));
int height = static_cast<int>(round(scale * mat->Height()));
mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(),
"output_cache", Device::GPU);
dst_tensors.push_back(mat->output_cache);
}
nvcv::ImageBatchVarShape dst_batch(mat_batch->mats->size());
CreateCvCudaImageBatchVarShape(dst_tensors, dst_batch);
// CV-CUDA Interp value is compatible with OpenCV
cvcuda::Resize resize_op;
resize_op(mat_batch->Stream(), src_batch, dst_batch,
NVCVInterpolationType(interp_));
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
FDMat* mat = &(*(mat_batch->mats))[i];
mat->SetTensor(dst_tensors[i]);
mat->SetWidth(dst_tensors[i]->Shape()[1]);
mat->SetHeight(dst_tensors[i]->Shape()[0]);
mat->device = Device::GPU;
mat->mat_type = ProcLib::CVCUDA;
}
mat_batch->device = Device::GPU;
mat_batch->mat_type = ProcLib::CVCUDA;
return true;
}
#endif
double ResizeByShort::GenerateScale(const int origin_w, const int origin_h) {
@@ -143,7 +188,7 @@ double ResizeByShort::GenerateScale(const int origin_w, const int origin_h) {
return scale;
}
bool ResizeByShort::Run(Mat* mat, int target_size, int interp, bool use_scale,
bool ResizeByShort::Run(FDMat* mat, int target_size, int interp, bool use_scale,
const std::vector<int>& max_hw, ProcLib lib) {
auto r = ResizeByShort(target_size, interp, use_scale, max_hw);
return r(mat, lib);

View File

@@ -28,16 +28,17 @@ class FASTDEPLOY_DECL ResizeByShort : public Processor {
interp_ = interp;
use_scale_ = use_scale;
}
bool ImplByOpenCV(Mat* mat);
bool ImplByOpenCV(FDMat* mat);
#ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat);
bool ImplByFlyCV(FDMat* mat);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(Mat* mat);
bool ImplByCvCuda(FDMat* mat);
bool ImplByCvCuda(FDMatBatch* mat_batch);
#endif
std::string Name() { return "ResizeByShort"; }
static bool Run(Mat* mat, int target_size, int interp = 1,
static bool Run(FDMat* mat, int target_size, int interp = 1,
bool use_scale = true,
const std::vector<int>& max_hw = std::vector<int>(),
ProcLib lib = ProcLib::DEFAULT);

View File

@@ -16,6 +16,7 @@
namespace fastdeploy {
void BindProcessorManager(pybind11::module& m);
void BindDetection(pybind11::module& m);
void BindClassification(pybind11::module& m);
void BindSegmentation(pybind11::module& m);
@@ -204,6 +205,7 @@ void BindVision(pybind11::module& m) {
m.def("disable_flycv", &vision::DisableFlyCV,
"Disable image preprocessing by FlyCV, change to use OpenCV.");
BindProcessorManager(m);
BindDetection(m);
BindClassification(m);
BindSegmentation(m);

View File

@@ -16,44 +16,40 @@ from __future__ import absolute_import
import logging
from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C
from ...common import ProcessorManager
class PaddleClasPreprocessor:
class PaddleClasPreprocessor(ProcessorManager):
def __init__(self, config_file):
"""Create a preprocessor for PaddleClasModel from configuration file
:param config_file: (str)Path of configuration file, e.g resnet50/inference_cls.yaml
"""
self._preprocessor = C.vision.classification.PaddleClasPreprocessor(
super(PaddleClasPreprocessor, self).__init__()
self._manager = C.vision.classification.PaddleClasPreprocessor(
config_file)
def run(self, input_ims):
"""Preprocess input images for PaddleClasModel
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims)
def use_cuda(self, enable_cv_cuda=False, gpu_id=-1):
"""Use CUDA preprocessors
:param: enable_cv_cuda: Whether to enable CV-CUDA
:param: gpu_id: GPU device id
"""
return self._preprocessor.use_cuda(enable_cv_cuda, gpu_id)
def disable_normalize(self):
"""
This function will disable normalize in preprocessing step.
"""
self._preprocessor.disable_normalize()
self._manager.disable_normalize()
def disable_permute(self):
"""
This function will disable hwc2chw in preprocessing step.
"""
self._preprocessor.disable_permute()
self._manager.disable_permute()
def initial_resize_on_cpu(self, v):
"""
When the initial operator is Resize, and input image size is large,
maybe it's better to run resize on CPU, because the HostToDevice memcpy
is time consuming. Set this True to run the initial resize on CPU.
:param: v: True or False
"""
self._manager.initial_resize_on_cpu(v)
class PaddleClasPostprocessor:

View File

@@ -0,0 +1,16 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .manager import ProcessorManager

View File

@@ -0,0 +1,36 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
class ProcessorManager:
def __init__(self):
self._manager = None
def run(self, input_ims):
"""Process input image
:param: input_ims: (list of numpy.ndarray) The input images
:return: list of FDTensor
"""
return self._manager.run(input_ims)
def use_cuda(self, enable_cv_cuda=False, gpu_id=-1):
"""Use CUDA processors
:param: enable_cv_cuda: Ture: use CV-CUDA, False: use CUDA only
:param: gpu_id: GPU device id
"""
return self._manager.use_cuda(enable_cv_cuda, gpu_id)

View File

@@ -72,6 +72,7 @@ setup_configs["PADDLELITE_URL"] = os.getenv("PADDLELITE_URL", "OFF")
setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF")
setup_configs["ENABLE_ENCRYPTION"] = os.getenv("ENABLE_ENCRYPTION", "OFF")
setup_configs["ENABLE_FLYCV"] = os.getenv("ENABLE_FLYCV", "OFF")
setup_configs["ENABLE_CVCUDA"] = os.getenv("ENABLE_CVCUDA", "OFF")
setup_configs["ENABLE_TEXT"] = os.getenv("ENABLE_TEXT", "OFF")
setup_configs["ENABLE_BENCHMARK"] = os.getenv("ENABLE_BENCHMARK", "OFF")
setup_configs["WITH_GPU"] = os.getenv("WITH_GPU", "OFF")