mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 09:31:35 +08:00
[CVCUDA] PaddleDetection preprocessor support CV-CUDA (#1493)
* ppdet preproc use manager * pad_to_size chw opencv * pad_to_size chw flycv * fix pad_to_size flycv * add warning message * cvcuda convert cubic to linear, padToSize cvcuda * stridedpad cvcuda * fix flycv include * fix flycv include * fix flycv build * cast cvcuda * fix pybind * fix normalize permute cuda * base processor move funcs to cc * Update pad_to_size.cc
This commit is contained in:
@@ -138,8 +138,10 @@ bool PaddleClasPreprocessor::Apply(FDMatBatch* image_batch,
|
|||||||
}
|
}
|
||||||
|
|
||||||
outputs->resize(1);
|
outputs->resize(1);
|
||||||
(*outputs)[0] = std::move(*(image_batch->Tensor()));
|
FDTensor* tensor = image_batch->Tensor();
|
||||||
(*outputs)[0].device_id = DeviceId();
|
(*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
|
||||||
|
tensor->Data(), tensor->device,
|
||||||
|
tensor->device_id);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -31,7 +31,9 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager {
|
|||||||
*/
|
*/
|
||||||
explicit PaddleClasPreprocessor(const std::string& config_file);
|
explicit PaddleClasPreprocessor(const std::string& config_file);
|
||||||
|
|
||||||
/** \brief Process the input image and prepare input tensors for runtime
|
/** \brief Implement the virtual function of ProcessorManager, Apply() is the
|
||||||
|
* body of Run(). Apply() contains the main logic of preprocessing, Run() is
|
||||||
|
* called by users to execute preprocessing
|
||||||
*
|
*
|
||||||
* \param[in] image_batch The input image batch
|
* \param[in] image_batch The input image batch
|
||||||
* \param[in] outputs The output tensors which will feed in runtime
|
* \param[in] outputs The output tensors which will feed in runtime
|
||||||
|
@@ -20,6 +20,63 @@
|
|||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
|
|
||||||
|
bool Processor::ImplByOpenCV(FDMat* mat) {
|
||||||
|
FDERROR << Name() << " Not Implement Yet." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Processor::ImplByOpenCV(FDMatBatch* mat_batch) {
|
||||||
|
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
||||||
|
if (ImplByOpenCV(&(*(mat_batch->mats))[i]) != true) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Processor::ImplByFlyCV(FDMat* mat) { return ImplByOpenCV(mat); }
|
||||||
|
|
||||||
|
bool Processor::ImplByFlyCV(FDMatBatch* mat_batch) {
|
||||||
|
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
||||||
|
if (ImplByFlyCV(&(*(mat_batch->mats))[i]) != true) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Processor::ImplByCuda(FDMat* mat) {
|
||||||
|
FDWARNING << Name()
|
||||||
|
<< " is not implemented with CUDA, will fallback to OpenCV."
|
||||||
|
<< std::endl;
|
||||||
|
return ImplByOpenCV(mat);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Processor::ImplByCuda(FDMatBatch* mat_batch) {
|
||||||
|
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
||||||
|
if (ImplByCuda(&(*(mat_batch->mats))[i]) != true) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Processor::ImplByCvCuda(FDMat* mat) {
|
||||||
|
FDWARNING << Name()
|
||||||
|
<< " is not implemented with CV-CUDA, will fallback to OpenCV."
|
||||||
|
<< std::endl;
|
||||||
|
return ImplByOpenCV(mat);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Processor::ImplByCvCuda(FDMatBatch* mat_batch) {
|
||||||
|
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
||||||
|
if (ImplByCvCuda(&(*(mat_batch->mats))[i]) != true) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool Processor::operator()(FDMat* mat) {
|
bool Processor::operator()(FDMat* mat) {
|
||||||
ProcLib target = mat->proc_lib;
|
ProcLib target = mat->proc_lib;
|
||||||
if (mat->proc_lib == ProcLib::DEFAULT) {
|
if (mat->proc_lib == ProcLib::DEFAULT) {
|
||||||
|
@@ -47,58 +47,17 @@ class FASTDEPLOY_DECL Processor {
|
|||||||
|
|
||||||
virtual std::string Name() = 0;
|
virtual std::string Name() = 0;
|
||||||
|
|
||||||
virtual bool ImplByOpenCV(FDMat* mat) {
|
virtual bool ImplByOpenCV(FDMat* mat);
|
||||||
FDERROR << Name() << " Not Implement Yet." << std::endl;
|
virtual bool ImplByOpenCV(FDMatBatch* mat_batch);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool ImplByOpenCV(FDMatBatch* mat_batch) {
|
virtual bool ImplByFlyCV(FDMat* mat);
|
||||||
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
virtual bool ImplByFlyCV(FDMatBatch* mat_batch);
|
||||||
if (ImplByOpenCV(&(*(mat_batch->mats))[i]) != true) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool ImplByFlyCV(FDMat* mat) {
|
virtual bool ImplByCuda(FDMat* mat);
|
||||||
return ImplByOpenCV(mat);
|
virtual bool ImplByCuda(FDMatBatch* mat_batch);
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool ImplByFlyCV(FDMatBatch* mat_batch) {
|
virtual bool ImplByCvCuda(FDMat* mat);
|
||||||
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
virtual bool ImplByCvCuda(FDMatBatch* mat_batch);
|
||||||
if (ImplByFlyCV(&(*(mat_batch->mats))[i]) != true) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool ImplByCuda(FDMat* mat) {
|
|
||||||
return ImplByOpenCV(mat);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool ImplByCuda(FDMatBatch* mat_batch) {
|
|
||||||
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
|
||||||
if (ImplByCuda(&(*(mat_batch->mats))[i]) != true) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool ImplByCvCuda(FDMat* mat) {
|
|
||||||
return ImplByOpenCV(mat);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool ImplByCvCuda(FDMatBatch* mat_batch) {
|
|
||||||
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
|
||||||
if (ImplByCvCuda(&(*(mat_batch->mats))[i]) != true) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool operator()(FDMat* mat);
|
virtual bool operator()(FDMat* mat);
|
||||||
|
|
||||||
|
@@ -14,6 +14,8 @@
|
|||||||
|
|
||||||
#include "fastdeploy/vision/common/processors/cast.h"
|
#include "fastdeploy/vision/common/processors/cast.h"
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/common/processors/utils.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
|
|
||||||
@@ -68,6 +70,40 @@ bool Cast::ImplByFlyCV(Mat* mat) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef ENABLE_CVCUDA
|
||||||
|
bool Cast::ImplByCvCuda(FDMat* mat) {
|
||||||
|
FDDataType dst_dtype;
|
||||||
|
if (dtype_ == "float") {
|
||||||
|
dst_dtype = FDDataType::FP32;
|
||||||
|
} else if (dtype_ == "double") {
|
||||||
|
dst_dtype = FDDataType::FP64;
|
||||||
|
} else {
|
||||||
|
FDWARNING << "Cast not support for " << dtype_
|
||||||
|
<< " now! will skip this operation." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (mat->Type() == dst_dtype) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare input tensor
|
||||||
|
FDTensor* src = CreateCachedGpuInputTensor(mat);
|
||||||
|
auto src_tensor = CreateCvCudaTensorWrapData(*src);
|
||||||
|
|
||||||
|
// Prepare output tensor
|
||||||
|
mat->output_cache->Resize(src->Shape(), dst_dtype, "output_cache",
|
||||||
|
Device::GPU);
|
||||||
|
auto dst_tensor =
|
||||||
|
CreateCvCudaTensorWrapData(*(mat->output_cache), mat->layout);
|
||||||
|
|
||||||
|
cvcuda_convert_op_(mat->Stream(), src_tensor, dst_tensor, 1.0f, 0.0f);
|
||||||
|
|
||||||
|
mat->SetTensor(mat->output_cache);
|
||||||
|
mat->mat_type = ProcLib::CVCUDA;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
bool Cast::Run(Mat* mat, const std::string& dtype, ProcLib lib) {
|
bool Cast::Run(Mat* mat, const std::string& dtype, ProcLib lib) {
|
||||||
auto c = Cast(dtype);
|
auto c = Cast(dtype);
|
||||||
return c(mat, lib);
|
return c(mat, lib);
|
||||||
|
@@ -15,6 +15,11 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "fastdeploy/vision/common/processors/base.h"
|
#include "fastdeploy/vision/common/processors/base.h"
|
||||||
|
#ifdef ENABLE_CVCUDA
|
||||||
|
#include <cvcuda/OpConvertTo.hpp>
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
@@ -25,6 +30,9 @@ class FASTDEPLOY_DECL Cast : public Processor {
|
|||||||
bool ImplByOpenCV(Mat* mat);
|
bool ImplByOpenCV(Mat* mat);
|
||||||
#ifdef ENABLE_FLYCV
|
#ifdef ENABLE_FLYCV
|
||||||
bool ImplByFlyCV(Mat* mat);
|
bool ImplByFlyCV(Mat* mat);
|
||||||
|
#endif
|
||||||
|
#ifdef ENABLE_CVCUDA
|
||||||
|
bool ImplByCvCuda(FDMat* mat);
|
||||||
#endif
|
#endif
|
||||||
std::string Name() { return "Cast"; }
|
std::string Name() { return "Cast"; }
|
||||||
static bool Run(Mat* mat, const std::string& dtype,
|
static bool Run(Mat* mat, const std::string& dtype,
|
||||||
@@ -34,6 +42,9 @@ class FASTDEPLOY_DECL Cast : public Processor {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
std::string dtype_;
|
std::string dtype_;
|
||||||
|
#ifdef ENABLE_CVCUDA
|
||||||
|
cvcuda::ConvertTo cvcuda_convert_op_;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -111,6 +111,17 @@ void CreateCvCudaImageBatchVarShape(std::vector<FDTensor*>& tensors,
|
|||||||
img_batch.pushBack(CreateImageWrapData(*(tensors[i])));
|
img_batch.pushBack(CreateImageWrapData(*(tensors[i])));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NVCVInterpolationType CreateCvCudaInterp(int interp) {
|
||||||
|
// CV-CUDA Interp value is compatible with OpenCV
|
||||||
|
auto nvcv_interp = NVCVInterpolationType(interp);
|
||||||
|
|
||||||
|
// Due to bug of CV-CUDA CUBIC resize, will force to convert CUBIC to LINEAR
|
||||||
|
if (nvcv_interp == NVCV_INTERP_CUBIC) {
|
||||||
|
return NVCV_INTERP_LINEAR;
|
||||||
|
}
|
||||||
|
return nvcv_interp;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
|
@@ -18,8 +18,9 @@
|
|||||||
#include "fastdeploy/vision/common/processors/mat.h"
|
#include "fastdeploy/vision/common/processors/mat.h"
|
||||||
|
|
||||||
#ifdef ENABLE_CVCUDA
|
#ifdef ENABLE_CVCUDA
|
||||||
#include "nvcv/Tensor.hpp"
|
#include <nvcv/Tensor.hpp>
|
||||||
#include <nvcv/ImageBatch.hpp>
|
#include <nvcv/ImageBatch.hpp>
|
||||||
|
#include <cvcuda/Types.h>
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
@@ -32,6 +33,7 @@ void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor);
|
|||||||
nvcv::ImageWrapData CreateImageWrapData(const FDTensor& tensor);
|
nvcv::ImageWrapData CreateImageWrapData(const FDTensor& tensor);
|
||||||
void CreateCvCudaImageBatchVarShape(std::vector<FDTensor*>& tensors,
|
void CreateCvCudaImageBatchVarShape(std::vector<FDTensor*>& tensors,
|
||||||
nvcv::ImageBatchVarShape& img_batch);
|
nvcv::ImageBatchVarShape& img_batch);
|
||||||
|
NVCVInterpolationType CreateCvCudaInterp(int interp);
|
||||||
|
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -77,6 +77,7 @@ bool HWC2CHW::ImplByCvCuda(FDMat* mat) {
|
|||||||
|
|
||||||
cvcuda_reformat_op_(mat->Stream(), src_tensor, dst_tensor);
|
cvcuda_reformat_op_(mat->Stream(), src_tensor, dst_tensor);
|
||||||
|
|
||||||
|
mat->layout = Layout::CHW;
|
||||||
mat->SetTensor(mat->output_cache);
|
mat->SetTensor(mat->output_cache);
|
||||||
mat->mat_type = ProcLib::CVCUDA;
|
mat->mat_type = ProcLib::CVCUDA;
|
||||||
return true;
|
return true;
|
||||||
|
@@ -37,7 +37,7 @@ cv::Mat* Mat::GetOpenCVMat() {
|
|||||||
#ifdef WITH_GPU
|
#ifdef WITH_GPU
|
||||||
FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess,
|
FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess,
|
||||||
"[ERROR] Error occurs while sync cuda stream.");
|
"[ERROR] Error occurs while sync cuda stream.");
|
||||||
cpu_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor);
|
cpu_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor, layout);
|
||||||
mat_type = ProcLib::OPENCV;
|
mat_type = ProcLib::OPENCV;
|
||||||
device = Device::CPU;
|
device = Device::CPU;
|
||||||
return &cpu_mat;
|
return &cpu_mat;
|
||||||
@@ -49,6 +49,23 @@ cv::Mat* Mat::GetOpenCVMat() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef ENABLE_FLYCV
|
||||||
|
fcv::Mat* Mat::GetFlyCVMat() {
|
||||||
|
if (mat_type == ProcLib::FLYCV) {
|
||||||
|
return &fcv_mat;
|
||||||
|
} else if (mat_type == ProcLib::OPENCV) {
|
||||||
|
// Just a reference to cpu_mat, zero copy. After you
|
||||||
|
// call this method, fcv_mat and cpu_mat will point
|
||||||
|
// to the same memory buffer.
|
||||||
|
fcv_mat = ConvertOpenCVMatToFlyCV(cpu_mat);
|
||||||
|
mat_type = ProcLib::FLYCV;
|
||||||
|
return &fcv_mat;
|
||||||
|
} else {
|
||||||
|
FDASSERT(false, "The mat_type of custom Mat can not be ProcLib::DEFAULT");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
void* Mat::Data() {
|
void* Mat::Data() {
|
||||||
if (mat_type == ProcLib::FLYCV) {
|
if (mat_type == ProcLib::FLYCV) {
|
||||||
#ifdef ENABLE_FLYCV
|
#ifdef ENABLE_FLYCV
|
||||||
@@ -158,7 +175,7 @@ void Mat::PrintInfo(const std::string& flag) {
|
|||||||
#ifdef WITH_GPU
|
#ifdef WITH_GPU
|
||||||
FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess,
|
FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess,
|
||||||
"[ERROR] Error occurs while sync cuda stream.");
|
"[ERROR] Error occurs while sync cuda stream.");
|
||||||
cv::Mat tmp_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor);
|
cv::Mat tmp_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor, layout);
|
||||||
cv::Scalar mean = cv::mean(tmp_mat);
|
cv::Scalar mean = cv::mean(tmp_mat);
|
||||||
for (int i = 0; i < Channels(); ++i) {
|
for (int i = 0; i < Channels(); ++i) {
|
||||||
std::cout << mean[i] << " ";
|
std::cout << mean[i] << " ";
|
||||||
|
@@ -13,10 +13,13 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "fastdeploy/core/fd_tensor.h"
|
#include "fastdeploy/core/fd_tensor.h"
|
||||||
#include "fastdeploy/vision/common/processors/utils.h"
|
|
||||||
#include "fastdeploy/vision/common/processors/proc_lib.h"
|
#include "fastdeploy/vision/common/processors/proc_lib.h"
|
||||||
#include "opencv2/core/core.hpp"
|
#include "opencv2/core/core.hpp"
|
||||||
|
|
||||||
|
#ifdef ENABLE_FLYCV
|
||||||
|
#include "flycv.h" // NOLINT
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef WITH_GPU
|
#ifdef WITH_GPU
|
||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
#endif
|
#endif
|
||||||
@@ -70,21 +73,7 @@ struct FASTDEPLOY_DECL Mat {
|
|||||||
fcv_mat = mat;
|
fcv_mat = mat;
|
||||||
mat_type = ProcLib::FLYCV;
|
mat_type = ProcLib::FLYCV;
|
||||||
}
|
}
|
||||||
|
fcv::Mat* GetFlyCVMat();
|
||||||
fcv::Mat* GetFlyCVMat() {
|
|
||||||
if (mat_type == ProcLib::FLYCV) {
|
|
||||||
return &fcv_mat;
|
|
||||||
} else if (mat_type == ProcLib::OPENCV) {
|
|
||||||
// Just a reference to cpu_mat, zero copy. After you
|
|
||||||
// call this method, fcv_mat and cpu_mat will point
|
|
||||||
// to the same memory buffer.
|
|
||||||
fcv_mat = ConvertOpenCVMatToFlyCV(cpu_mat);
|
|
||||||
mat_type = ProcLib::FLYCV;
|
|
||||||
return &fcv_mat;
|
|
||||||
} else {
|
|
||||||
FDASSERT(false, "The mat_type of custom Mat can not be ProcLib::DEFAULT");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void* Data();
|
void* Data();
|
||||||
|
@@ -40,12 +40,16 @@ __global__ void NormalizeAndPermuteKernel(const uint8_t* src, float* dst,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool NormalizeAndPermute::ImplByCuda(FDMat* mat) {
|
bool NormalizeAndPermute::ImplByCuda(FDMat* mat) {
|
||||||
|
if (mat->layout != Layout::HWC) {
|
||||||
|
FDERROR << "Only supports input with HWC layout." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
// Prepare input tensor
|
// Prepare input tensor
|
||||||
FDTensor* src = CreateCachedGpuInputTensor(mat);
|
FDTensor* src = CreateCachedGpuInputTensor(mat);
|
||||||
|
|
||||||
// Prepare output tensor
|
// Prepare output tensor
|
||||||
mat->output_cache->Resize(src->Shape(), FDDataType::FP32, "output_cache",
|
mat->output_cache->Resize({src->shape[2], src->shape[0], src->shape[1]},
|
||||||
Device::GPU);
|
FDDataType::FP32, "output_cache", Device::GPU);
|
||||||
|
|
||||||
// Copy alpha and beta to GPU
|
// Copy alpha and beta to GPU
|
||||||
gpu_alpha_.Resize({1, 1, static_cast<int>(alpha_.size())}, FDDataType::FP32,
|
gpu_alpha_.Resize({1, 1, static_cast<int>(alpha_.size())}, FDDataType::FP32,
|
||||||
@@ -68,9 +72,8 @@ bool NormalizeAndPermute::ImplByCuda(FDMat* mat) {
|
|||||||
reinterpret_cast<float*>(gpu_beta_.Data()), mat->Channels(), swap_rb_, 1,
|
reinterpret_cast<float*>(gpu_beta_.Data()), mat->Channels(), swap_rb_, 1,
|
||||||
jobs);
|
jobs);
|
||||||
|
|
||||||
mat->SetTensor(mat->output_cache);
|
|
||||||
mat->device = Device::GPU;
|
|
||||||
mat->layout = Layout::CHW;
|
mat->layout = Layout::CHW;
|
||||||
|
mat->SetTensor(mat->output_cache);
|
||||||
mat->mat_type = ProcLib::CUDA;
|
mat->mat_type = ProcLib::CUDA;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -112,7 +115,6 @@ bool NormalizeAndPermute::ImplByCuda(FDMatBatch* mat_batch) {
|
|||||||
mat_batch->output_cache->shape[0], jobs);
|
mat_batch->output_cache->shape[0], jobs);
|
||||||
|
|
||||||
mat_batch->SetTensor(mat_batch->output_cache);
|
mat_batch->SetTensor(mat_batch->output_cache);
|
||||||
mat_batch->device = Device::GPU;
|
|
||||||
mat_batch->layout = FDMatBatchLayout::NCHW;
|
mat_batch->layout = FDMatBatchLayout::NCHW;
|
||||||
mat_batch->mat_type = ProcLib::CUDA;
|
mat_batch->mat_type = ProcLib::CUDA;
|
||||||
return true;
|
return true;
|
||||||
|
@@ -14,77 +14,62 @@
|
|||||||
|
|
||||||
#include "fastdeploy/vision/common/processors/pad_to_size.h"
|
#include "fastdeploy/vision/common/processors/pad_to_size.h"
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/common/processors/utils.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
|
|
||||||
bool PadToSize::ImplByOpenCV(Mat* mat) {
|
static bool PadHWCByOpenCV(FDMat* mat, int width, int height,
|
||||||
if (width_ == -1 || height_ == -1) {
|
const std::vector<float>& value) {
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (mat->layout != Layout::HWC) {
|
|
||||||
FDERROR << "PadToSize: The input data must be Layout::HWC format!"
|
|
||||||
<< std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (mat->Channels() > 4) {
|
|
||||||
FDERROR << "PadToSize: Only support channels <= 4." << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (mat->Channels() != value_.size()) {
|
|
||||||
FDERROR
|
|
||||||
<< "PadToSize: Require input channels equals to size of padding value, "
|
|
||||||
"but now channels = "
|
|
||||||
<< mat->Channels() << ", the size of padding values = " << value_.size()
|
|
||||||
<< "." << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
int origin_w = mat->Width();
|
int origin_w = mat->Width();
|
||||||
int origin_h = mat->Height();
|
int origin_h = mat->Height();
|
||||||
if (origin_w > width_) {
|
|
||||||
FDERROR << "PadToSize: the input width:" << origin_w
|
|
||||||
<< " is greater than the target width: " << width_ << "."
|
|
||||||
<< std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (origin_h > height_) {
|
|
||||||
FDERROR << "PadToSize: the input height:" << origin_h
|
|
||||||
<< " is greater than the target height: " << height_ << "."
|
|
||||||
<< std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (origin_w == width_ && origin_h == height_) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
cv::Mat* im = mat->GetOpenCVMat();
|
cv::Mat* im = mat->GetOpenCVMat();
|
||||||
cv::Scalar value;
|
cv::Scalar scalar;
|
||||||
if (value_.size() == 1) {
|
if (value.size() == 1) {
|
||||||
value = cv::Scalar(value_[0]);
|
scalar = cv::Scalar(value[0]);
|
||||||
} else if (value_.size() == 2) {
|
} else if (value.size() == 2) {
|
||||||
value = cv::Scalar(value_[0], value_[1]);
|
scalar = cv::Scalar(value[0], value[1]);
|
||||||
} else if (value_.size() == 3) {
|
} else if (value.size() == 3) {
|
||||||
value = cv::Scalar(value_[0], value_[1], value_[2]);
|
scalar = cv::Scalar(value[0], value[1], value[2]);
|
||||||
} else {
|
} else {
|
||||||
value = cv::Scalar(value_[0], value_[1], value_[2], value_[3]);
|
scalar = cv::Scalar(value[0], value[1], value[2], value[3]);
|
||||||
}
|
}
|
||||||
// top, bottom, left, right
|
// top, bottom, left, right
|
||||||
cv::copyMakeBorder(*im, *im, 0, height_ - origin_h, 0, width_ - origin_w,
|
cv::copyMakeBorder(*im, *im, 0, height - origin_h, 0, width - origin_w,
|
||||||
cv::BORDER_CONSTANT, value);
|
cv::BORDER_CONSTANT, scalar);
|
||||||
mat->SetHeight(height_);
|
mat->SetHeight(height);
|
||||||
mat->SetWidth(width_);
|
mat->SetWidth(width);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_FLYCV
|
static bool PadCHWByOpenCV(FDMat* mat, int width, int height,
|
||||||
bool PadToSize::ImplByFlyCV(Mat* mat) {
|
const std::vector<float>& value) {
|
||||||
if (width_ == -1 || height_ == -1) {
|
int origin_w = mat->Width();
|
||||||
|
int origin_h = mat->Height();
|
||||||
|
cv::Mat* im = mat->GetOpenCVMat();
|
||||||
|
cv::Mat new_im(height, width,
|
||||||
|
CreateOpenCVDataType(mat->Type(), mat->Channels()));
|
||||||
|
|
||||||
|
for (int i = 0; i < mat->Channels(); ++i) {
|
||||||
|
uint8_t* src_data =
|
||||||
|
im->ptr() + i * origin_w * origin_h * FDDataTypeSize(mat->Type());
|
||||||
|
cv::Mat src(origin_h, origin_w, CreateOpenCVDataType(mat->Type(), 1),
|
||||||
|
src_data);
|
||||||
|
|
||||||
|
uint8_t* dst_data =
|
||||||
|
new_im.ptr() + i * width * height * FDDataTypeSize(mat->Type());
|
||||||
|
cv::Mat dst(height, width, CreateOpenCVDataType(mat->Type(), 1), dst_data);
|
||||||
|
|
||||||
|
cv::copyMakeBorder(src, dst, 0, height - origin_h, 0, width - origin_w,
|
||||||
|
cv::BORDER_CONSTANT, cv::Scalar(value[i]));
|
||||||
|
}
|
||||||
|
mat->SetMat(new_im);
|
||||||
|
mat->SetHeight(height);
|
||||||
|
mat->SetWidth(width);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (mat->layout != Layout::HWC) {
|
|
||||||
FDERROR << "PadToSize: The input data must be Layout::HWC format!"
|
bool PadToSize::CheckArgs(FDMat* mat) {
|
||||||
<< std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (mat->Channels() > 4) {
|
if (mat->Channels() > 4) {
|
||||||
FDERROR << "PadToSize: Only support channels <= 4." << std::endl;
|
FDERROR << "PadToSize: Only support channels <= 4." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
@@ -97,45 +82,184 @@ bool PadToSize::ImplByFlyCV(Mat* mat) {
|
|||||||
<< "." << std::endl;
|
<< "." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
int origin_w = mat->Width();
|
if (mat->Width() > width_) {
|
||||||
int origin_h = mat->Height();
|
FDERROR << "PadToSize: the input width:" << mat->Width()
|
||||||
if (origin_w > width_) {
|
|
||||||
FDERROR << "PadToSize: the input width:" << origin_w
|
|
||||||
<< " is greater than the target width: " << width_ << "."
|
<< " is greater than the target width: " << width_ << "."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (origin_h > height_) {
|
if (mat->Height() > height_) {
|
||||||
FDERROR << "PadToSize: the input height:" << origin_h
|
FDERROR << "PadToSize: the input height:" << mat->Height()
|
||||||
<< " is greater than the target height: " << height_ << "."
|
<< " is greater than the target height: " << height_ << "."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (origin_w == width_ && origin_h == height_) {
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PadToSize::ImplByOpenCV(FDMat* mat) {
|
||||||
|
if (width_ == -1 || height_ == -1 ||
|
||||||
|
(mat->Width() == width_ && mat->Height() == height_)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (CheckArgs(mat) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (mat->layout == Layout::HWC) {
|
||||||
|
return PadHWCByOpenCV(mat, width_, height_, value_);
|
||||||
|
} else if (mat->layout == Layout::CHW) {
|
||||||
|
return PadCHWByOpenCV(mat, width_, height_, value_);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef ENABLE_FLYCV
|
||||||
|
static bool PadHWCByFlyCV(FDMat* mat, int width, int height,
|
||||||
|
const std::vector<float>& value) {
|
||||||
|
int origin_w = mat->Width();
|
||||||
|
int origin_h = mat->Height();
|
||||||
fcv::Mat* im = mat->GetFlyCVMat();
|
fcv::Mat* im = mat->GetFlyCVMat();
|
||||||
fcv::Scalar value;
|
fcv::Scalar scalar;
|
||||||
if (value_.size() == 1) {
|
if (value.size() == 1) {
|
||||||
value = fcv::Scalar(value_[0]);
|
scalar = fcv::Scalar(value[0]);
|
||||||
} else if (value_.size() == 2) {
|
} else if (value.size() == 2) {
|
||||||
value = fcv::Scalar(value_[0], value_[1]);
|
scalar = fcv::Scalar(value[0], value[1]);
|
||||||
} else if (value_.size() == 3) {
|
} else if (value.size() == 3) {
|
||||||
value = fcv::Scalar(value_[0], value_[1], value_[2]);
|
scalar = fcv::Scalar(value[0], value[1], value[2]);
|
||||||
} else {
|
} else {
|
||||||
value = fcv::Scalar(value_[0], value_[1], value_[2], value_[3]);
|
scalar = fcv::Scalar(value[0], value[1], value[2], value[3]);
|
||||||
}
|
}
|
||||||
fcv::Mat new_im;
|
fcv::Mat new_im;
|
||||||
// top, bottom, left, right
|
// top, bottom, left, right
|
||||||
fcv::copy_make_border(*im, new_im, 0, height_ - origin_h, 0,
|
fcv::copy_make_border(*im, new_im, 0, height - origin_h, 0, width - origin_w,
|
||||||
width_ - origin_w, fcv::BorderType::BORDER_CONSTANT,
|
fcv::BorderType::BORDER_CONSTANT, scalar);
|
||||||
value);
|
|
||||||
mat->SetMat(new_im);
|
mat->SetMat(new_im);
|
||||||
mat->SetHeight(height_);
|
mat->SetHeight(height);
|
||||||
mat->SetWidth(width_);
|
mat->SetWidth(width);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool PadCHWByFlyCV(FDMat* mat, int width, int height,
|
||||||
|
const std::vector<float>& value) {
|
||||||
|
int origin_w = mat->Width();
|
||||||
|
int origin_h = mat->Height();
|
||||||
|
fcv::Mat new_im(height, width,
|
||||||
|
CreateFlyCVDataType(mat->Type(), mat->Channels()));
|
||||||
|
for (int i = 0; i < mat->Channels(); ++i) {
|
||||||
|
uint8_t* src_data = reinterpret_cast<uint8_t*>(mat->Data()) +
|
||||||
|
i * origin_w * origin_h * FDDataTypeSize(mat->Type());
|
||||||
|
fcv::Mat src(origin_h, origin_w, CreateFlyCVDataType(mat->Type(), 1),
|
||||||
|
src_data);
|
||||||
|
|
||||||
|
uint8_t* dst_data = reinterpret_cast<uint8_t*>(new_im.data()) +
|
||||||
|
i * width * height * FDDataTypeSize(mat->Type());
|
||||||
|
fcv::Mat dst(height, width, CreateFlyCVDataType(mat->Type(), 1), dst_data);
|
||||||
|
|
||||||
|
fcv::copy_make_border(src, dst, 0, height - origin_h, 0, width - origin_w,
|
||||||
|
fcv::BorderType::BORDER_CONSTANT,
|
||||||
|
fcv::Scalar(value[i]));
|
||||||
|
}
|
||||||
|
mat->SetMat(new_im);
|
||||||
|
mat->SetHeight(height);
|
||||||
|
mat->SetWidth(width);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PadToSize::ImplByFlyCV(FDMat* mat) {
|
||||||
|
if (width_ == -1 || height_ == -1 ||
|
||||||
|
(mat->Width() == width_ && mat->Height() == height_)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (CheckArgs(mat) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (mat->layout == Layout::HWC) {
|
||||||
|
return PadHWCByFlyCV(mat, width_, height_, value_);
|
||||||
|
} else if (mat->layout == Layout::CHW) {
|
||||||
|
return PadCHWByFlyCV(mat, width_, height_, value_);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef ENABLE_CVCUDA
|
||||||
|
static bool PadHWCByCvCuda(cvcuda::CopyMakeBorder& pad_op, FDMat* mat,
|
||||||
|
int width, int height,
|
||||||
|
const std::vector<float>& value) {
|
||||||
|
float4 border_value;
|
||||||
|
if (value.size() == 1) {
|
||||||
|
border_value = make_float4(value[0], 0.0f, 0.0f, 0.0f);
|
||||||
|
} else if (value.size() == 2) {
|
||||||
|
border_value = make_float4(value[0], value[1], 0.0f, 0.0f);
|
||||||
|
} else if (value.size() == 3) {
|
||||||
|
border_value = make_float4(value[0], value[1], value[2], 0.0f);
|
||||||
|
} else {
|
||||||
|
border_value = make_float4(value[0], value[1], value[2], value[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare input tensor
|
||||||
|
FDTensor* src = CreateCachedGpuInputTensor(mat);
|
||||||
|
auto src_tensor = CreateCvCudaTensorWrapData(*src);
|
||||||
|
|
||||||
|
// Prepare output tensor
|
||||||
|
mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(),
|
||||||
|
"output_cache", Device::GPU);
|
||||||
|
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
|
||||||
|
|
||||||
|
pad_op(mat->Stream(), src_tensor, dst_tensor, 0, 0, NVCV_BORDER_CONSTANT,
|
||||||
|
border_value);
|
||||||
|
|
||||||
|
mat->SetTensor(mat->output_cache);
|
||||||
|
mat->mat_type = ProcLib::CVCUDA;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool PadCHWByCvCuda(cvcuda::CopyMakeBorder& pad_op, FDMat* mat,
|
||||||
|
int width, int height,
|
||||||
|
const std::vector<float>& value) {
|
||||||
|
float4 border_value = make_float4(value[0], 0.0f, 0.0f, 0.0f);
|
||||||
|
FDTensor* input = CreateCachedGpuInputTensor(mat);
|
||||||
|
int channels = input->shape[0];
|
||||||
|
mat->output_cache->Resize({channels, height, width}, mat->Type(),
|
||||||
|
"output_cache", Device::GPU);
|
||||||
|
for (int i = 0; i < channels; ++i) {
|
||||||
|
uint8_t* src_data =
|
||||||
|
reinterpret_cast<uint8_t*>(input->Data()) +
|
||||||
|
i * mat->Width() * mat->Height() * FDDataTypeSize(mat->Type());
|
||||||
|
FDTensor src;
|
||||||
|
src.SetExternalData({mat->Height(), mat->Width(), 1}, input->Dtype(),
|
||||||
|
src_data, input->device, input->device_id);
|
||||||
|
auto src_tensor = CreateCvCudaTensorWrapData(src);
|
||||||
|
|
||||||
|
uint8_t* dst_data = reinterpret_cast<uint8_t*>(mat->output_cache->Data()) +
|
||||||
|
i * width * height * FDDataTypeSize(mat->Type());
|
||||||
|
FDTensor dst;
|
||||||
|
dst.SetExternalData({height, width, 1}, input->Dtype(), dst_data,
|
||||||
|
input->device, input->device_id);
|
||||||
|
auto dst_tensor = CreateCvCudaTensorWrapData(dst);
|
||||||
|
|
||||||
|
pad_op(mat->Stream(), src_tensor, dst_tensor, 0, 0, NVCV_BORDER_CONSTANT,
|
||||||
|
border_value);
|
||||||
|
}
|
||||||
|
mat->SetTensor(mat->output_cache);
|
||||||
|
mat->mat_type = ProcLib::CVCUDA;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool PadToSize::ImplByCvCuda(FDMat* mat) {
|
||||||
|
if (width_ == -1 || height_ == -1 ||
|
||||||
|
(mat->Width() == width_ && mat->Height() == height_)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (CheckArgs(mat) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (mat->layout == Layout::HWC) {
|
||||||
|
return PadHWCByCvCuda(cvcuda_pad_op_, mat, width_, height_, value_);
|
||||||
|
} else if (mat->layout == Layout::CHW) {
|
||||||
|
return PadCHWByCvCuda(cvcuda_pad_op_, mat, width_, height_, value_);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
bool PadToSize::Run(Mat* mat, int width, int height,
|
bool PadToSize::Run(Mat* mat, int width, int height,
|
||||||
|
@@ -15,6 +15,11 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "fastdeploy/vision/common/processors/base.h"
|
#include "fastdeploy/vision/common/processors/base.h"
|
||||||
|
#ifdef ENABLE_CVCUDA
|
||||||
|
#include <cvcuda/OpCopyMakeBorder.hpp>
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
@@ -30,6 +35,9 @@ class FASTDEPLOY_DECL PadToSize : public Processor {
|
|||||||
bool ImplByOpenCV(Mat* mat);
|
bool ImplByOpenCV(Mat* mat);
|
||||||
#ifdef ENABLE_FLYCV
|
#ifdef ENABLE_FLYCV
|
||||||
bool ImplByFlyCV(Mat* mat);
|
bool ImplByFlyCV(Mat* mat);
|
||||||
|
#endif
|
||||||
|
#ifdef ENABLE_CVCUDA
|
||||||
|
bool ImplByCvCuda(FDMat* mat);
|
||||||
#endif
|
#endif
|
||||||
std::string Name() { return "PadToSize"; }
|
std::string Name() { return "PadToSize"; }
|
||||||
|
|
||||||
@@ -37,10 +45,19 @@ class FASTDEPLOY_DECL PadToSize : public Processor {
|
|||||||
const std::vector<float>& value,
|
const std::vector<float>& value,
|
||||||
ProcLib lib = ProcLib::DEFAULT);
|
ProcLib lib = ProcLib::DEFAULT);
|
||||||
|
|
||||||
|
void SetWidthHeight(int width, int height) {
|
||||||
|
width_ = width;
|
||||||
|
height_ = height;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
bool CheckArgs(FDMat* mat);
|
||||||
int width_;
|
int width_;
|
||||||
int height_;
|
int height_;
|
||||||
std::vector<float> value_;
|
std::vector<float> value_;
|
||||||
|
#ifdef ENABLE_CVCUDA
|
||||||
|
cvcuda::CopyMakeBorder cvcuda_pad_op_;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -147,7 +147,7 @@ bool Resize::ImplByCvCuda(FDMat* mat) {
|
|||||||
|
|
||||||
// CV-CUDA Interp value is compatible with OpenCV
|
// CV-CUDA Interp value is compatible with OpenCV
|
||||||
cvcuda_resize_op_(mat->Stream(), src_tensor, dst_tensor,
|
cvcuda_resize_op_(mat->Stream(), src_tensor, dst_tensor,
|
||||||
NVCVInterpolationType(interp_));
|
CreateCvCudaInterp(interp_));
|
||||||
|
|
||||||
mat->SetTensor(mat->output_cache);
|
mat->SetTensor(mat->output_cache);
|
||||||
mat->SetWidth(width_);
|
mat->SetWidth(width_);
|
||||||
|
@@ -95,9 +95,8 @@ bool ResizeByShort::ImplByCvCuda(FDMat* mat) {
|
|||||||
"output_cache", Device::GPU);
|
"output_cache", Device::GPU);
|
||||||
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
|
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
|
||||||
|
|
||||||
// CV-CUDA Interp value is compatible with OpenCV
|
|
||||||
cvcuda_resize_op_(mat->Stream(), src_tensor, dst_tensor,
|
cvcuda_resize_op_(mat->Stream(), src_tensor, dst_tensor,
|
||||||
NVCVInterpolationType(interp_));
|
CreateCvCudaInterp(interp_));
|
||||||
|
|
||||||
mat->SetTensor(mat->output_cache);
|
mat->SetTensor(mat->output_cache);
|
||||||
mat->SetWidth(width);
|
mat->SetWidth(width);
|
||||||
@@ -138,7 +137,7 @@ bool ResizeByShort::ImplByCvCuda(FDMatBatch* mat_batch) {
|
|||||||
|
|
||||||
// CV-CUDA Interp value is compatible with OpenCV
|
// CV-CUDA Interp value is compatible with OpenCV
|
||||||
cvcuda_resize_op_(mat_batch->Stream(), src_batch, dst_batch,
|
cvcuda_resize_op_(mat_batch->Stream(), src_batch, dst_batch,
|
||||||
NVCVInterpolationType(interp_));
|
CreateCvCudaInterp(interp_));
|
||||||
|
|
||||||
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
||||||
FDMat* mat = &(*(mat_batch->mats))[i];
|
FDMat* mat = &(*(mat_batch->mats))[i];
|
||||||
|
@@ -114,6 +114,68 @@ bool StridePad::ImplByFlyCV(Mat* mat) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef ENABLE_CVCUDA
|
||||||
|
bool StridePad::ImplByCvCuda(FDMat* mat) {
|
||||||
|
if (mat->layout != Layout::HWC) {
|
||||||
|
FDERROR << "StridePad: The input data must be Layout::HWC format!"
|
||||||
|
<< std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (mat->Channels() > 4) {
|
||||||
|
FDERROR << "StridePad: Only support channels <= 4." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (mat->Channels() != value_.size()) {
|
||||||
|
FDERROR
|
||||||
|
<< "StridePad: Require input channels equals to size of padding value, "
|
||||||
|
"but now channels = "
|
||||||
|
<< mat->Channels() << ", the size of padding values = " << value_.size()
|
||||||
|
<< "." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
int origin_w = mat->Width();
|
||||||
|
int origin_h = mat->Height();
|
||||||
|
|
||||||
|
int pad_h = (mat->Height() / stride_) * stride_ +
|
||||||
|
(mat->Height() % stride_ != 0) * stride_ - mat->Height();
|
||||||
|
int pad_w = (mat->Width() / stride_) * stride_ +
|
||||||
|
(mat->Width() % stride_ != 0) * stride_ - mat->Width();
|
||||||
|
if (pad_h == 0 && pad_w == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
float4 value;
|
||||||
|
if (value_.size() == 1) {
|
||||||
|
value = make_float4(value_[0], 0.0f, 0.0f, 0.0f);
|
||||||
|
} else if (value_.size() == 2) {
|
||||||
|
value = make_float4(value_[0], value_[1], 0.0f, 0.0f);
|
||||||
|
} else if (value_.size() == 3) {
|
||||||
|
value = make_float4(value_[0], value_[1], value_[2], 0.0f);
|
||||||
|
} else {
|
||||||
|
value = make_float4(value_[0], value_[1], value_[2], value_[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare input tensor
|
||||||
|
FDTensor* src = CreateCachedGpuInputTensor(mat);
|
||||||
|
auto src_tensor = CreateCvCudaTensorWrapData(*src);
|
||||||
|
|
||||||
|
int height = mat->Height() + pad_h;
|
||||||
|
int width = mat->Width() + pad_w;
|
||||||
|
|
||||||
|
// Prepare output tensor
|
||||||
|
mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(),
|
||||||
|
"output_cache", Device::GPU);
|
||||||
|
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
|
||||||
|
|
||||||
|
cvcuda_pad_op_(mat->Stream(), src_tensor, dst_tensor, 0, 0,
|
||||||
|
NVCV_BORDER_CONSTANT, value);
|
||||||
|
|
||||||
|
mat->SetTensor(mat->output_cache);
|
||||||
|
mat->mat_type = ProcLib::CVCUDA;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
bool StridePad::Run(Mat* mat, int stride, const std::vector<float>& value,
|
bool StridePad::Run(Mat* mat, int stride, const std::vector<float>& value,
|
||||||
ProcLib lib) {
|
ProcLib lib) {
|
||||||
auto p = StridePad(stride, value);
|
auto p = StridePad(stride, value);
|
||||||
|
@@ -15,6 +15,11 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "fastdeploy/vision/common/processors/base.h"
|
#include "fastdeploy/vision/common/processors/base.h"
|
||||||
|
#ifdef ENABLE_CVCUDA
|
||||||
|
#include <cvcuda/OpCopyMakeBorder.hpp>
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
@@ -29,6 +34,9 @@ class FASTDEPLOY_DECL StridePad : public Processor {
|
|||||||
bool ImplByOpenCV(Mat* mat);
|
bool ImplByOpenCV(Mat* mat);
|
||||||
#ifdef ENABLE_FLYCV
|
#ifdef ENABLE_FLYCV
|
||||||
bool ImplByFlyCV(Mat* mat);
|
bool ImplByFlyCV(Mat* mat);
|
||||||
|
#endif
|
||||||
|
#ifdef ENABLE_CVCUDA
|
||||||
|
bool ImplByCvCuda(FDMat* mat);
|
||||||
#endif
|
#endif
|
||||||
std::string Name() { return "StridePad"; }
|
std::string Name() { return "StridePad"; }
|
||||||
|
|
||||||
@@ -39,6 +47,9 @@ class FASTDEPLOY_DECL StridePad : public Processor {
|
|||||||
private:
|
private:
|
||||||
int stride_ = 32;
|
int stride_ = 32;
|
||||||
std::vector<float> value_;
|
std::vector<float> value_;
|
||||||
|
#ifdef ENABLE_CVCUDA
|
||||||
|
cvcuda::CopyMakeBorder cvcuda_pad_op_;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -186,8 +186,7 @@ cv::Mat ConvertFlyCVMatToOpenCV(fcv::Mat& fim) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
cv::Mat CreateZeroCopyOpenCVMatFromBuffer(
|
cv::Mat CreateZeroCopyOpenCVMatFromBuffer(int height, int width, int channels,
|
||||||
int height, int width, int channels,
|
|
||||||
FDDataType type, void* data) {
|
FDDataType type, void* data) {
|
||||||
cv::Mat ocv_mat;
|
cv::Mat ocv_mat;
|
||||||
switch (type) {
|
switch (type) {
|
||||||
@@ -219,38 +218,38 @@ cv::Mat CreateZeroCopyOpenCVMatFromBuffer(
|
|||||||
return ocv_mat;
|
return ocv_mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
cv::Mat CreateZeroCopyOpenCVMatFromTensor(const FDTensor& tensor) {
|
cv::Mat CreateZeroCopyOpenCVMatFromTensor(const FDTensor& tensor,
|
||||||
// TODO(qiuyanjun): Should add a Layout checking. Now, we
|
Layout layout) {
|
||||||
// assume that the input tensor is already in Layout::HWC.
|
FDASSERT(tensor.shape.size() == 3,
|
||||||
FDASSERT(tensor.shape.size() == 3, "When create OepnCV Mat from tensor,"
|
"When create OepnCV Mat from tensor,"
|
||||||
"tensor shape should be 3-Dim, HWC layout");
|
"tensor shape should be 3-Dim");
|
||||||
FDDataType type = tensor.dtype;
|
FDDataType type = tensor.dtype;
|
||||||
int height = static_cast<int>(tensor.shape[0]);
|
int height = static_cast<int>(tensor.shape[0]);
|
||||||
int width = static_cast<int>(tensor.shape[1]);
|
int width = static_cast<int>(tensor.shape[1]);
|
||||||
int channels = static_cast<int>(tensor.shape[2]);
|
int channels = static_cast<int>(tensor.shape[2]);
|
||||||
return CreateZeroCopyOpenCVMatFromBuffer(
|
if (layout == Layout::CHW) {
|
||||||
height, width, channels, type,
|
channels = static_cast<int>(tensor.shape[0]);
|
||||||
|
height = static_cast<int>(tensor.shape[1]);
|
||||||
|
width = static_cast<int>(tensor.shape[2]);
|
||||||
|
}
|
||||||
|
return CreateZeroCopyOpenCVMatFromBuffer(height, width, channels, type,
|
||||||
const_cast<void*>(tensor.CpuData()));
|
const_cast<void*>(tensor.CpuData()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_FLYCV
|
#ifdef ENABLE_FLYCV
|
||||||
fcv::Mat CreateZeroCopyFlyCVMatFromBuffer(
|
fcv::Mat CreateZeroCopyFlyCVMatFromBuffer(int height, int width, int channels,
|
||||||
int height, int width, int channels,
|
|
||||||
FDDataType type, void* data) {
|
FDDataType type, void* data) {
|
||||||
fcv::Mat fcv_mat;
|
fcv::Mat fcv_mat;
|
||||||
auto fcv_type = CreateFlyCVDataType(type, channels);
|
auto fcv_type = CreateFlyCVDataType(type, channels);
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case FDDataType::UINT8:
|
case FDDataType::UINT8:
|
||||||
fcv_mat =
|
fcv_mat = fcv::Mat(width, height, fcv_type, data);
|
||||||
fcv::Mat(width, height, fcv_type, data);
|
|
||||||
break;
|
break;
|
||||||
case FDDataType::FP32:
|
case FDDataType::FP32:
|
||||||
fcv_mat =
|
fcv_mat = fcv::Mat(width, height, fcv_type, data);
|
||||||
fcv::Mat(width, height, fcv_type, data);
|
|
||||||
break;
|
break;
|
||||||
case FDDataType::FP64:
|
case FDDataType::FP64:
|
||||||
fcv_mat =
|
fcv_mat = fcv::Mat(width, height, fcv_type, data);
|
||||||
fcv::Mat(width, height, fcv_type, data);
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
FDASSERT(false,
|
FDASSERT(false,
|
||||||
@@ -265,14 +264,14 @@ fcv::Mat CreateZeroCopyFlyCVMatFromBuffer(
|
|||||||
fcv::Mat CreateZeroCopyFlyCVMatFromTensor(const FDTensor& tensor) {
|
fcv::Mat CreateZeroCopyFlyCVMatFromTensor(const FDTensor& tensor) {
|
||||||
// TODO(qiuyanjun): Should add a Layout checking. Now, we
|
// TODO(qiuyanjun): Should add a Layout checking. Now, we
|
||||||
// assume that the input tensor is already in Layout::HWC.
|
// assume that the input tensor is already in Layout::HWC.
|
||||||
FDASSERT(tensor.shape.size() == 3, "When create FlyCV Mat from tensor,"
|
FDASSERT(tensor.shape.size() == 3,
|
||||||
|
"When create FlyCV Mat from tensor,"
|
||||||
"tensor shape should be 3-Dim, HWC layout");
|
"tensor shape should be 3-Dim, HWC layout");
|
||||||
FDDataType type = tensor.dtype;
|
FDDataType type = tensor.dtype;
|
||||||
int height = static_cast<int>(tensor.shape[0]);
|
int height = static_cast<int>(tensor.shape[0]);
|
||||||
int width = static_cast<int>(tensor.shape[1]);
|
int width = static_cast<int>(tensor.shape[1]);
|
||||||
int channels = static_cast<int>(tensor.shape[2]);
|
int channels = static_cast<int>(tensor.shape[2]);
|
||||||
return CreateZeroCopyFlyCVMatFromBuffer(
|
return CreateZeroCopyFlyCVMatFromBuffer(height, width, channels, type,
|
||||||
height, width, channels, type,
|
|
||||||
const_cast<void*>(tensor.Data()));
|
const_cast<void*>(tensor.Data()));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
#include "fastdeploy/core/fd_tensor.h"
|
#include "fastdeploy/core/fd_tensor.h"
|
||||||
#include "fastdeploy/utils/utils.h"
|
#include "fastdeploy/utils/utils.h"
|
||||||
|
#include "fastdeploy/vision/common/processors/mat.h"
|
||||||
#include "opencv2/core/core.hpp"
|
#include "opencv2/core/core.hpp"
|
||||||
|
|
||||||
#ifdef ENABLE_FLYCV
|
#ifdef ENABLE_FLYCV
|
||||||
@@ -43,7 +44,8 @@ cv::Mat ConvertFlyCVMatToOpenCV(fcv::Mat& fim);
|
|||||||
// Create zero copy OpenCV/FlyCV Mat from FD Tensor / Buffer
|
// Create zero copy OpenCV/FlyCV Mat from FD Tensor / Buffer
|
||||||
cv::Mat CreateZeroCopyOpenCVMatFromBuffer(int height, int width,
|
cv::Mat CreateZeroCopyOpenCVMatFromBuffer(int height, int width,
|
||||||
int channels, FDDataType type, void* data);
|
int channels, FDDataType type, void* data);
|
||||||
cv::Mat CreateZeroCopyOpenCVMatFromTensor(const FDTensor& tensor);
|
cv::Mat CreateZeroCopyOpenCVMatFromTensor(const FDTensor& tensor,
|
||||||
|
Layout layout = Layout::HWC);
|
||||||
#ifdef ENABLE_FLYCV
|
#ifdef ENABLE_FLYCV
|
||||||
fcv::Mat CreateZeroCopyFlyCVMatFromBuffer(int height, int width,
|
fcv::Mat CreateZeroCopyFlyCVMatFromBuffer(int height, int width,
|
||||||
int channels, FDDataType type, void* data);
|
int channels, FDDataType type, void* data);
|
||||||
|
@@ -15,8 +15,8 @@
|
|||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
void BindPPDet(pybind11::module& m) {
|
void BindPPDet(pybind11::module& m) {
|
||||||
pybind11::class_<vision::detection::PaddleDetPreprocessor>(
|
pybind11::class_<vision::detection::PaddleDetPreprocessor,
|
||||||
m, "PaddleDetPreprocessor")
|
vision::ProcessorManager>(m, "PaddleDetPreprocessor")
|
||||||
.def(pybind11::init<std::string>())
|
.def(pybind11::init<std::string>())
|
||||||
.def("run",
|
.def("run",
|
||||||
[](vision::detection::PaddleDetPreprocessor& self,
|
[](vision::detection::PaddleDetPreprocessor& self,
|
||||||
|
@@ -129,13 +129,13 @@ bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images,
|
bool PaddleDetPreprocessor::Apply(FDMatBatch* image_batch,
|
||||||
std::vector<FDTensor>* outputs) {
|
std::vector<FDTensor>* outputs) {
|
||||||
if (!initialized_) {
|
if (!initialized_) {
|
||||||
FDERROR << "The preprocessor is not initialized." << std::endl;
|
FDERROR << "The preprocessor is not initialized." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (images->empty()) {
|
if (image_batch->mats->empty()) {
|
||||||
FDERROR << "The size of input images should be greater than 0."
|
FDERROR << "The size of input images should be greater than 0."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
return false;
|
return false;
|
||||||
@@ -146,7 +146,7 @@ bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images,
|
|||||||
// So preprocessor will output the 3 FDTensors, and how to use `im_shape`
|
// So preprocessor will output the 3 FDTensors, and how to use `im_shape`
|
||||||
// is decided by the model itself
|
// is decided by the model itself
|
||||||
outputs->resize(3);
|
outputs->resize(3);
|
||||||
int batch = static_cast<int>(images->size());
|
int batch = static_cast<int>(image_batch->mats->size());
|
||||||
// Allocate memory for scale_factor
|
// Allocate memory for scale_factor
|
||||||
(*outputs)[1].Resize({batch, 2}, FDDataType::FP32);
|
(*outputs)[1].Resize({batch, 2}, FDDataType::FP32);
|
||||||
// Allocate memory for im_shape
|
// Allocate memory for im_shape
|
||||||
@@ -158,63 +158,51 @@ bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images,
|
|||||||
auto* scale_factor_ptr =
|
auto* scale_factor_ptr =
|
||||||
reinterpret_cast<float*>((*outputs)[1].MutableData());
|
reinterpret_cast<float*>((*outputs)[1].MutableData());
|
||||||
auto* im_shape_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData());
|
auto* im_shape_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData());
|
||||||
for (size_t i = 0; i < images->size(); ++i) {
|
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
|
||||||
int origin_w = (*images)[i].Width();
|
FDMat* mat = &(image_batch->mats->at(i));
|
||||||
int origin_h = (*images)[i].Height();
|
int origin_w = mat->Width();
|
||||||
|
int origin_h = mat->Height();
|
||||||
scale_factor_ptr[2 * i] = 1.0;
|
scale_factor_ptr[2 * i] = 1.0;
|
||||||
scale_factor_ptr[2 * i + 1] = 1.0;
|
scale_factor_ptr[2 * i + 1] = 1.0;
|
||||||
for (size_t j = 0; j < processors_.size(); ++j) {
|
for (size_t j = 0; j < processors_.size(); ++j) {
|
||||||
if (!(*(processors_[j].get()))(&((*images)[i]))) {
|
if (!(*(processors_[j].get()))(mat)) {
|
||||||
FDERROR << "Failed to processs image:" << i << " in "
|
FDERROR << "Failed to processs image:" << i << " in "
|
||||||
<< processors_[i]->Name() << "." << std::endl;
|
<< processors_[j]->Name() << "." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (processors_[j]->Name().find("Resize") != std::string::npos) {
|
if (processors_[j]->Name().find("Resize") != std::string::npos) {
|
||||||
scale_factor_ptr[2 * i] = (*images)[i].Height() * 1.0 / origin_h;
|
scale_factor_ptr[2 * i] = mat->Height() * 1.0 / origin_h;
|
||||||
scale_factor_ptr[2 * i + 1] = (*images)[i].Width() * 1.0 / origin_w;
|
scale_factor_ptr[2 * i + 1] = mat->Width() * 1.0 / origin_w;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if ((*images)[i].Height() > max_hw[0]) {
|
if (mat->Height() > max_hw[0]) {
|
||||||
max_hw[0] = (*images)[i].Height();
|
max_hw[0] = mat->Height();
|
||||||
}
|
}
|
||||||
if ((*images)[i].Width() > max_hw[1]) {
|
if (mat->Width() > max_hw[1]) {
|
||||||
max_hw[1] = (*images)[i].Width();
|
max_hw[1] = mat->Width();
|
||||||
}
|
}
|
||||||
im_shape_ptr[2 * i] = max_hw[0];
|
im_shape_ptr[2 * i] = max_hw[0];
|
||||||
im_shape_ptr[2 * i + 1] = max_hw[1];
|
im_shape_ptr[2 * i + 1] = max_hw[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Concat all the preprocessed data to a batch tensor
|
|
||||||
std::vector<FDTensor> im_tensors(images->size());
|
|
||||||
for (size_t i = 0; i < images->size(); ++i) {
|
|
||||||
if ((*images)[i].Height() < max_hw[0] || (*images)[i].Width() < max_hw[1]) {
|
|
||||||
// if the size of image less than max_hw, pad to max_hw
|
// if the size of image less than max_hw, pad to max_hw
|
||||||
FDTensor tensor;
|
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
|
||||||
(*images)[i].ShareWithTensor(&tensor);
|
FDMat* mat = &(image_batch->mats->at(i));
|
||||||
function::Pad(tensor, &(im_tensors[i]),
|
if (mat->Height() < max_hw[0] || mat->Width() < max_hw[1]) {
|
||||||
{0, 0, max_hw[0] - (*images)[i].Height(),
|
pad_op_->SetWidthHeight(max_hw[1], max_hw[0]);
|
||||||
max_hw[1] - (*images)[i].Width()},
|
(*pad_op_)(mat);
|
||||||
0);
|
|
||||||
} else {
|
|
||||||
// No need pad
|
|
||||||
(*images)[i].ShareWithTensor(&(im_tensors[i]));
|
|
||||||
}
|
}
|
||||||
// Reshape to 1xCxHxW
|
|
||||||
im_tensors[i].ExpandDim(0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (im_tensors.size() == 1) {
|
// Get the NCHW tensor
|
||||||
// If there's only 1 input, no need to concat
|
FDTensor* tensor = image_batch->Tensor();
|
||||||
// skip memory copy
|
(*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
|
||||||
(*outputs)[0] = std::move(im_tensors[0]);
|
tensor->Data(), tensor->device,
|
||||||
} else {
|
tensor->device_id);
|
||||||
// Else concat the im tensor for each input image
|
|
||||||
// compose a batched input tensor
|
|
||||||
function::Concat(im_tensors, &((*outputs)[0]), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PaddleDetPreprocessor::DisableNormalize() {
|
void PaddleDetPreprocessor::DisableNormalize() {
|
||||||
this->disable_normalize_ = true;
|
this->disable_normalize_ = true;
|
||||||
// the DisableNormalize function will be invalid if the configuration file is
|
// the DisableNormalize function will be invalid if the configuration file is
|
||||||
@@ -224,6 +212,7 @@ void PaddleDetPreprocessor::DisableNormalize() {
|
|||||||
<< std::endl;
|
<< std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PaddleDetPreprocessor::DisablePermute() {
|
void PaddleDetPreprocessor::DisablePermute() {
|
||||||
this->disable_permute_ = true;
|
this->disable_permute_ = true;
|
||||||
// the DisablePermute function will be invalid if the configuration file is
|
// the DisablePermute function will be invalid if the configuration file is
|
||||||
|
@@ -13,6 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include "fastdeploy/vision/common/processors/manager.h"
|
||||||
#include "fastdeploy/vision/common/processors/transform.h"
|
#include "fastdeploy/vision/common/processors/transform.h"
|
||||||
#include "fastdeploy/vision/common/result.h"
|
#include "fastdeploy/vision/common/result.h"
|
||||||
|
|
||||||
@@ -22,7 +23,7 @@ namespace vision {
|
|||||||
namespace detection {
|
namespace detection {
|
||||||
/*! @brief Preprocessor object for PaddleDet serials model.
|
/*! @brief Preprocessor object for PaddleDet serials model.
|
||||||
*/
|
*/
|
||||||
class FASTDEPLOY_DECL PaddleDetPreprocessor {
|
class FASTDEPLOY_DECL PaddleDetPreprocessor : public ProcessorManager {
|
||||||
public:
|
public:
|
||||||
PaddleDetPreprocessor() = default;
|
PaddleDetPreprocessor() = default;
|
||||||
/** \brief Create a preprocessor instance for PaddleDet serials model
|
/** \brief Create a preprocessor instance for PaddleDet serials model
|
||||||
@@ -31,13 +32,16 @@ class FASTDEPLOY_DECL PaddleDetPreprocessor {
|
|||||||
*/
|
*/
|
||||||
explicit PaddleDetPreprocessor(const std::string& config_file);
|
explicit PaddleDetPreprocessor(const std::string& config_file);
|
||||||
|
|
||||||
/** \brief Process the input image and prepare input tensors for runtime
|
/** \brief Implement the virtual function of ProcessorManager, Apply() is the
|
||||||
|
* body of Run(). Apply() contains the main logic of preprocessing, Run() is
|
||||||
|
* called by users to execute preprocessing
|
||||||
*
|
*
|
||||||
* \param[in] images The input image data list, all the elements are returned by cv::imread()
|
* \param[in] image_batch The input image batch
|
||||||
* \param[in] outputs The output tensors which will feed in runtime, include image, scale_factor, im_shape
|
* \param[in] outputs The output tensors which will feed in runtime
|
||||||
* \return true if the preprocess successed, otherwise false
|
* \return true if the preprocess successed, otherwise false
|
||||||
*/
|
*/
|
||||||
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
|
virtual bool Apply(FDMatBatch* image_batch,
|
||||||
|
std::vector<FDTensor>* outputs);
|
||||||
|
|
||||||
/// This function will disable normalize in preprocessing step.
|
/// This function will disable normalize in preprocessing step.
|
||||||
void DisableNormalize();
|
void DisableNormalize();
|
||||||
@@ -51,6 +55,8 @@ class FASTDEPLOY_DECL PaddleDetPreprocessor {
|
|||||||
private:
|
private:
|
||||||
bool BuildPreprocessPipelineFromConfig();
|
bool BuildPreprocessPipelineFromConfig();
|
||||||
std::vector<std::shared_ptr<Processor>> processors_;
|
std::vector<std::shared_ptr<Processor>> processors_;
|
||||||
|
std::shared_ptr<PadToSize> pad_op_ =
|
||||||
|
std::make_shared<PadToSize>(0, 0, std::vector<float>(3, 0));
|
||||||
bool initialized_ = false;
|
bool initialized_ = false;
|
||||||
// for recording the switch of hwc2chw
|
// for recording the switch of hwc2chw
|
||||||
bool disable_permute_ = false;
|
bool disable_permute_ = false;
|
||||||
|
Reference in New Issue
Block a user