[CVCUDA] PaddleDetection preprocessor support CV-CUDA (#1493)

* ppdet preproc use manager

* pad_to_size chw opencv

* pad_to_size chw flycv

* fix pad_to_size flycv

* add warning message

* cvcuda convert cubic to linear, padToSize cvcuda

* stridedpad cvcuda

* fix flycv include

* fix flycv include

* fix flycv build

* cast cvcuda

* fix pybind

* fix normalize permute cuda

* base processor move funcs to cc

* Update pad_to_size.cc
This commit is contained in:
Wang Xinyu
2023-03-10 12:43:57 +08:00
committed by GitHub
parent 9ee2118e1f
commit cb7c8a07d4
23 changed files with 537 additions and 239 deletions

View File

@@ -138,8 +138,10 @@ bool PaddleClasPreprocessor::Apply(FDMatBatch* image_batch,
}
outputs->resize(1);
(*outputs)[0] = std::move(*(image_batch->Tensor()));
(*outputs)[0].device_id = DeviceId();
FDTensor* tensor = image_batch->Tensor();
(*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
tensor->Data(), tensor->device,
tensor->device_id);
return true;
}

View File

@@ -31,7 +31,9 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager {
*/
explicit PaddleClasPreprocessor(const std::string& config_file);
/** \brief Process the input image and prepare input tensors for runtime
/** \brief Implement the virtual function of ProcessorManager, Apply() is the
* body of Run(). Apply() contains the main logic of preprocessing, Run() is
* called by users to execute preprocessing
*
* \param[in] image_batch The input image batch
* \param[in] outputs The output tensors which will feed in runtime

View File

@@ -20,6 +20,63 @@
namespace fastdeploy {
namespace vision {
bool Processor::ImplByOpenCV(FDMat* mat) {
FDERROR << Name() << " Not Implement Yet." << std::endl;
return false;
}
bool Processor::ImplByOpenCV(FDMatBatch* mat_batch) {
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
if (ImplByOpenCV(&(*(mat_batch->mats))[i]) != true) {
return false;
}
}
return true;
}
bool Processor::ImplByFlyCV(FDMat* mat) { return ImplByOpenCV(mat); }
bool Processor::ImplByFlyCV(FDMatBatch* mat_batch) {
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
if (ImplByFlyCV(&(*(mat_batch->mats))[i]) != true) {
return false;
}
}
return true;
}
bool Processor::ImplByCuda(FDMat* mat) {
FDWARNING << Name()
<< " is not implemented with CUDA, will fallback to OpenCV."
<< std::endl;
return ImplByOpenCV(mat);
}
bool Processor::ImplByCuda(FDMatBatch* mat_batch) {
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
if (ImplByCuda(&(*(mat_batch->mats))[i]) != true) {
return false;
}
}
return true;
}
bool Processor::ImplByCvCuda(FDMat* mat) {
FDWARNING << Name()
<< " is not implemented with CV-CUDA, will fallback to OpenCV."
<< std::endl;
return ImplByOpenCV(mat);
}
bool Processor::ImplByCvCuda(FDMatBatch* mat_batch) {
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
if (ImplByCvCuda(&(*(mat_batch->mats))[i]) != true) {
return false;
}
}
return true;
}
bool Processor::operator()(FDMat* mat) {
ProcLib target = mat->proc_lib;
if (mat->proc_lib == ProcLib::DEFAULT) {

View File

@@ -47,58 +47,17 @@ class FASTDEPLOY_DECL Processor {
virtual std::string Name() = 0;
virtual bool ImplByOpenCV(FDMat* mat) {
FDERROR << Name() << " Not Implement Yet." << std::endl;
return false;
}
virtual bool ImplByOpenCV(FDMat* mat);
virtual bool ImplByOpenCV(FDMatBatch* mat_batch);
virtual bool ImplByOpenCV(FDMatBatch* mat_batch) {
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
if (ImplByOpenCV(&(*(mat_batch->mats))[i]) != true) {
return false;
}
}
return true;
}
virtual bool ImplByFlyCV(FDMat* mat);
virtual bool ImplByFlyCV(FDMatBatch* mat_batch);
virtual bool ImplByFlyCV(FDMat* mat) {
return ImplByOpenCV(mat);
}
virtual bool ImplByCuda(FDMat* mat);
virtual bool ImplByCuda(FDMatBatch* mat_batch);
virtual bool ImplByFlyCV(FDMatBatch* mat_batch) {
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
if (ImplByFlyCV(&(*(mat_batch->mats))[i]) != true) {
return false;
}
}
return true;
}
virtual bool ImplByCuda(FDMat* mat) {
return ImplByOpenCV(mat);
}
virtual bool ImplByCuda(FDMatBatch* mat_batch) {
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
if (ImplByCuda(&(*(mat_batch->mats))[i]) != true) {
return false;
}
}
return true;
}
virtual bool ImplByCvCuda(FDMat* mat) {
return ImplByOpenCV(mat);
}
virtual bool ImplByCvCuda(FDMatBatch* mat_batch) {
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
if (ImplByCvCuda(&(*(mat_batch->mats))[i]) != true) {
return false;
}
}
return true;
}
virtual bool ImplByCvCuda(FDMat* mat);
virtual bool ImplByCvCuda(FDMatBatch* mat_batch);
virtual bool operator()(FDMat* mat);

View File

@@ -14,6 +14,8 @@
#include "fastdeploy/vision/common/processors/cast.h"
#include "fastdeploy/vision/common/processors/utils.h"
namespace fastdeploy {
namespace vision {
@@ -68,6 +70,40 @@ bool Cast::ImplByFlyCV(Mat* mat) {
}
#endif
#ifdef ENABLE_CVCUDA
bool Cast::ImplByCvCuda(FDMat* mat) {
FDDataType dst_dtype;
if (dtype_ == "float") {
dst_dtype = FDDataType::FP32;
} else if (dtype_ == "double") {
dst_dtype = FDDataType::FP64;
} else {
FDWARNING << "Cast not support for " << dtype_
<< " now! will skip this operation." << std::endl;
return false;
}
if (mat->Type() == dst_dtype) {
return true;
}
// Prepare input tensor
FDTensor* src = CreateCachedGpuInputTensor(mat);
auto src_tensor = CreateCvCudaTensorWrapData(*src);
// Prepare output tensor
mat->output_cache->Resize(src->Shape(), dst_dtype, "output_cache",
Device::GPU);
auto dst_tensor =
CreateCvCudaTensorWrapData(*(mat->output_cache), mat->layout);
cvcuda_convert_op_(mat->Stream(), src_tensor, dst_tensor, 1.0f, 0.0f);
mat->SetTensor(mat->output_cache);
mat->mat_type = ProcLib::CVCUDA;
return true;
}
#endif
bool Cast::Run(Mat* mat, const std::string& dtype, ProcLib lib) {
auto c = Cast(dtype);
return c(mat, lib);

View File

@@ -15,6 +15,11 @@
#pragma once
#include "fastdeploy/vision/common/processors/base.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpConvertTo.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy {
namespace vision {
@@ -25,6 +30,9 @@ class FASTDEPLOY_DECL Cast : public Processor {
bool ImplByOpenCV(Mat* mat);
#ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(FDMat* mat);
#endif
std::string Name() { return "Cast"; }
static bool Run(Mat* mat, const std::string& dtype,
@@ -34,6 +42,9 @@ class FASTDEPLOY_DECL Cast : public Processor {
private:
std::string dtype_;
#ifdef ENABLE_CVCUDA
cvcuda::ConvertTo cvcuda_convert_op_;
#endif
};
} // namespace vision
} // namespace fastdeploy

View File

@@ -111,6 +111,17 @@ void CreateCvCudaImageBatchVarShape(std::vector<FDTensor*>& tensors,
img_batch.pushBack(CreateImageWrapData(*(tensors[i])));
}
}
NVCVInterpolationType CreateCvCudaInterp(int interp) {
// CV-CUDA Interp value is compatible with OpenCV
auto nvcv_interp = NVCVInterpolationType(interp);
// Due to bug of CV-CUDA CUBIC resize, will force to convert CUBIC to LINEAR
if (nvcv_interp == NVCV_INTERP_CUBIC) {
return NVCV_INTERP_LINEAR;
}
return nvcv_interp;
}
#endif
} // namespace vision

View File

@@ -18,8 +18,9 @@
#include "fastdeploy/vision/common/processors/mat.h"
#ifdef ENABLE_CVCUDA
#include "nvcv/Tensor.hpp"
#include <nvcv/Tensor.hpp>
#include <nvcv/ImageBatch.hpp>
#include <cvcuda/Types.h>
namespace fastdeploy {
namespace vision {
@@ -32,6 +33,7 @@ void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor);
nvcv::ImageWrapData CreateImageWrapData(const FDTensor& tensor);
void CreateCvCudaImageBatchVarShape(std::vector<FDTensor*>& tensors,
nvcv::ImageBatchVarShape& img_batch);
NVCVInterpolationType CreateCvCudaInterp(int interp);
} // namespace vision
} // namespace fastdeploy

View File

@@ -77,6 +77,7 @@ bool HWC2CHW::ImplByCvCuda(FDMat* mat) {
cvcuda_reformat_op_(mat->Stream(), src_tensor, dst_tensor);
mat->layout = Layout::CHW;
mat->SetTensor(mat->output_cache);
mat->mat_type = ProcLib::CVCUDA;
return true;

View File

@@ -37,7 +37,7 @@ cv::Mat* Mat::GetOpenCVMat() {
#ifdef WITH_GPU
FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess,
"[ERROR] Error occurs while sync cuda stream.");
cpu_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor);
cpu_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor, layout);
mat_type = ProcLib::OPENCV;
device = Device::CPU;
return &cpu_mat;
@@ -49,6 +49,23 @@ cv::Mat* Mat::GetOpenCVMat() {
}
}
#ifdef ENABLE_FLYCV
fcv::Mat* Mat::GetFlyCVMat() {
if (mat_type == ProcLib::FLYCV) {
return &fcv_mat;
} else if (mat_type == ProcLib::OPENCV) {
// Just a reference to cpu_mat, zero copy. After you
// call this method, fcv_mat and cpu_mat will point
// to the same memory buffer.
fcv_mat = ConvertOpenCVMatToFlyCV(cpu_mat);
mat_type = ProcLib::FLYCV;
return &fcv_mat;
} else {
FDASSERT(false, "The mat_type of custom Mat can not be ProcLib::DEFAULT");
}
}
#endif
void* Mat::Data() {
if (mat_type == ProcLib::FLYCV) {
#ifdef ENABLE_FLYCV
@@ -158,7 +175,7 @@ void Mat::PrintInfo(const std::string& flag) {
#ifdef WITH_GPU
FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess,
"[ERROR] Error occurs while sync cuda stream.");
cv::Mat tmp_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor);
cv::Mat tmp_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor, layout);
cv::Scalar mean = cv::mean(tmp_mat);
for (int i = 0; i < Channels(); ++i) {
std::cout << mean[i] << " ";

View File

@@ -13,10 +13,13 @@
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/vision/common/processors/utils.h"
#include "fastdeploy/vision/common/processors/proc_lib.h"
#include "opencv2/core/core.hpp"
#ifdef ENABLE_FLYCV
#include "flycv.h" // NOLINT
#endif
#ifdef WITH_GPU
#include <cuda_runtime_api.h>
#endif
@@ -70,21 +73,7 @@ struct FASTDEPLOY_DECL Mat {
fcv_mat = mat;
mat_type = ProcLib::FLYCV;
}
fcv::Mat* GetFlyCVMat() {
if (mat_type == ProcLib::FLYCV) {
return &fcv_mat;
} else if (mat_type == ProcLib::OPENCV) {
// Just a reference to cpu_mat, zero copy. After you
// call this method, fcv_mat and cpu_mat will point
// to the same memory buffer.
fcv_mat = ConvertOpenCVMatToFlyCV(cpu_mat);
mat_type = ProcLib::FLYCV;
return &fcv_mat;
} else {
FDASSERT(false, "The mat_type of custom Mat can not be ProcLib::DEFAULT");
}
}
fcv::Mat* GetFlyCVMat();
#endif
void* Data();

View File

@@ -40,12 +40,16 @@ __global__ void NormalizeAndPermuteKernel(const uint8_t* src, float* dst,
}
bool NormalizeAndPermute::ImplByCuda(FDMat* mat) {
if (mat->layout != Layout::HWC) {
FDERROR << "Only supports input with HWC layout." << std::endl;
return false;
}
// Prepare input tensor
FDTensor* src = CreateCachedGpuInputTensor(mat);
// Prepare output tensor
mat->output_cache->Resize(src->Shape(), FDDataType::FP32, "output_cache",
Device::GPU);
mat->output_cache->Resize({src->shape[2], src->shape[0], src->shape[1]},
FDDataType::FP32, "output_cache", Device::GPU);
// Copy alpha and beta to GPU
gpu_alpha_.Resize({1, 1, static_cast<int>(alpha_.size())}, FDDataType::FP32,
@@ -68,9 +72,8 @@ bool NormalizeAndPermute::ImplByCuda(FDMat* mat) {
reinterpret_cast<float*>(gpu_beta_.Data()), mat->Channels(), swap_rb_, 1,
jobs);
mat->SetTensor(mat->output_cache);
mat->device = Device::GPU;
mat->layout = Layout::CHW;
mat->SetTensor(mat->output_cache);
mat->mat_type = ProcLib::CUDA;
return true;
}
@@ -112,7 +115,6 @@ bool NormalizeAndPermute::ImplByCuda(FDMatBatch* mat_batch) {
mat_batch->output_cache->shape[0], jobs);
mat_batch->SetTensor(mat_batch->output_cache);
mat_batch->device = Device::GPU;
mat_batch->layout = FDMatBatchLayout::NCHW;
mat_batch->mat_type = ProcLib::CUDA;
return true;

View File

@@ -14,77 +14,62 @@
#include "fastdeploy/vision/common/processors/pad_to_size.h"
#include "fastdeploy/vision/common/processors/utils.h"
namespace fastdeploy {
namespace vision {
bool PadToSize::ImplByOpenCV(Mat* mat) {
if (width_ == -1 || height_ == -1) {
return true;
}
if (mat->layout != Layout::HWC) {
FDERROR << "PadToSize: The input data must be Layout::HWC format!"
<< std::endl;
return false;
}
if (mat->Channels() > 4) {
FDERROR << "PadToSize: Only support channels <= 4." << std::endl;
return false;
}
if (mat->Channels() != value_.size()) {
FDERROR
<< "PadToSize: Require input channels equals to size of padding value, "
"but now channels = "
<< mat->Channels() << ", the size of padding values = " << value_.size()
<< "." << std::endl;
return false;
}
static bool PadHWCByOpenCV(FDMat* mat, int width, int height,
const std::vector<float>& value) {
int origin_w = mat->Width();
int origin_h = mat->Height();
if (origin_w > width_) {
FDERROR << "PadToSize: the input width:" << origin_w
<< " is greater than the target width: " << width_ << "."
<< std::endl;
return false;
}
if (origin_h > height_) {
FDERROR << "PadToSize: the input height:" << origin_h
<< " is greater than the target height: " << height_ << "."
<< std::endl;
return false;
}
if (origin_w == width_ && origin_h == height_) {
return true;
}
cv::Mat* im = mat->GetOpenCVMat();
cv::Scalar value;
if (value_.size() == 1) {
value = cv::Scalar(value_[0]);
} else if (value_.size() == 2) {
value = cv::Scalar(value_[0], value_[1]);
} else if (value_.size() == 3) {
value = cv::Scalar(value_[0], value_[1], value_[2]);
cv::Scalar scalar;
if (value.size() == 1) {
scalar = cv::Scalar(value[0]);
} else if (value.size() == 2) {
scalar = cv::Scalar(value[0], value[1]);
} else if (value.size() == 3) {
scalar = cv::Scalar(value[0], value[1], value[2]);
} else {
value = cv::Scalar(value_[0], value_[1], value_[2], value_[3]);
scalar = cv::Scalar(value[0], value[1], value[2], value[3]);
}
// top, bottom, left, right
cv::copyMakeBorder(*im, *im, 0, height_ - origin_h, 0, width_ - origin_w,
cv::BORDER_CONSTANT, value);
mat->SetHeight(height_);
mat->SetWidth(width_);
cv::copyMakeBorder(*im, *im, 0, height - origin_h, 0, width - origin_w,
cv::BORDER_CONSTANT, scalar);
mat->SetHeight(height);
mat->SetWidth(width);
return true;
}
#ifdef ENABLE_FLYCV
bool PadToSize::ImplByFlyCV(Mat* mat) {
if (width_ == -1 || height_ == -1) {
return true;
}
if (mat->layout != Layout::HWC) {
FDERROR << "PadToSize: The input data must be Layout::HWC format!"
<< std::endl;
return false;
static bool PadCHWByOpenCV(FDMat* mat, int width, int height,
const std::vector<float>& value) {
int origin_w = mat->Width();
int origin_h = mat->Height();
cv::Mat* im = mat->GetOpenCVMat();
cv::Mat new_im(height, width,
CreateOpenCVDataType(mat->Type(), mat->Channels()));
for (int i = 0; i < mat->Channels(); ++i) {
uint8_t* src_data =
im->ptr() + i * origin_w * origin_h * FDDataTypeSize(mat->Type());
cv::Mat src(origin_h, origin_w, CreateOpenCVDataType(mat->Type(), 1),
src_data);
uint8_t* dst_data =
new_im.ptr() + i * width * height * FDDataTypeSize(mat->Type());
cv::Mat dst(height, width, CreateOpenCVDataType(mat->Type(), 1), dst_data);
cv::copyMakeBorder(src, dst, 0, height - origin_h, 0, width - origin_w,
cv::BORDER_CONSTANT, cv::Scalar(value[i]));
}
mat->SetMat(new_im);
mat->SetHeight(height);
mat->SetWidth(width);
return true;
}
bool PadToSize::CheckArgs(FDMat* mat) {
if (mat->Channels() > 4) {
FDERROR << "PadToSize: Only support channels <= 4." << std::endl;
return false;
@@ -97,45 +82,184 @@ bool PadToSize::ImplByFlyCV(Mat* mat) {
<< "." << std::endl;
return false;
}
int origin_w = mat->Width();
int origin_h = mat->Height();
if (origin_w > width_) {
FDERROR << "PadToSize: the input width:" << origin_w
if (mat->Width() > width_) {
FDERROR << "PadToSize: the input width:" << mat->Width()
<< " is greater than the target width: " << width_ << "."
<< std::endl;
return false;
}
if (origin_h > height_) {
FDERROR << "PadToSize: the input height:" << origin_h
if (mat->Height() > height_) {
FDERROR << "PadToSize: the input height:" << mat->Height()
<< " is greater than the target height: " << height_ << "."
<< std::endl;
return false;
}
if (origin_w == width_ && origin_h == height_) {
return true;
}
bool PadToSize::ImplByOpenCV(FDMat* mat) {
if (width_ == -1 || height_ == -1 ||
(mat->Width() == width_ && mat->Height() == height_)) {
return true;
}
if (CheckArgs(mat) == false) {
return false;
}
if (mat->layout == Layout::HWC) {
return PadHWCByOpenCV(mat, width_, height_, value_);
} else if (mat->layout == Layout::CHW) {
return PadCHWByOpenCV(mat, width_, height_, value_);
}
return false;
}
#ifdef ENABLE_FLYCV
static bool PadHWCByFlyCV(FDMat* mat, int width, int height,
const std::vector<float>& value) {
int origin_w = mat->Width();
int origin_h = mat->Height();
fcv::Mat* im = mat->GetFlyCVMat();
fcv::Scalar value;
if (value_.size() == 1) {
value = fcv::Scalar(value_[0]);
} else if (value_.size() == 2) {
value = fcv::Scalar(value_[0], value_[1]);
} else if (value_.size() == 3) {
value = fcv::Scalar(value_[0], value_[1], value_[2]);
fcv::Scalar scalar;
if (value.size() == 1) {
scalar = fcv::Scalar(value[0]);
} else if (value.size() == 2) {
scalar = fcv::Scalar(value[0], value[1]);
} else if (value.size() == 3) {
scalar = fcv::Scalar(value[0], value[1], value[2]);
} else {
value = fcv::Scalar(value_[0], value_[1], value_[2], value_[3]);
scalar = fcv::Scalar(value[0], value[1], value[2], value[3]);
}
fcv::Mat new_im;
// top, bottom, left, right
fcv::copy_make_border(*im, new_im, 0, height_ - origin_h, 0,
width_ - origin_w, fcv::BorderType::BORDER_CONSTANT,
value);
fcv::copy_make_border(*im, new_im, 0, height - origin_h, 0, width - origin_w,
fcv::BorderType::BORDER_CONSTANT, scalar);
mat->SetMat(new_im);
mat->SetHeight(height_);
mat->SetWidth(width_);
mat->SetHeight(height);
mat->SetWidth(width);
return true;
}
static bool PadCHWByFlyCV(FDMat* mat, int width, int height,
const std::vector<float>& value) {
int origin_w = mat->Width();
int origin_h = mat->Height();
fcv::Mat new_im(height, width,
CreateFlyCVDataType(mat->Type(), mat->Channels()));
for (int i = 0; i < mat->Channels(); ++i) {
uint8_t* src_data = reinterpret_cast<uint8_t*>(mat->Data()) +
i * origin_w * origin_h * FDDataTypeSize(mat->Type());
fcv::Mat src(origin_h, origin_w, CreateFlyCVDataType(mat->Type(), 1),
src_data);
uint8_t* dst_data = reinterpret_cast<uint8_t*>(new_im.data()) +
i * width * height * FDDataTypeSize(mat->Type());
fcv::Mat dst(height, width, CreateFlyCVDataType(mat->Type(), 1), dst_data);
fcv::copy_make_border(src, dst, 0, height - origin_h, 0, width - origin_w,
fcv::BorderType::BORDER_CONSTANT,
fcv::Scalar(value[i]));
}
mat->SetMat(new_im);
mat->SetHeight(height);
mat->SetWidth(width);
return true;
}
bool PadToSize::ImplByFlyCV(FDMat* mat) {
if (width_ == -1 || height_ == -1 ||
(mat->Width() == width_ && mat->Height() == height_)) {
return true;
}
if (CheckArgs(mat) == false) {
return false;
}
if (mat->layout == Layout::HWC) {
return PadHWCByFlyCV(mat, width_, height_, value_);
} else if (mat->layout == Layout::CHW) {
return PadCHWByFlyCV(mat, width_, height_, value_);
}
return false;
}
#endif
#ifdef ENABLE_CVCUDA
static bool PadHWCByCvCuda(cvcuda::CopyMakeBorder& pad_op, FDMat* mat,
int width, int height,
const std::vector<float>& value) {
float4 border_value;
if (value.size() == 1) {
border_value = make_float4(value[0], 0.0f, 0.0f, 0.0f);
} else if (value.size() == 2) {
border_value = make_float4(value[0], value[1], 0.0f, 0.0f);
} else if (value.size() == 3) {
border_value = make_float4(value[0], value[1], value[2], 0.0f);
} else {
border_value = make_float4(value[0], value[1], value[2], value[3]);
}
// Prepare input tensor
FDTensor* src = CreateCachedGpuInputTensor(mat);
auto src_tensor = CreateCvCudaTensorWrapData(*src);
// Prepare output tensor
mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(),
"output_cache", Device::GPU);
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
pad_op(mat->Stream(), src_tensor, dst_tensor, 0, 0, NVCV_BORDER_CONSTANT,
border_value);
mat->SetTensor(mat->output_cache);
mat->mat_type = ProcLib::CVCUDA;
return true;
}
static bool PadCHWByCvCuda(cvcuda::CopyMakeBorder& pad_op, FDMat* mat,
int width, int height,
const std::vector<float>& value) {
float4 border_value = make_float4(value[0], 0.0f, 0.0f, 0.0f);
FDTensor* input = CreateCachedGpuInputTensor(mat);
int channels = input->shape[0];
mat->output_cache->Resize({channels, height, width}, mat->Type(),
"output_cache", Device::GPU);
for (int i = 0; i < channels; ++i) {
uint8_t* src_data =
reinterpret_cast<uint8_t*>(input->Data()) +
i * mat->Width() * mat->Height() * FDDataTypeSize(mat->Type());
FDTensor src;
src.SetExternalData({mat->Height(), mat->Width(), 1}, input->Dtype(),
src_data, input->device, input->device_id);
auto src_tensor = CreateCvCudaTensorWrapData(src);
uint8_t* dst_data = reinterpret_cast<uint8_t*>(mat->output_cache->Data()) +
i * width * height * FDDataTypeSize(mat->Type());
FDTensor dst;
dst.SetExternalData({height, width, 1}, input->Dtype(), dst_data,
input->device, input->device_id);
auto dst_tensor = CreateCvCudaTensorWrapData(dst);
pad_op(mat->Stream(), src_tensor, dst_tensor, 0, 0, NVCV_BORDER_CONSTANT,
border_value);
}
mat->SetTensor(mat->output_cache);
mat->mat_type = ProcLib::CVCUDA;
return true;
}
bool PadToSize::ImplByCvCuda(FDMat* mat) {
if (width_ == -1 || height_ == -1 ||
(mat->Width() == width_ && mat->Height() == height_)) {
return true;
}
if (CheckArgs(mat) == false) {
return false;
}
if (mat->layout == Layout::HWC) {
return PadHWCByCvCuda(cvcuda_pad_op_, mat, width_, height_, value_);
} else if (mat->layout == Layout::CHW) {
return PadCHWByCvCuda(cvcuda_pad_op_, mat, width_, height_, value_);
}
return false;
}
#endif
bool PadToSize::Run(Mat* mat, int width, int height,

View File

@@ -15,6 +15,11 @@
#pragma once
#include "fastdeploy/vision/common/processors/base.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpCopyMakeBorder.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy {
namespace vision {
@@ -30,6 +35,9 @@ class FASTDEPLOY_DECL PadToSize : public Processor {
bool ImplByOpenCV(Mat* mat);
#ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(FDMat* mat);
#endif
std::string Name() { return "PadToSize"; }
@@ -37,10 +45,19 @@ class FASTDEPLOY_DECL PadToSize : public Processor {
const std::vector<float>& value,
ProcLib lib = ProcLib::DEFAULT);
void SetWidthHeight(int width, int height) {
width_ = width;
height_ = height;
}
private:
bool CheckArgs(FDMat* mat);
int width_;
int height_;
std::vector<float> value_;
#ifdef ENABLE_CVCUDA
cvcuda::CopyMakeBorder cvcuda_pad_op_;
#endif
};
} // namespace vision
} // namespace fastdeploy

View File

@@ -147,7 +147,7 @@ bool Resize::ImplByCvCuda(FDMat* mat) {
// CV-CUDA Interp value is compatible with OpenCV
cvcuda_resize_op_(mat->Stream(), src_tensor, dst_tensor,
NVCVInterpolationType(interp_));
CreateCvCudaInterp(interp_));
mat->SetTensor(mat->output_cache);
mat->SetWidth(width_);

View File

@@ -95,9 +95,8 @@ bool ResizeByShort::ImplByCvCuda(FDMat* mat) {
"output_cache", Device::GPU);
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
// CV-CUDA Interp value is compatible with OpenCV
cvcuda_resize_op_(mat->Stream(), src_tensor, dst_tensor,
NVCVInterpolationType(interp_));
CreateCvCudaInterp(interp_));
mat->SetTensor(mat->output_cache);
mat->SetWidth(width);
@@ -138,7 +137,7 @@ bool ResizeByShort::ImplByCvCuda(FDMatBatch* mat_batch) {
// CV-CUDA Interp value is compatible with OpenCV
cvcuda_resize_op_(mat_batch->Stream(), src_batch, dst_batch,
NVCVInterpolationType(interp_));
CreateCvCudaInterp(interp_));
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
FDMat* mat = &(*(mat_batch->mats))[i];

View File

@@ -114,6 +114,68 @@ bool StridePad::ImplByFlyCV(Mat* mat) {
}
#endif
#ifdef ENABLE_CVCUDA
bool StridePad::ImplByCvCuda(FDMat* mat) {
if (mat->layout != Layout::HWC) {
FDERROR << "StridePad: The input data must be Layout::HWC format!"
<< std::endl;
return false;
}
if (mat->Channels() > 4) {
FDERROR << "StridePad: Only support channels <= 4." << std::endl;
return false;
}
if (mat->Channels() != value_.size()) {
FDERROR
<< "StridePad: Require input channels equals to size of padding value, "
"but now channels = "
<< mat->Channels() << ", the size of padding values = " << value_.size()
<< "." << std::endl;
return false;
}
int origin_w = mat->Width();
int origin_h = mat->Height();
int pad_h = (mat->Height() / stride_) * stride_ +
(mat->Height() % stride_ != 0) * stride_ - mat->Height();
int pad_w = (mat->Width() / stride_) * stride_ +
(mat->Width() % stride_ != 0) * stride_ - mat->Width();
if (pad_h == 0 && pad_w == 0) {
return true;
}
float4 value;
if (value_.size() == 1) {
value = make_float4(value_[0], 0.0f, 0.0f, 0.0f);
} else if (value_.size() == 2) {
value = make_float4(value_[0], value_[1], 0.0f, 0.0f);
} else if (value_.size() == 3) {
value = make_float4(value_[0], value_[1], value_[2], 0.0f);
} else {
value = make_float4(value_[0], value_[1], value_[2], value_[3]);
}
// Prepare input tensor
FDTensor* src = CreateCachedGpuInputTensor(mat);
auto src_tensor = CreateCvCudaTensorWrapData(*src);
int height = mat->Height() + pad_h;
int width = mat->Width() + pad_w;
// Prepare output tensor
mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(),
"output_cache", Device::GPU);
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
cvcuda_pad_op_(mat->Stream(), src_tensor, dst_tensor, 0, 0,
NVCV_BORDER_CONSTANT, value);
mat->SetTensor(mat->output_cache);
mat->mat_type = ProcLib::CVCUDA;
return true;
}
#endif
bool StridePad::Run(Mat* mat, int stride, const std::vector<float>& value,
ProcLib lib) {
auto p = StridePad(stride, value);

View File

@@ -15,6 +15,11 @@
#pragma once
#include "fastdeploy/vision/common/processors/base.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpCopyMakeBorder.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy {
namespace vision {
@@ -29,6 +34,9 @@ class FASTDEPLOY_DECL StridePad : public Processor {
bool ImplByOpenCV(Mat* mat);
#ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(FDMat* mat);
#endif
std::string Name() { return "StridePad"; }
@@ -39,6 +47,9 @@ class FASTDEPLOY_DECL StridePad : public Processor {
private:
int stride_ = 32;
std::vector<float> value_;
#ifdef ENABLE_CVCUDA
cvcuda::CopyMakeBorder cvcuda_pad_op_;
#endif
};
} // namespace vision
} // namespace fastdeploy

View File

@@ -186,9 +186,8 @@ cv::Mat ConvertFlyCVMatToOpenCV(fcv::Mat& fim) {
}
#endif
cv::Mat CreateZeroCopyOpenCVMatFromBuffer(
int height, int width, int channels,
FDDataType type, void* data) {
cv::Mat CreateZeroCopyOpenCVMatFromBuffer(int height, int width, int channels,
FDDataType type, void* data) {
cv::Mat ocv_mat;
switch (type) {
case FDDataType::UINT8:
@@ -219,61 +218,61 @@ cv::Mat CreateZeroCopyOpenCVMatFromBuffer(
return ocv_mat;
}
cv::Mat CreateZeroCopyOpenCVMatFromTensor(const FDTensor& tensor) {
// TODO(qiuyanjun): Should add a Layout checking. Now, we
// assume that the input tensor is already in Layout::HWC.
FDASSERT(tensor.shape.size() == 3, "When create OepnCV Mat from tensor,"
"tensor shape should be 3-Dim, HWC layout");
cv::Mat CreateZeroCopyOpenCVMatFromTensor(const FDTensor& tensor,
Layout layout) {
FDASSERT(tensor.shape.size() == 3,
"When create OepnCV Mat from tensor,"
"tensor shape should be 3-Dim");
FDDataType type = tensor.dtype;
int height = static_cast<int>(tensor.shape[0]);
int width = static_cast<int>(tensor.shape[1]);
int channels = static_cast<int>(tensor.shape[2]);
return CreateZeroCopyOpenCVMatFromBuffer(
height, width, channels, type,
const_cast<void*>(tensor.CpuData()));
if (layout == Layout::CHW) {
channels = static_cast<int>(tensor.shape[0]);
height = static_cast<int>(tensor.shape[1]);
width = static_cast<int>(tensor.shape[2]);
}
return CreateZeroCopyOpenCVMatFromBuffer(height, width, channels, type,
const_cast<void*>(tensor.CpuData()));
}
#ifdef ENABLE_FLYCV
fcv::Mat CreateZeroCopyFlyCVMatFromBuffer(
int height, int width, int channels,
FDDataType type, void* data) {
fcv::Mat CreateZeroCopyFlyCVMatFromBuffer(int height, int width, int channels,
FDDataType type, void* data) {
fcv::Mat fcv_mat;
auto fcv_type = CreateFlyCVDataType(type, channels);
switch (type) {
case FDDataType::UINT8:
fcv_mat =
fcv::Mat(width, height, fcv_type, data);
fcv_mat = fcv::Mat(width, height, fcv_type, data);
break;
case FDDataType::FP32:
fcv_mat =
fcv::Mat(width, height, fcv_type, data);
fcv_mat = fcv::Mat(width, height, fcv_type, data);
break;
case FDDataType::FP64:
fcv_mat =
fcv::Mat(width, height, fcv_type, data);
break;
fcv_mat = fcv::Mat(width, height, fcv_type, data);
break;
default:
FDASSERT(false,
"Tensor type %d is not supported While calling "
"CreateZeroCopyFlyCVMat.",
"Tensor type %d is not supported While calling "
"CreateZeroCopyFlyCVMat.",
type);
break;
break;
}
return fcv_mat;
}
fcv::Mat CreateZeroCopyFlyCVMatFromTensor(const FDTensor& tensor) {
// TODO(qiuyanjun): Should add a Layout checking. Now, we
// assume that the input tensor is already in Layout::HWC.
FDASSERT(tensor.shape.size() == 3, "When create FlyCV Mat from tensor,"
"tensor shape should be 3-Dim, HWC layout");
// TODO(qiuyanjun): Should add a Layout checking. Now, we
// assume that the input tensor is already in Layout::HWC.
FDASSERT(tensor.shape.size() == 3,
"When create FlyCV Mat from tensor,"
"tensor shape should be 3-Dim, HWC layout");
FDDataType type = tensor.dtype;
int height = static_cast<int>(tensor.shape[0]);
int width = static_cast<int>(tensor.shape[1]);
int channels = static_cast<int>(tensor.shape[2]);
return CreateZeroCopyFlyCVMatFromBuffer(
height, width, channels, type,
const_cast<void*>(tensor.Data()));
return CreateZeroCopyFlyCVMatFromBuffer(height, width, channels, type,
const_cast<void*>(tensor.Data()));
}
#endif

View File

@@ -16,6 +16,7 @@
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/utils/utils.h"
#include "fastdeploy/vision/common/processors/mat.h"
#include "opencv2/core/core.hpp"
#ifdef ENABLE_FLYCV
@@ -42,11 +43,12 @@ cv::Mat ConvertFlyCVMatToOpenCV(fcv::Mat& fim);
// Create zero copy OpenCV/FlyCV Mat from FD Tensor / Buffer
cv::Mat CreateZeroCopyOpenCVMatFromBuffer(int height, int width,
int channels, FDDataType type, void* data);
cv::Mat CreateZeroCopyOpenCVMatFromTensor(const FDTensor& tensor);
int channels, FDDataType type, void* data);
cv::Mat CreateZeroCopyOpenCVMatFromTensor(const FDTensor& tensor,
Layout layout = Layout::HWC);
#ifdef ENABLE_FLYCV
fcv::Mat CreateZeroCopyFlyCVMatFromBuffer(int height, int width,
int channels, FDDataType type, void* data);
int channels, FDDataType type, void* data);
fcv::Mat CreateZeroCopyFlyCVMatFromTensor(const FDTensor& tensor);
#endif
} // namespace vision

View File

@@ -15,8 +15,8 @@
namespace fastdeploy {
void BindPPDet(pybind11::module& m) {
pybind11::class_<vision::detection::PaddleDetPreprocessor>(
m, "PaddleDetPreprocessor")
pybind11::class_<vision::detection::PaddleDetPreprocessor,
vision::ProcessorManager>(m, "PaddleDetPreprocessor")
.def(pybind11::init<std::string>())
.def("run",
[](vision::detection::PaddleDetPreprocessor& self,

View File

@@ -129,13 +129,13 @@ bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig() {
return true;
}
bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs) {
bool PaddleDetPreprocessor::Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
if (images->empty()) {
if (image_batch->mats->empty()) {
FDERROR << "The size of input images should be greater than 0."
<< std::endl;
return false;
@@ -146,7 +146,7 @@ bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images,
// So preprocessor will output the 3 FDTensors, and how to use `im_shape`
// is decided by the model itself
outputs->resize(3);
int batch = static_cast<int>(images->size());
int batch = static_cast<int>(image_batch->mats->size());
// Allocate memory for scale_factor
(*outputs)[1].Resize({batch, 2}, FDDataType::FP32);
// Allocate memory for im_shape
@@ -158,63 +158,51 @@ bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images,
auto* scale_factor_ptr =
reinterpret_cast<float*>((*outputs)[1].MutableData());
auto* im_shape_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData());
for (size_t i = 0; i < images->size(); ++i) {
int origin_w = (*images)[i].Width();
int origin_h = (*images)[i].Height();
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
FDMat* mat = &(image_batch->mats->at(i));
int origin_w = mat->Width();
int origin_h = mat->Height();
scale_factor_ptr[2 * i] = 1.0;
scale_factor_ptr[2 * i + 1] = 1.0;
for (size_t j = 0; j < processors_.size(); ++j) {
if (!(*(processors_[j].get()))(&((*images)[i]))) {
if (!(*(processors_[j].get()))(mat)) {
FDERROR << "Failed to processs image:" << i << " in "
<< processors_[i]->Name() << "." << std::endl;
<< processors_[j]->Name() << "." << std::endl;
return false;
}
if (processors_[j]->Name().find("Resize") != std::string::npos) {
scale_factor_ptr[2 * i] = (*images)[i].Height() * 1.0 / origin_h;
scale_factor_ptr[2 * i + 1] = (*images)[i].Width() * 1.0 / origin_w;
scale_factor_ptr[2 * i] = mat->Height() * 1.0 / origin_h;
scale_factor_ptr[2 * i + 1] = mat->Width() * 1.0 / origin_w;
}
}
if ((*images)[i].Height() > max_hw[0]) {
max_hw[0] = (*images)[i].Height();
if (mat->Height() > max_hw[0]) {
max_hw[0] = mat->Height();
}
if ((*images)[i].Width() > max_hw[1]) {
max_hw[1] = (*images)[i].Width();
if (mat->Width() > max_hw[1]) {
max_hw[1] = mat->Width();
}
im_shape_ptr[2 * i] = max_hw[0];
im_shape_ptr[2 * i + 1] = max_hw[1];
}
// Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> im_tensors(images->size());
for (size_t i = 0; i < images->size(); ++i) {
if ((*images)[i].Height() < max_hw[0] || (*images)[i].Width() < max_hw[1]) {
// if the size of image less than max_hw, pad to max_hw
FDTensor tensor;
(*images)[i].ShareWithTensor(&tensor);
function::Pad(tensor, &(im_tensors[i]),
{0, 0, max_hw[0] - (*images)[i].Height(),
max_hw[1] - (*images)[i].Width()},
0);
} else {
// No need pad
(*images)[i].ShareWithTensor(&(im_tensors[i]));
// if the size of image less than max_hw, pad to max_hw
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
FDMat* mat = &(image_batch->mats->at(i));
if (mat->Height() < max_hw[0] || mat->Width() < max_hw[1]) {
pad_op_->SetWidthHeight(max_hw[1], max_hw[0]);
(*pad_op_)(mat);
}
// Reshape to 1xCxHxW
im_tensors[i].ExpandDim(0);
}
if (im_tensors.size() == 1) {
// If there's only 1 input, no need to concat
// skip memory copy
(*outputs)[0] = std::move(im_tensors[0]);
} else {
// Else concat the im tensor for each input image
// compose a batched input tensor
function::Concat(im_tensors, &((*outputs)[0]), 0);
}
// Get the NCHW tensor
FDTensor* tensor = image_batch->Tensor();
(*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
tensor->Data(), tensor->device,
tensor->device_id);
return true;
}
void PaddleDetPreprocessor::DisableNormalize() {
this->disable_normalize_ = true;
// the DisableNormalize function will be invalid if the configuration file is
@@ -224,6 +212,7 @@ void PaddleDetPreprocessor::DisableNormalize() {
<< std::endl;
}
}
void PaddleDetPreprocessor::DisablePermute() {
this->disable_permute_ = true;
// the DisablePermute function will be invalid if the configuration file is

View File

@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/manager.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
@@ -22,7 +23,7 @@ namespace vision {
namespace detection {
/*! @brief Preprocessor object for PaddleDet serials model.
*/
class FASTDEPLOY_DECL PaddleDetPreprocessor {
class FASTDEPLOY_DECL PaddleDetPreprocessor : public ProcessorManager {
public:
PaddleDetPreprocessor() = default;
/** \brief Create a preprocessor instance for PaddleDet serials model
@@ -31,13 +32,16 @@ class FASTDEPLOY_DECL PaddleDetPreprocessor {
*/
explicit PaddleDetPreprocessor(const std::string& config_file);
/** \brief Process the input image and prepare input tensors for runtime
/** \brief Implement the virtual function of ProcessorManager, Apply() is the
* body of Run(). Apply() contains the main logic of preprocessing, Run() is
* called by users to execute preprocessing
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime, include image, scale_factor, im_shape
* \param[in] image_batch The input image batch
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
virtual bool Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs);
/// This function will disable normalize in preprocessing step.
void DisableNormalize();
@@ -51,6 +55,8 @@ class FASTDEPLOY_DECL PaddleDetPreprocessor {
private:
bool BuildPreprocessPipelineFromConfig();
std::vector<std::shared_ptr<Processor>> processors_;
std::shared_ptr<PadToSize> pad_op_ =
std::make_shared<PadToSize>(0, 0, std::vector<float>(3, 0));
bool initialized_ = false;
// for recording the switch of hwc2chw
bool disable_permute_ = false;