From cb7c8a07d42f2c2f445b28dbd11002a22aa807a1 Mon Sep 17 00:00:00 2001 From: Wang Xinyu Date: Fri, 10 Mar 2023 12:43:57 +0800 Subject: [PATCH] [CVCUDA] PaddleDetection preprocessor support CV-CUDA (#1493) * ppdet preproc use manager * pad_to_size chw opencv * pad_to_size chw flycv * fix pad_to_size flycv * add warning message * cvcuda convert cubic to linear, padToSize cvcuda * stridedpad cvcuda * fix flycv include * fix flycv include * fix flycv build * cast cvcuda * fix pybind * fix normalize permute cuda * base processor move funcs to cc * Update pad_to_size.cc --- .../classification/ppcls/preprocessor.cc | 6 +- .../classification/ppcls/preprocessor.h | 4 +- fastdeploy/vision/common/processors/base.cc | 57 ++++ fastdeploy/vision/common/processors/base.h | 57 +--- fastdeploy/vision/common/processors/cast.cc | 36 +++ fastdeploy/vision/common/processors/cast.h | 11 + .../vision/common/processors/cvcuda_utils.cc | 11 + .../vision/common/processors/cvcuda_utils.h | 4 +- .../vision/common/processors/hwc2chw.cc | 1 + fastdeploy/vision/common/processors/mat.cc | 21 +- fastdeploy/vision/common/processors/mat.h | 21 +- .../processors/normalize_and_permute.cu | 12 +- .../vision/common/processors/pad_to_size.cc | 280 +++++++++++++----- .../vision/common/processors/pad_to_size.h | 17 ++ fastdeploy/vision/common/processors/resize.cc | 2 +- .../common/processors/resize_by_short.cc | 5 +- .../vision/common/processors/stride_pad.cc | 62 ++++ .../vision/common/processors/stride_pad.h | 11 + fastdeploy/vision/common/processors/utils.cc | 61 ++-- fastdeploy/vision/common/processors/utils.h | 8 +- .../vision/detection/ppdet/ppdet_pybind.cc | 4 +- .../vision/detection/ppdet/preprocessor.cc | 69 ++--- .../vision/detection/ppdet/preprocessor.h | 16 +- 23 files changed, 537 insertions(+), 239 deletions(-) diff --git a/fastdeploy/vision/classification/ppcls/preprocessor.cc b/fastdeploy/vision/classification/ppcls/preprocessor.cc index 5e167dc28..ff65d06ea 100644 --- a/fastdeploy/vision/classification/ppcls/preprocessor.cc +++ b/fastdeploy/vision/classification/ppcls/preprocessor.cc @@ -138,8 +138,10 @@ bool PaddleClasPreprocessor::Apply(FDMatBatch* image_batch, } outputs->resize(1); - (*outputs)[0] = std::move(*(image_batch->Tensor())); - (*outputs)[0].device_id = DeviceId(); + FDTensor* tensor = image_batch->Tensor(); + (*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(), + tensor->Data(), tensor->device, + tensor->device_id); return true; } diff --git a/fastdeploy/vision/classification/ppcls/preprocessor.h b/fastdeploy/vision/classification/ppcls/preprocessor.h index ac2e82ef1..25a237f29 100644 --- a/fastdeploy/vision/classification/ppcls/preprocessor.h +++ b/fastdeploy/vision/classification/ppcls/preprocessor.h @@ -31,7 +31,9 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager { */ explicit PaddleClasPreprocessor(const std::string& config_file); - /** \brief Process the input image and prepare input tensors for runtime + /** \brief Implement the virtual function of ProcessorManager, Apply() is the + * body of Run(). Apply() contains the main logic of preprocessing, Run() is + * called by users to execute preprocessing * * \param[in] image_batch The input image batch * \param[in] outputs The output tensors which will feed in runtime diff --git a/fastdeploy/vision/common/processors/base.cc b/fastdeploy/vision/common/processors/base.cc index 7e34d07bf..99103e305 100644 --- a/fastdeploy/vision/common/processors/base.cc +++ b/fastdeploy/vision/common/processors/base.cc @@ -20,6 +20,63 @@ namespace fastdeploy { namespace vision { +bool Processor::ImplByOpenCV(FDMat* mat) { + FDERROR << Name() << " Not Implement Yet." << std::endl; + return false; +} + +bool Processor::ImplByOpenCV(FDMatBatch* mat_batch) { + for (size_t i = 0; i < mat_batch->mats->size(); ++i) { + if (ImplByOpenCV(&(*(mat_batch->mats))[i]) != true) { + return false; + } + } + return true; +} + +bool Processor::ImplByFlyCV(FDMat* mat) { return ImplByOpenCV(mat); } + +bool Processor::ImplByFlyCV(FDMatBatch* mat_batch) { + for (size_t i = 0; i < mat_batch->mats->size(); ++i) { + if (ImplByFlyCV(&(*(mat_batch->mats))[i]) != true) { + return false; + } + } + return true; +} + +bool Processor::ImplByCuda(FDMat* mat) { + FDWARNING << Name() + << " is not implemented with CUDA, will fallback to OpenCV." + << std::endl; + return ImplByOpenCV(mat); +} + +bool Processor::ImplByCuda(FDMatBatch* mat_batch) { + for (size_t i = 0; i < mat_batch->mats->size(); ++i) { + if (ImplByCuda(&(*(mat_batch->mats))[i]) != true) { + return false; + } + } + return true; +} + +bool Processor::ImplByCvCuda(FDMat* mat) { + FDWARNING << Name() + << " is not implemented with CV-CUDA, will fallback to OpenCV." + << std::endl; + return ImplByOpenCV(mat); +} + +bool Processor::ImplByCvCuda(FDMatBatch* mat_batch) { + for (size_t i = 0; i < mat_batch->mats->size(); ++i) { + if (ImplByCvCuda(&(*(mat_batch->mats))[i]) != true) { + return false; + } + } + return true; +} + bool Processor::operator()(FDMat* mat) { ProcLib target = mat->proc_lib; if (mat->proc_lib == ProcLib::DEFAULT) { diff --git a/fastdeploy/vision/common/processors/base.h b/fastdeploy/vision/common/processors/base.h index a1c64a2c1..820ef88e5 100644 --- a/fastdeploy/vision/common/processors/base.h +++ b/fastdeploy/vision/common/processors/base.h @@ -47,58 +47,17 @@ class FASTDEPLOY_DECL Processor { virtual std::string Name() = 0; - virtual bool ImplByOpenCV(FDMat* mat) { - FDERROR << Name() << " Not Implement Yet." << std::endl; - return false; - } + virtual bool ImplByOpenCV(FDMat* mat); + virtual bool ImplByOpenCV(FDMatBatch* mat_batch); - virtual bool ImplByOpenCV(FDMatBatch* mat_batch) { - for (size_t i = 0; i < mat_batch->mats->size(); ++i) { - if (ImplByOpenCV(&(*(mat_batch->mats))[i]) != true) { - return false; - } - } - return true; - } + virtual bool ImplByFlyCV(FDMat* mat); + virtual bool ImplByFlyCV(FDMatBatch* mat_batch); - virtual bool ImplByFlyCV(FDMat* mat) { - return ImplByOpenCV(mat); - } + virtual bool ImplByCuda(FDMat* mat); + virtual bool ImplByCuda(FDMatBatch* mat_batch); - virtual bool ImplByFlyCV(FDMatBatch* mat_batch) { - for (size_t i = 0; i < mat_batch->mats->size(); ++i) { - if (ImplByFlyCV(&(*(mat_batch->mats))[i]) != true) { - return false; - } - } - return true; - } - - virtual bool ImplByCuda(FDMat* mat) { - return ImplByOpenCV(mat); - } - - virtual bool ImplByCuda(FDMatBatch* mat_batch) { - for (size_t i = 0; i < mat_batch->mats->size(); ++i) { - if (ImplByCuda(&(*(mat_batch->mats))[i]) != true) { - return false; - } - } - return true; - } - - virtual bool ImplByCvCuda(FDMat* mat) { - return ImplByOpenCV(mat); - } - - virtual bool ImplByCvCuda(FDMatBatch* mat_batch) { - for (size_t i = 0; i < mat_batch->mats->size(); ++i) { - if (ImplByCvCuda(&(*(mat_batch->mats))[i]) != true) { - return false; - } - } - return true; - } + virtual bool ImplByCvCuda(FDMat* mat); + virtual bool ImplByCvCuda(FDMatBatch* mat_batch); virtual bool operator()(FDMat* mat); diff --git a/fastdeploy/vision/common/processors/cast.cc b/fastdeploy/vision/common/processors/cast.cc index 0ca04a504..cb206185c 100644 --- a/fastdeploy/vision/common/processors/cast.cc +++ b/fastdeploy/vision/common/processors/cast.cc @@ -14,6 +14,8 @@ #include "fastdeploy/vision/common/processors/cast.h" +#include "fastdeploy/vision/common/processors/utils.h" + namespace fastdeploy { namespace vision { @@ -68,6 +70,40 @@ bool Cast::ImplByFlyCV(Mat* mat) { } #endif +#ifdef ENABLE_CVCUDA +bool Cast::ImplByCvCuda(FDMat* mat) { + FDDataType dst_dtype; + if (dtype_ == "float") { + dst_dtype = FDDataType::FP32; + } else if (dtype_ == "double") { + dst_dtype = FDDataType::FP64; + } else { + FDWARNING << "Cast not support for " << dtype_ + << " now! will skip this operation." << std::endl; + return false; + } + if (mat->Type() == dst_dtype) { + return true; + } + + // Prepare input tensor + FDTensor* src = CreateCachedGpuInputTensor(mat); + auto src_tensor = CreateCvCudaTensorWrapData(*src); + + // Prepare output tensor + mat->output_cache->Resize(src->Shape(), dst_dtype, "output_cache", + Device::GPU); + auto dst_tensor = + CreateCvCudaTensorWrapData(*(mat->output_cache), mat->layout); + + cvcuda_convert_op_(mat->Stream(), src_tensor, dst_tensor, 1.0f, 0.0f); + + mat->SetTensor(mat->output_cache); + mat->mat_type = ProcLib::CVCUDA; + return true; +} +#endif + bool Cast::Run(Mat* mat, const std::string& dtype, ProcLib lib) { auto c = Cast(dtype); return c(mat, lib); diff --git a/fastdeploy/vision/common/processors/cast.h b/fastdeploy/vision/common/processors/cast.h index 891ae334c..34fe3dafb 100644 --- a/fastdeploy/vision/common/processors/cast.h +++ b/fastdeploy/vision/common/processors/cast.h @@ -15,6 +15,11 @@ #pragma once #include "fastdeploy/vision/common/processors/base.h" +#ifdef ENABLE_CVCUDA +#include + +#include "fastdeploy/vision/common/processors/cvcuda_utils.h" +#endif namespace fastdeploy { namespace vision { @@ -25,6 +30,9 @@ class FASTDEPLOY_DECL Cast : public Processor { bool ImplByOpenCV(Mat* mat); #ifdef ENABLE_FLYCV bool ImplByFlyCV(Mat* mat); +#endif +#ifdef ENABLE_CVCUDA + bool ImplByCvCuda(FDMat* mat); #endif std::string Name() { return "Cast"; } static bool Run(Mat* mat, const std::string& dtype, @@ -34,6 +42,9 @@ class FASTDEPLOY_DECL Cast : public Processor { private: std::string dtype_; +#ifdef ENABLE_CVCUDA + cvcuda::ConvertTo cvcuda_convert_op_; +#endif }; } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/cvcuda_utils.cc b/fastdeploy/vision/common/processors/cvcuda_utils.cc index ff0d5e3ba..017ad15ee 100644 --- a/fastdeploy/vision/common/processors/cvcuda_utils.cc +++ b/fastdeploy/vision/common/processors/cvcuda_utils.cc @@ -111,6 +111,17 @@ void CreateCvCudaImageBatchVarShape(std::vector& tensors, img_batch.pushBack(CreateImageWrapData(*(tensors[i]))); } } + +NVCVInterpolationType CreateCvCudaInterp(int interp) { + // CV-CUDA Interp value is compatible with OpenCV + auto nvcv_interp = NVCVInterpolationType(interp); + + // Due to bug of CV-CUDA CUBIC resize, will force to convert CUBIC to LINEAR + if (nvcv_interp == NVCV_INTERP_CUBIC) { + return NVCV_INTERP_LINEAR; + } + return nvcv_interp; +} #endif } // namespace vision diff --git a/fastdeploy/vision/common/processors/cvcuda_utils.h b/fastdeploy/vision/common/processors/cvcuda_utils.h index 2c84d073d..a3a62e702 100644 --- a/fastdeploy/vision/common/processors/cvcuda_utils.h +++ b/fastdeploy/vision/common/processors/cvcuda_utils.h @@ -18,8 +18,9 @@ #include "fastdeploy/vision/common/processors/mat.h" #ifdef ENABLE_CVCUDA -#include "nvcv/Tensor.hpp" +#include #include +#include namespace fastdeploy { namespace vision { @@ -32,6 +33,7 @@ void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor); nvcv::ImageWrapData CreateImageWrapData(const FDTensor& tensor); void CreateCvCudaImageBatchVarShape(std::vector& tensors, nvcv::ImageBatchVarShape& img_batch); +NVCVInterpolationType CreateCvCudaInterp(int interp); } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/hwc2chw.cc b/fastdeploy/vision/common/processors/hwc2chw.cc index 9c45b396e..af13da129 100644 --- a/fastdeploy/vision/common/processors/hwc2chw.cc +++ b/fastdeploy/vision/common/processors/hwc2chw.cc @@ -77,6 +77,7 @@ bool HWC2CHW::ImplByCvCuda(FDMat* mat) { cvcuda_reformat_op_(mat->Stream(), src_tensor, dst_tensor); + mat->layout = Layout::CHW; mat->SetTensor(mat->output_cache); mat->mat_type = ProcLib::CVCUDA; return true; diff --git a/fastdeploy/vision/common/processors/mat.cc b/fastdeploy/vision/common/processors/mat.cc index da1d72ccb..f39a065e1 100644 --- a/fastdeploy/vision/common/processors/mat.cc +++ b/fastdeploy/vision/common/processors/mat.cc @@ -37,7 +37,7 @@ cv::Mat* Mat::GetOpenCVMat() { #ifdef WITH_GPU FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess, "[ERROR] Error occurs while sync cuda stream."); - cpu_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor); + cpu_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor, layout); mat_type = ProcLib::OPENCV; device = Device::CPU; return &cpu_mat; @@ -49,6 +49,23 @@ cv::Mat* Mat::GetOpenCVMat() { } } +#ifdef ENABLE_FLYCV +fcv::Mat* Mat::GetFlyCVMat() { + if (mat_type == ProcLib::FLYCV) { + return &fcv_mat; + } else if (mat_type == ProcLib::OPENCV) { + // Just a reference to cpu_mat, zero copy. After you + // call this method, fcv_mat and cpu_mat will point + // to the same memory buffer. + fcv_mat = ConvertOpenCVMatToFlyCV(cpu_mat); + mat_type = ProcLib::FLYCV; + return &fcv_mat; + } else { + FDASSERT(false, "The mat_type of custom Mat can not be ProcLib::DEFAULT"); + } +} +#endif + void* Mat::Data() { if (mat_type == ProcLib::FLYCV) { #ifdef ENABLE_FLYCV @@ -158,7 +175,7 @@ void Mat::PrintInfo(const std::string& flag) { #ifdef WITH_GPU FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess, "[ERROR] Error occurs while sync cuda stream."); - cv::Mat tmp_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor); + cv::Mat tmp_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor, layout); cv::Scalar mean = cv::mean(tmp_mat); for (int i = 0; i < Channels(); ++i) { std::cout << mean[i] << " "; diff --git a/fastdeploy/vision/common/processors/mat.h b/fastdeploy/vision/common/processors/mat.h index 85f121b90..49c407a4b 100644 --- a/fastdeploy/vision/common/processors/mat.h +++ b/fastdeploy/vision/common/processors/mat.h @@ -13,10 +13,13 @@ // limitations under the License. #pragma once #include "fastdeploy/core/fd_tensor.h" -#include "fastdeploy/vision/common/processors/utils.h" #include "fastdeploy/vision/common/processors/proc_lib.h" #include "opencv2/core/core.hpp" +#ifdef ENABLE_FLYCV +#include "flycv.h" // NOLINT +#endif + #ifdef WITH_GPU #include #endif @@ -70,21 +73,7 @@ struct FASTDEPLOY_DECL Mat { fcv_mat = mat; mat_type = ProcLib::FLYCV; } - - fcv::Mat* GetFlyCVMat() { - if (mat_type == ProcLib::FLYCV) { - return &fcv_mat; - } else if (mat_type == ProcLib::OPENCV) { - // Just a reference to cpu_mat, zero copy. After you - // call this method, fcv_mat and cpu_mat will point - // to the same memory buffer. - fcv_mat = ConvertOpenCVMatToFlyCV(cpu_mat); - mat_type = ProcLib::FLYCV; - return &fcv_mat; - } else { - FDASSERT(false, "The mat_type of custom Mat can not be ProcLib::DEFAULT"); - } - } + fcv::Mat* GetFlyCVMat(); #endif void* Data(); diff --git a/fastdeploy/vision/common/processors/normalize_and_permute.cu b/fastdeploy/vision/common/processors/normalize_and_permute.cu index da3f4ffb1..f1c7c2f25 100644 --- a/fastdeploy/vision/common/processors/normalize_and_permute.cu +++ b/fastdeploy/vision/common/processors/normalize_and_permute.cu @@ -40,12 +40,16 @@ __global__ void NormalizeAndPermuteKernel(const uint8_t* src, float* dst, } bool NormalizeAndPermute::ImplByCuda(FDMat* mat) { + if (mat->layout != Layout::HWC) { + FDERROR << "Only supports input with HWC layout." << std::endl; + return false; + } // Prepare input tensor FDTensor* src = CreateCachedGpuInputTensor(mat); // Prepare output tensor - mat->output_cache->Resize(src->Shape(), FDDataType::FP32, "output_cache", - Device::GPU); + mat->output_cache->Resize({src->shape[2], src->shape[0], src->shape[1]}, + FDDataType::FP32, "output_cache", Device::GPU); // Copy alpha and beta to GPU gpu_alpha_.Resize({1, 1, static_cast(alpha_.size())}, FDDataType::FP32, @@ -68,9 +72,8 @@ bool NormalizeAndPermute::ImplByCuda(FDMat* mat) { reinterpret_cast(gpu_beta_.Data()), mat->Channels(), swap_rb_, 1, jobs); - mat->SetTensor(mat->output_cache); - mat->device = Device::GPU; mat->layout = Layout::CHW; + mat->SetTensor(mat->output_cache); mat->mat_type = ProcLib::CUDA; return true; } @@ -112,7 +115,6 @@ bool NormalizeAndPermute::ImplByCuda(FDMatBatch* mat_batch) { mat_batch->output_cache->shape[0], jobs); mat_batch->SetTensor(mat_batch->output_cache); - mat_batch->device = Device::GPU; mat_batch->layout = FDMatBatchLayout::NCHW; mat_batch->mat_type = ProcLib::CUDA; return true; diff --git a/fastdeploy/vision/common/processors/pad_to_size.cc b/fastdeploy/vision/common/processors/pad_to_size.cc index faa6c5915..1f456dfc7 100644 --- a/fastdeploy/vision/common/processors/pad_to_size.cc +++ b/fastdeploy/vision/common/processors/pad_to_size.cc @@ -14,77 +14,62 @@ #include "fastdeploy/vision/common/processors/pad_to_size.h" +#include "fastdeploy/vision/common/processors/utils.h" + namespace fastdeploy { namespace vision { -bool PadToSize::ImplByOpenCV(Mat* mat) { - if (width_ == -1 || height_ == -1) { - return true; - } - if (mat->layout != Layout::HWC) { - FDERROR << "PadToSize: The input data must be Layout::HWC format!" - << std::endl; - return false; - } - if (mat->Channels() > 4) { - FDERROR << "PadToSize: Only support channels <= 4." << std::endl; - return false; - } - if (mat->Channels() != value_.size()) { - FDERROR - << "PadToSize: Require input channels equals to size of padding value, " - "but now channels = " - << mat->Channels() << ", the size of padding values = " << value_.size() - << "." << std::endl; - return false; - } +static bool PadHWCByOpenCV(FDMat* mat, int width, int height, + const std::vector& value) { int origin_w = mat->Width(); int origin_h = mat->Height(); - if (origin_w > width_) { - FDERROR << "PadToSize: the input width:" << origin_w - << " is greater than the target width: " << width_ << "." - << std::endl; - return false; - } - if (origin_h > height_) { - FDERROR << "PadToSize: the input height:" << origin_h - << " is greater than the target height: " << height_ << "." - << std::endl; - return false; - } - if (origin_w == width_ && origin_h == height_) { - return true; - } - cv::Mat* im = mat->GetOpenCVMat(); - cv::Scalar value; - if (value_.size() == 1) { - value = cv::Scalar(value_[0]); - } else if (value_.size() == 2) { - value = cv::Scalar(value_[0], value_[1]); - } else if (value_.size() == 3) { - value = cv::Scalar(value_[0], value_[1], value_[2]); + cv::Scalar scalar; + if (value.size() == 1) { + scalar = cv::Scalar(value[0]); + } else if (value.size() == 2) { + scalar = cv::Scalar(value[0], value[1]); + } else if (value.size() == 3) { + scalar = cv::Scalar(value[0], value[1], value[2]); } else { - value = cv::Scalar(value_[0], value_[1], value_[2], value_[3]); + scalar = cv::Scalar(value[0], value[1], value[2], value[3]); } // top, bottom, left, right - cv::copyMakeBorder(*im, *im, 0, height_ - origin_h, 0, width_ - origin_w, - cv::BORDER_CONSTANT, value); - mat->SetHeight(height_); - mat->SetWidth(width_); + cv::copyMakeBorder(*im, *im, 0, height - origin_h, 0, width - origin_w, + cv::BORDER_CONSTANT, scalar); + mat->SetHeight(height); + mat->SetWidth(width); return true; } -#ifdef ENABLE_FLYCV -bool PadToSize::ImplByFlyCV(Mat* mat) { - if (width_ == -1 || height_ == -1) { - return true; - } - if (mat->layout != Layout::HWC) { - FDERROR << "PadToSize: The input data must be Layout::HWC format!" - << std::endl; - return false; +static bool PadCHWByOpenCV(FDMat* mat, int width, int height, + const std::vector& value) { + int origin_w = mat->Width(); + int origin_h = mat->Height(); + cv::Mat* im = mat->GetOpenCVMat(); + cv::Mat new_im(height, width, + CreateOpenCVDataType(mat->Type(), mat->Channels())); + + for (int i = 0; i < mat->Channels(); ++i) { + uint8_t* src_data = + im->ptr() + i * origin_w * origin_h * FDDataTypeSize(mat->Type()); + cv::Mat src(origin_h, origin_w, CreateOpenCVDataType(mat->Type(), 1), + src_data); + + uint8_t* dst_data = + new_im.ptr() + i * width * height * FDDataTypeSize(mat->Type()); + cv::Mat dst(height, width, CreateOpenCVDataType(mat->Type(), 1), dst_data); + + cv::copyMakeBorder(src, dst, 0, height - origin_h, 0, width - origin_w, + cv::BORDER_CONSTANT, cv::Scalar(value[i])); } + mat->SetMat(new_im); + mat->SetHeight(height); + mat->SetWidth(width); + return true; +} + +bool PadToSize::CheckArgs(FDMat* mat) { if (mat->Channels() > 4) { FDERROR << "PadToSize: Only support channels <= 4." << std::endl; return false; @@ -97,45 +82,184 @@ bool PadToSize::ImplByFlyCV(Mat* mat) { << "." << std::endl; return false; } - int origin_w = mat->Width(); - int origin_h = mat->Height(); - if (origin_w > width_) { - FDERROR << "PadToSize: the input width:" << origin_w + if (mat->Width() > width_) { + FDERROR << "PadToSize: the input width:" << mat->Width() << " is greater than the target width: " << width_ << "." << std::endl; return false; } - if (origin_h > height_) { - FDERROR << "PadToSize: the input height:" << origin_h + if (mat->Height() > height_) { + FDERROR << "PadToSize: the input height:" << mat->Height() << " is greater than the target height: " << height_ << "." << std::endl; return false; } - if (origin_w == width_ && origin_h == height_) { + return true; +} + +bool PadToSize::ImplByOpenCV(FDMat* mat) { + if (width_ == -1 || height_ == -1 || + (mat->Width() == width_ && mat->Height() == height_)) { return true; } + if (CheckArgs(mat) == false) { + return false; + } + if (mat->layout == Layout::HWC) { + return PadHWCByOpenCV(mat, width_, height_, value_); + } else if (mat->layout == Layout::CHW) { + return PadCHWByOpenCV(mat, width_, height_, value_); + } + return false; +} +#ifdef ENABLE_FLYCV +static bool PadHWCByFlyCV(FDMat* mat, int width, int height, + const std::vector& value) { + int origin_w = mat->Width(); + int origin_h = mat->Height(); fcv::Mat* im = mat->GetFlyCVMat(); - fcv::Scalar value; - if (value_.size() == 1) { - value = fcv::Scalar(value_[0]); - } else if (value_.size() == 2) { - value = fcv::Scalar(value_[0], value_[1]); - } else if (value_.size() == 3) { - value = fcv::Scalar(value_[0], value_[1], value_[2]); + fcv::Scalar scalar; + if (value.size() == 1) { + scalar = fcv::Scalar(value[0]); + } else if (value.size() == 2) { + scalar = fcv::Scalar(value[0], value[1]); + } else if (value.size() == 3) { + scalar = fcv::Scalar(value[0], value[1], value[2]); } else { - value = fcv::Scalar(value_[0], value_[1], value_[2], value_[3]); + scalar = fcv::Scalar(value[0], value[1], value[2], value[3]); } fcv::Mat new_im; // top, bottom, left, right - fcv::copy_make_border(*im, new_im, 0, height_ - origin_h, 0, - width_ - origin_w, fcv::BorderType::BORDER_CONSTANT, - value); + fcv::copy_make_border(*im, new_im, 0, height - origin_h, 0, width - origin_w, + fcv::BorderType::BORDER_CONSTANT, scalar); mat->SetMat(new_im); - mat->SetHeight(height_); - mat->SetWidth(width_); + mat->SetHeight(height); + mat->SetWidth(width); return true; } + +static bool PadCHWByFlyCV(FDMat* mat, int width, int height, + const std::vector& value) { + int origin_w = mat->Width(); + int origin_h = mat->Height(); + fcv::Mat new_im(height, width, + CreateFlyCVDataType(mat->Type(), mat->Channels())); + for (int i = 0; i < mat->Channels(); ++i) { + uint8_t* src_data = reinterpret_cast(mat->Data()) + + i * origin_w * origin_h * FDDataTypeSize(mat->Type()); + fcv::Mat src(origin_h, origin_w, CreateFlyCVDataType(mat->Type(), 1), + src_data); + + uint8_t* dst_data = reinterpret_cast(new_im.data()) + + i * width * height * FDDataTypeSize(mat->Type()); + fcv::Mat dst(height, width, CreateFlyCVDataType(mat->Type(), 1), dst_data); + + fcv::copy_make_border(src, dst, 0, height - origin_h, 0, width - origin_w, + fcv::BorderType::BORDER_CONSTANT, + fcv::Scalar(value[i])); + } + mat->SetMat(new_im); + mat->SetHeight(height); + mat->SetWidth(width); + return true; +} + +bool PadToSize::ImplByFlyCV(FDMat* mat) { + if (width_ == -1 || height_ == -1 || + (mat->Width() == width_ && mat->Height() == height_)) { + return true; + } + if (CheckArgs(mat) == false) { + return false; + } + if (mat->layout == Layout::HWC) { + return PadHWCByFlyCV(mat, width_, height_, value_); + } else if (mat->layout == Layout::CHW) { + return PadCHWByFlyCV(mat, width_, height_, value_); + } + return false; +} +#endif + +#ifdef ENABLE_CVCUDA +static bool PadHWCByCvCuda(cvcuda::CopyMakeBorder& pad_op, FDMat* mat, + int width, int height, + const std::vector& value) { + float4 border_value; + if (value.size() == 1) { + border_value = make_float4(value[0], 0.0f, 0.0f, 0.0f); + } else if (value.size() == 2) { + border_value = make_float4(value[0], value[1], 0.0f, 0.0f); + } else if (value.size() == 3) { + border_value = make_float4(value[0], value[1], value[2], 0.0f); + } else { + border_value = make_float4(value[0], value[1], value[2], value[3]); + } + + // Prepare input tensor + FDTensor* src = CreateCachedGpuInputTensor(mat); + auto src_tensor = CreateCvCudaTensorWrapData(*src); + + // Prepare output tensor + mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(), + "output_cache", Device::GPU); + auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache)); + + pad_op(mat->Stream(), src_tensor, dst_tensor, 0, 0, NVCV_BORDER_CONSTANT, + border_value); + + mat->SetTensor(mat->output_cache); + mat->mat_type = ProcLib::CVCUDA; + return true; +} + +static bool PadCHWByCvCuda(cvcuda::CopyMakeBorder& pad_op, FDMat* mat, + int width, int height, + const std::vector& value) { + float4 border_value = make_float4(value[0], 0.0f, 0.0f, 0.0f); + FDTensor* input = CreateCachedGpuInputTensor(mat); + int channels = input->shape[0]; + mat->output_cache->Resize({channels, height, width}, mat->Type(), + "output_cache", Device::GPU); + for (int i = 0; i < channels; ++i) { + uint8_t* src_data = + reinterpret_cast(input->Data()) + + i * mat->Width() * mat->Height() * FDDataTypeSize(mat->Type()); + FDTensor src; + src.SetExternalData({mat->Height(), mat->Width(), 1}, input->Dtype(), + src_data, input->device, input->device_id); + auto src_tensor = CreateCvCudaTensorWrapData(src); + + uint8_t* dst_data = reinterpret_cast(mat->output_cache->Data()) + + i * width * height * FDDataTypeSize(mat->Type()); + FDTensor dst; + dst.SetExternalData({height, width, 1}, input->Dtype(), dst_data, + input->device, input->device_id); + auto dst_tensor = CreateCvCudaTensorWrapData(dst); + + pad_op(mat->Stream(), src_tensor, dst_tensor, 0, 0, NVCV_BORDER_CONSTANT, + border_value); + } + mat->SetTensor(mat->output_cache); + mat->mat_type = ProcLib::CVCUDA; + return true; +} +bool PadToSize::ImplByCvCuda(FDMat* mat) { + if (width_ == -1 || height_ == -1 || + (mat->Width() == width_ && mat->Height() == height_)) { + return true; + } + if (CheckArgs(mat) == false) { + return false; + } + if (mat->layout == Layout::HWC) { + return PadHWCByCvCuda(cvcuda_pad_op_, mat, width_, height_, value_); + } else if (mat->layout == Layout::CHW) { + return PadCHWByCvCuda(cvcuda_pad_op_, mat, width_, height_, value_); + } + return false; +} #endif bool PadToSize::Run(Mat* mat, int width, int height, diff --git a/fastdeploy/vision/common/processors/pad_to_size.h b/fastdeploy/vision/common/processors/pad_to_size.h index c73cee3c2..e445fd748 100644 --- a/fastdeploy/vision/common/processors/pad_to_size.h +++ b/fastdeploy/vision/common/processors/pad_to_size.h @@ -15,6 +15,11 @@ #pragma once #include "fastdeploy/vision/common/processors/base.h" +#ifdef ENABLE_CVCUDA +#include + +#include "fastdeploy/vision/common/processors/cvcuda_utils.h" +#endif namespace fastdeploy { namespace vision { @@ -30,6 +35,9 @@ class FASTDEPLOY_DECL PadToSize : public Processor { bool ImplByOpenCV(Mat* mat); #ifdef ENABLE_FLYCV bool ImplByFlyCV(Mat* mat); +#endif +#ifdef ENABLE_CVCUDA + bool ImplByCvCuda(FDMat* mat); #endif std::string Name() { return "PadToSize"; } @@ -37,10 +45,19 @@ class FASTDEPLOY_DECL PadToSize : public Processor { const std::vector& value, ProcLib lib = ProcLib::DEFAULT); + void SetWidthHeight(int width, int height) { + width_ = width; + height_ = height; + } + private: + bool CheckArgs(FDMat* mat); int width_; int height_; std::vector value_; +#ifdef ENABLE_CVCUDA + cvcuda::CopyMakeBorder cvcuda_pad_op_; +#endif }; } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/resize.cc b/fastdeploy/vision/common/processors/resize.cc index 806eab643..538ed419f 100644 --- a/fastdeploy/vision/common/processors/resize.cc +++ b/fastdeploy/vision/common/processors/resize.cc @@ -147,7 +147,7 @@ bool Resize::ImplByCvCuda(FDMat* mat) { // CV-CUDA Interp value is compatible with OpenCV cvcuda_resize_op_(mat->Stream(), src_tensor, dst_tensor, - NVCVInterpolationType(interp_)); + CreateCvCudaInterp(interp_)); mat->SetTensor(mat->output_cache); mat->SetWidth(width_); diff --git a/fastdeploy/vision/common/processors/resize_by_short.cc b/fastdeploy/vision/common/processors/resize_by_short.cc index 7fe644e0d..8ac650dd7 100644 --- a/fastdeploy/vision/common/processors/resize_by_short.cc +++ b/fastdeploy/vision/common/processors/resize_by_short.cc @@ -95,9 +95,8 @@ bool ResizeByShort::ImplByCvCuda(FDMat* mat) { "output_cache", Device::GPU); auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache)); - // CV-CUDA Interp value is compatible with OpenCV cvcuda_resize_op_(mat->Stream(), src_tensor, dst_tensor, - NVCVInterpolationType(interp_)); + CreateCvCudaInterp(interp_)); mat->SetTensor(mat->output_cache); mat->SetWidth(width); @@ -138,7 +137,7 @@ bool ResizeByShort::ImplByCvCuda(FDMatBatch* mat_batch) { // CV-CUDA Interp value is compatible with OpenCV cvcuda_resize_op_(mat_batch->Stream(), src_batch, dst_batch, - NVCVInterpolationType(interp_)); + CreateCvCudaInterp(interp_)); for (size_t i = 0; i < mat_batch->mats->size(); ++i) { FDMat* mat = &(*(mat_batch->mats))[i]; diff --git a/fastdeploy/vision/common/processors/stride_pad.cc b/fastdeploy/vision/common/processors/stride_pad.cc index e41284709..6b9506e85 100644 --- a/fastdeploy/vision/common/processors/stride_pad.cc +++ b/fastdeploy/vision/common/processors/stride_pad.cc @@ -114,6 +114,68 @@ bool StridePad::ImplByFlyCV(Mat* mat) { } #endif +#ifdef ENABLE_CVCUDA +bool StridePad::ImplByCvCuda(FDMat* mat) { + if (mat->layout != Layout::HWC) { + FDERROR << "StridePad: The input data must be Layout::HWC format!" + << std::endl; + return false; + } + if (mat->Channels() > 4) { + FDERROR << "StridePad: Only support channels <= 4." << std::endl; + return false; + } + if (mat->Channels() != value_.size()) { + FDERROR + << "StridePad: Require input channels equals to size of padding value, " + "but now channels = " + << mat->Channels() << ", the size of padding values = " << value_.size() + << "." << std::endl; + return false; + } + int origin_w = mat->Width(); + int origin_h = mat->Height(); + + int pad_h = (mat->Height() / stride_) * stride_ + + (mat->Height() % stride_ != 0) * stride_ - mat->Height(); + int pad_w = (mat->Width() / stride_) * stride_ + + (mat->Width() % stride_ != 0) * stride_ - mat->Width(); + if (pad_h == 0 && pad_w == 0) { + return true; + } + + float4 value; + if (value_.size() == 1) { + value = make_float4(value_[0], 0.0f, 0.0f, 0.0f); + } else if (value_.size() == 2) { + value = make_float4(value_[0], value_[1], 0.0f, 0.0f); + } else if (value_.size() == 3) { + value = make_float4(value_[0], value_[1], value_[2], 0.0f); + } else { + value = make_float4(value_[0], value_[1], value_[2], value_[3]); + } + + // Prepare input tensor + FDTensor* src = CreateCachedGpuInputTensor(mat); + auto src_tensor = CreateCvCudaTensorWrapData(*src); + + int height = mat->Height() + pad_h; + int width = mat->Width() + pad_w; + + // Prepare output tensor + mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(), + "output_cache", Device::GPU); + auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache)); + + cvcuda_pad_op_(mat->Stream(), src_tensor, dst_tensor, 0, 0, + NVCV_BORDER_CONSTANT, value); + + mat->SetTensor(mat->output_cache); + mat->mat_type = ProcLib::CVCUDA; + return true; +} +#endif + bool StridePad::Run(Mat* mat, int stride, const std::vector& value, ProcLib lib) { auto p = StridePad(stride, value); diff --git a/fastdeploy/vision/common/processors/stride_pad.h b/fastdeploy/vision/common/processors/stride_pad.h index 18eebd54e..b7cc6c815 100644 --- a/fastdeploy/vision/common/processors/stride_pad.h +++ b/fastdeploy/vision/common/processors/stride_pad.h @@ -15,6 +15,11 @@ #pragma once #include "fastdeploy/vision/common/processors/base.h" +#ifdef ENABLE_CVCUDA +#include + +#include "fastdeploy/vision/common/processors/cvcuda_utils.h" +#endif namespace fastdeploy { namespace vision { @@ -29,6 +34,9 @@ class FASTDEPLOY_DECL StridePad : public Processor { bool ImplByOpenCV(Mat* mat); #ifdef ENABLE_FLYCV bool ImplByFlyCV(Mat* mat); +#endif +#ifdef ENABLE_CVCUDA + bool ImplByCvCuda(FDMat* mat); #endif std::string Name() { return "StridePad"; } @@ -39,6 +47,9 @@ class FASTDEPLOY_DECL StridePad : public Processor { private: int stride_ = 32; std::vector value_; +#ifdef ENABLE_CVCUDA + cvcuda::CopyMakeBorder cvcuda_pad_op_; +#endif }; } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/utils.cc b/fastdeploy/vision/common/processors/utils.cc index 74f3bd5d0..353d33afc 100644 --- a/fastdeploy/vision/common/processors/utils.cc +++ b/fastdeploy/vision/common/processors/utils.cc @@ -186,9 +186,8 @@ cv::Mat ConvertFlyCVMatToOpenCV(fcv::Mat& fim) { } #endif -cv::Mat CreateZeroCopyOpenCVMatFromBuffer( - int height, int width, int channels, - FDDataType type, void* data) { +cv::Mat CreateZeroCopyOpenCVMatFromBuffer(int height, int width, int channels, + FDDataType type, void* data) { cv::Mat ocv_mat; switch (type) { case FDDataType::UINT8: @@ -219,61 +218,61 @@ cv::Mat CreateZeroCopyOpenCVMatFromBuffer( return ocv_mat; } -cv::Mat CreateZeroCopyOpenCVMatFromTensor(const FDTensor& tensor) { - // TODO(qiuyanjun): Should add a Layout checking. Now, we - // assume that the input tensor is already in Layout::HWC. - FDASSERT(tensor.shape.size() == 3, "When create OepnCV Mat from tensor," - "tensor shape should be 3-Dim, HWC layout"); +cv::Mat CreateZeroCopyOpenCVMatFromTensor(const FDTensor& tensor, + Layout layout) { + FDASSERT(tensor.shape.size() == 3, + "When create OepnCV Mat from tensor," + "tensor shape should be 3-Dim"); FDDataType type = tensor.dtype; int height = static_cast(tensor.shape[0]); int width = static_cast(tensor.shape[1]); int channels = static_cast(tensor.shape[2]); - return CreateZeroCopyOpenCVMatFromBuffer( - height, width, channels, type, - const_cast(tensor.CpuData())); + if (layout == Layout::CHW) { + channels = static_cast(tensor.shape[0]); + height = static_cast(tensor.shape[1]); + width = static_cast(tensor.shape[2]); + } + return CreateZeroCopyOpenCVMatFromBuffer(height, width, channels, type, + const_cast(tensor.CpuData())); } #ifdef ENABLE_FLYCV -fcv::Mat CreateZeroCopyFlyCVMatFromBuffer( - int height, int width, int channels, - FDDataType type, void* data) { +fcv::Mat CreateZeroCopyFlyCVMatFromBuffer(int height, int width, int channels, + FDDataType type, void* data) { fcv::Mat fcv_mat; auto fcv_type = CreateFlyCVDataType(type, channels); switch (type) { case FDDataType::UINT8: - fcv_mat = - fcv::Mat(width, height, fcv_type, data); + fcv_mat = fcv::Mat(width, height, fcv_type, data); break; case FDDataType::FP32: - fcv_mat = - fcv::Mat(width, height, fcv_type, data); + fcv_mat = fcv::Mat(width, height, fcv_type, data); break; case FDDataType::FP64: - fcv_mat = - fcv::Mat(width, height, fcv_type, data); - break; + fcv_mat = fcv::Mat(width, height, fcv_type, data); + break; default: FDASSERT(false, - "Tensor type %d is not supported While calling " - "CreateZeroCopyFlyCVMat.", + "Tensor type %d is not supported While calling " + "CreateZeroCopyFlyCVMat.", type); - break; + break; } return fcv_mat; } fcv::Mat CreateZeroCopyFlyCVMatFromTensor(const FDTensor& tensor) { - // TODO(qiuyanjun): Should add a Layout checking. Now, we - // assume that the input tensor is already in Layout::HWC. - FDASSERT(tensor.shape.size() == 3, "When create FlyCV Mat from tensor," - "tensor shape should be 3-Dim, HWC layout"); + // TODO(qiuyanjun): Should add a Layout checking. Now, we + // assume that the input tensor is already in Layout::HWC. + FDASSERT(tensor.shape.size() == 3, + "When create FlyCV Mat from tensor," + "tensor shape should be 3-Dim, HWC layout"); FDDataType type = tensor.dtype; int height = static_cast(tensor.shape[0]); int width = static_cast(tensor.shape[1]); int channels = static_cast(tensor.shape[2]); - return CreateZeroCopyFlyCVMatFromBuffer( - height, width, channels, type, - const_cast(tensor.Data())); + return CreateZeroCopyFlyCVMatFromBuffer(height, width, channels, type, + const_cast(tensor.Data())); } #endif diff --git a/fastdeploy/vision/common/processors/utils.h b/fastdeploy/vision/common/processors/utils.h index 50074c3a0..8060e4a91 100644 --- a/fastdeploy/vision/common/processors/utils.h +++ b/fastdeploy/vision/common/processors/utils.h @@ -16,6 +16,7 @@ #include "fastdeploy/core/fd_tensor.h" #include "fastdeploy/utils/utils.h" +#include "fastdeploy/vision/common/processors/mat.h" #include "opencv2/core/core.hpp" #ifdef ENABLE_FLYCV @@ -42,11 +43,12 @@ cv::Mat ConvertFlyCVMatToOpenCV(fcv::Mat& fim); // Create zero copy OpenCV/FlyCV Mat from FD Tensor / Buffer cv::Mat CreateZeroCopyOpenCVMatFromBuffer(int height, int width, - int channels, FDDataType type, void* data); -cv::Mat CreateZeroCopyOpenCVMatFromTensor(const FDTensor& tensor); + int channels, FDDataType type, void* data); +cv::Mat CreateZeroCopyOpenCVMatFromTensor(const FDTensor& tensor, + Layout layout = Layout::HWC); #ifdef ENABLE_FLYCV fcv::Mat CreateZeroCopyFlyCVMatFromBuffer(int height, int width, - int channels, FDDataType type, void* data); + int channels, FDDataType type, void* data); fcv::Mat CreateZeroCopyFlyCVMatFromTensor(const FDTensor& tensor); #endif } // namespace vision diff --git a/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc b/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc index f17fb6f66..3b4af6f88 100644 --- a/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc +++ b/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc @@ -15,8 +15,8 @@ namespace fastdeploy { void BindPPDet(pybind11::module& m) { - pybind11::class_( - m, "PaddleDetPreprocessor") + pybind11::class_(m, "PaddleDetPreprocessor") .def(pybind11::init()) .def("run", [](vision::detection::PaddleDetPreprocessor& self, diff --git a/fastdeploy/vision/detection/ppdet/preprocessor.cc b/fastdeploy/vision/detection/ppdet/preprocessor.cc index 5755fa7a0..8c6141576 100644 --- a/fastdeploy/vision/detection/ppdet/preprocessor.cc +++ b/fastdeploy/vision/detection/ppdet/preprocessor.cc @@ -129,13 +129,13 @@ bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig() { return true; } -bool PaddleDetPreprocessor::Run(std::vector* images, - std::vector* outputs) { +bool PaddleDetPreprocessor::Apply(FDMatBatch* image_batch, + std::vector* outputs) { if (!initialized_) { FDERROR << "The preprocessor is not initialized." << std::endl; return false; } - if (images->empty()) { + if (image_batch->mats->empty()) { FDERROR << "The size of input images should be greater than 0." << std::endl; return false; @@ -146,7 +146,7 @@ bool PaddleDetPreprocessor::Run(std::vector* images, // So preprocessor will output the 3 FDTensors, and how to use `im_shape` // is decided by the model itself outputs->resize(3); - int batch = static_cast(images->size()); + int batch = static_cast(image_batch->mats->size()); // Allocate memory for scale_factor (*outputs)[1].Resize({batch, 2}, FDDataType::FP32); // Allocate memory for im_shape @@ -158,63 +158,51 @@ bool PaddleDetPreprocessor::Run(std::vector* images, auto* scale_factor_ptr = reinterpret_cast((*outputs)[1].MutableData()); auto* im_shape_ptr = reinterpret_cast((*outputs)[2].MutableData()); - for (size_t i = 0; i < images->size(); ++i) { - int origin_w = (*images)[i].Width(); - int origin_h = (*images)[i].Height(); + for (size_t i = 0; i < image_batch->mats->size(); ++i) { + FDMat* mat = &(image_batch->mats->at(i)); + int origin_w = mat->Width(); + int origin_h = mat->Height(); scale_factor_ptr[2 * i] = 1.0; scale_factor_ptr[2 * i + 1] = 1.0; for (size_t j = 0; j < processors_.size(); ++j) { - if (!(*(processors_[j].get()))(&((*images)[i]))) { + if (!(*(processors_[j].get()))(mat)) { FDERROR << "Failed to processs image:" << i << " in " - << processors_[i]->Name() << "." << std::endl; + << processors_[j]->Name() << "." << std::endl; return false; } if (processors_[j]->Name().find("Resize") != std::string::npos) { - scale_factor_ptr[2 * i] = (*images)[i].Height() * 1.0 / origin_h; - scale_factor_ptr[2 * i + 1] = (*images)[i].Width() * 1.0 / origin_w; + scale_factor_ptr[2 * i] = mat->Height() * 1.0 / origin_h; + scale_factor_ptr[2 * i + 1] = mat->Width() * 1.0 / origin_w; } } - if ((*images)[i].Height() > max_hw[0]) { - max_hw[0] = (*images)[i].Height(); + if (mat->Height() > max_hw[0]) { + max_hw[0] = mat->Height(); } - if ((*images)[i].Width() > max_hw[1]) { - max_hw[1] = (*images)[i].Width(); + if (mat->Width() > max_hw[1]) { + max_hw[1] = mat->Width(); } im_shape_ptr[2 * i] = max_hw[0]; im_shape_ptr[2 * i + 1] = max_hw[1]; } - // Concat all the preprocessed data to a batch tensor - std::vector im_tensors(images->size()); - for (size_t i = 0; i < images->size(); ++i) { - if ((*images)[i].Height() < max_hw[0] || (*images)[i].Width() < max_hw[1]) { - // if the size of image less than max_hw, pad to max_hw - FDTensor tensor; - (*images)[i].ShareWithTensor(&tensor); - function::Pad(tensor, &(im_tensors[i]), - {0, 0, max_hw[0] - (*images)[i].Height(), - max_hw[1] - (*images)[i].Width()}, - 0); - } else { - // No need pad - (*images)[i].ShareWithTensor(&(im_tensors[i])); + // if the size of image less than max_hw, pad to max_hw + for (size_t i = 0; i < image_batch->mats->size(); ++i) { + FDMat* mat = &(image_batch->mats->at(i)); + if (mat->Height() < max_hw[0] || mat->Width() < max_hw[1]) { + pad_op_->SetWidthHeight(max_hw[1], max_hw[0]); + (*pad_op_)(mat); } - // Reshape to 1xCxHxW - im_tensors[i].ExpandDim(0); } - if (im_tensors.size() == 1) { - // If there's only 1 input, no need to concat - // skip memory copy - (*outputs)[0] = std::move(im_tensors[0]); - } else { - // Else concat the im tensor for each input image - // compose a batched input tensor - function::Concat(im_tensors, &((*outputs)[0]), 0); - } + // Get the NCHW tensor + FDTensor* tensor = image_batch->Tensor(); + (*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(), + tensor->Data(), tensor->device, + tensor->device_id); return true; } + void PaddleDetPreprocessor::DisableNormalize() { this->disable_normalize_ = true; // the DisableNormalize function will be invalid if the configuration file is @@ -224,6 +212,7 @@ void PaddleDetPreprocessor::DisableNormalize() { << std::endl; } } + void PaddleDetPreprocessor::DisablePermute() { this->disable_permute_ = true; // the DisablePermute function will be invalid if the configuration file is diff --git a/fastdeploy/vision/detection/ppdet/preprocessor.h b/fastdeploy/vision/detection/ppdet/preprocessor.h index 9ce9dec32..a4856c64c 100644 --- a/fastdeploy/vision/detection/ppdet/preprocessor.h +++ b/fastdeploy/vision/detection/ppdet/preprocessor.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include "fastdeploy/vision/common/processors/manager.h" #include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/result.h" @@ -22,7 +23,7 @@ namespace vision { namespace detection { /*! @brief Preprocessor object for PaddleDet serials model. */ -class FASTDEPLOY_DECL PaddleDetPreprocessor { +class FASTDEPLOY_DECL PaddleDetPreprocessor : public ProcessorManager { public: PaddleDetPreprocessor() = default; /** \brief Create a preprocessor instance for PaddleDet serials model @@ -31,13 +32,16 @@ class FASTDEPLOY_DECL PaddleDetPreprocessor { */ explicit PaddleDetPreprocessor(const std::string& config_file); - /** \brief Process the input image and prepare input tensors for runtime + /** \brief Implement the virtual function of ProcessorManager, Apply() is the + * body of Run(). Apply() contains the main logic of preprocessing, Run() is + * called by users to execute preprocessing * - * \param[in] images The input image data list, all the elements are returned by cv::imread() - * \param[in] outputs The output tensors which will feed in runtime, include image, scale_factor, im_shape + * \param[in] image_batch The input image batch + * \param[in] outputs The output tensors which will feed in runtime * \return true if the preprocess successed, otherwise false */ - bool Run(std::vector* images, std::vector* outputs); + virtual bool Apply(FDMatBatch* image_batch, + std::vector* outputs); /// This function will disable normalize in preprocessing step. void DisableNormalize(); @@ -51,6 +55,8 @@ class FASTDEPLOY_DECL PaddleDetPreprocessor { private: bool BuildPreprocessPipelineFromConfig(); std::vector> processors_; + std::shared_ptr pad_op_ = + std::make_shared(0, 0, std::vector(3, 0)); bool initialized_ = false; // for recording the switch of hwc2chw bool disable_permute_ = false;