[CVCUDA] PP-OCR Cls & Rec preprocessor support CV-CUDA (#1470)

* ppocr cls preprocessor use manager

* hwc2chw cvcuda

* ppocr rec preproc use manager

* ocr rec preproc cvcuda

* fix rec preproc bug

* ppocr cls&rec preproc set normalize

* fix pybind

* address comment
This commit is contained in:
Wang Xinyu
2023-03-02 10:50:44 +08:00
committed by GitHub
parent fe2882a1ef
commit 044ab993d2
19 changed files with 424 additions and 306 deletions

View File

@@ -18,34 +18,36 @@ namespace fastdeploy {
namespace vision {
#ifdef ENABLE_CVCUDA
nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel) {
nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel,
bool interleaved) {
FDASSERT(channel == 1 || channel == 3 || channel == 4,
"Only support channel be 1/3/4 in CV-CUDA.");
if (type == FDDataType::UINT8) {
if (channel == 1) {
return nvcv::FMT_U8;
} else if (channel == 3) {
return nvcv::FMT_BGR8;
return (interleaved ? nvcv::FMT_BGR8 : nvcv::FMT_BGR8p);
} else {
return nvcv::FMT_BGRA8;
return (interleaved ? nvcv::FMT_BGRA8 : nvcv::FMT_BGRA8p);
}
} else if (type == FDDataType::FP32) {
if (channel == 1) {
return nvcv::FMT_F32;
} else if (channel == 3) {
return nvcv::FMT_BGRf32;
return (interleaved ? nvcv::FMT_BGRf32 : nvcv::FMT_BGRf32p);
} else {
return nvcv::FMT_BGRAf32;
return (interleaved ? nvcv::FMT_BGRAf32 : nvcv::FMT_BGRAf32p);
}
}
FDASSERT(false, "Data type of %s is not supported.", Str(type).c_str());
return nvcv::FMT_BGRf32;
}
nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor) {
nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor,
Layout layout) {
FDASSERT(tensor.shape.size() == 3,
"When create CVCUDA tensor from FD tensor,"
"tensor shape should be 3-Dim, HWC layout");
"tensor shape should be 3-Dim,");
int batchsize = 1;
int h = tensor.Shape()[0];
int w = tensor.Shape()[1];
@@ -56,10 +58,20 @@ nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor) {
buf.strides[2] = c * buf.strides[3];
buf.strides[1] = w * buf.strides[2];
buf.strides[0] = h * buf.strides[1];
if (layout == Layout::CHW) {
c = tensor.Shape()[0];
h = tensor.Shape()[1];
w = tensor.Shape()[2];
buf.strides[3] = FDDataTypeSize(tensor.Dtype());
buf.strides[2] = w * buf.strides[3];
buf.strides[1] = h * buf.strides[2];
buf.strides[0] = c * buf.strides[1];
}
buf.basePtr = reinterpret_cast<NVCVByte*>(const_cast<void*>(tensor.Data()));
nvcv::Tensor::Requirements req = nvcv::Tensor::CalcRequirements(
batchsize, {w, h}, CreateCvCudaImageFormat(tensor.Dtype(), c));
batchsize, {w, h},
CreateCvCudaImageFormat(tensor.Dtype(), c, layout == Layout::HWC));
nvcv::TensorDataStridedCuda tensor_data(
nvcv::TensorShape{req.shape, req.rank, req.layout},

View File

@@ -15,6 +15,7 @@
#pragma once
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/vision/common/processors/mat.h"
#ifdef ENABLE_CVCUDA
#include "nvcv/Tensor.hpp"
@@ -23,8 +24,10 @@
namespace fastdeploy {
namespace vision {
nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel);
nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor);
nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel,
bool interleaved = true);
nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor,
Layout layout = Layout::HWC);
void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor);
nvcv::ImageWrapData CreateImageWrapData(const FDTensor& tensor);
void CreateCvCudaImageBatchVarShape(std::vector<FDTensor*>& tensors,

View File

@@ -63,6 +63,26 @@ bool HWC2CHW::ImplByFlyCV(Mat* mat) {
}
#endif
#ifdef ENABLE_CVCUDA
bool HWC2CHW::ImplByCvCuda(FDMat* mat) {
// Prepare input tensor
FDTensor* src = CreateCachedGpuInputTensor(mat);
auto src_tensor = CreateCvCudaTensorWrapData(*src);
// Prepare output tensor
mat->output_cache->Resize({mat->Channels(), mat->Height(), mat->Width()},
src->Dtype(), "output_cache", Device::GPU);
auto dst_tensor =
CreateCvCudaTensorWrapData(*(mat->output_cache), Layout::CHW);
cvcuda_reformat_op_(mat->Stream(), src_tensor, dst_tensor);
mat->SetTensor(mat->output_cache);
mat->mat_type = ProcLib::CVCUDA;
return true;
}
#endif
bool HWC2CHW::Run(Mat* mat, ProcLib lib) {
auto h = HWC2CHW();
return h(mat, lib);

View File

@@ -15,6 +15,11 @@
#pragma once
#include "fastdeploy/vision/common/processors/base.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpReformat.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy {
namespace vision {
@@ -24,10 +29,17 @@ class FASTDEPLOY_DECL HWC2CHW : public Processor {
bool ImplByOpenCV(Mat* mat);
#ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(FDMat* mat);
#endif
std::string Name() { return "HWC2CHW"; }
static bool Run(Mat* mat, ProcLib lib = ProcLib::DEFAULT);
private:
#ifdef ENABLE_CVCUDA
cvcuda::Reformat cvcuda_reformat_op_;
#endif
};
} // namespace vision
} // namespace fastdeploy

View File

@@ -73,6 +73,7 @@ bool ProcessorManager::Run(std::vector<FDMat>* images,
}
(*images)[i].input_cache = &input_caches_[i];
(*images)[i].output_cache = &output_caches_[i];
(*images)[i].proc_lib = proc_lib_;
if ((*images)[i].mat_type == ProcLib::CUDA) {
// Make a copy of the input data ptr, so that the original data ptr of
// FDMat won't be modified.

View File

@@ -272,6 +272,9 @@ std::vector<FDMat> WrapMat(const std::vector<cv::Mat>& images) {
}
bool CheckShapeConsistency(std::vector<Mat>* mats) {
if (mats == nullptr) {
return true;
}
for (size_t i = 1; i < mats->size(); ++i) {
if ((*mats)[i].Channels() != (*mats)[0].Channels() ||
(*mats)[i].Width() != (*mats)[0].Width() ||
@@ -285,21 +288,24 @@ bool CheckShapeConsistency(std::vector<Mat>* mats) {
FDTensor* CreateCachedGpuInputTensor(Mat* mat) {
#ifdef WITH_GPU
FDTensor* src = mat->Tensor();
if (src->device == Device::GPU) {
// Need to make sure the tensor is pointed to the input_cache.
if (src->Data() == mat->output_cache->Data()) {
std::swap(mat->input_cache, mat->output_cache);
std::swap(mat->input_cache->name, mat->output_cache->name);
}
if (src->device == Device::GPU) {
return src;
} else if (src->device == Device::CPU) {
// Mats on CPU, we need copy these tensors from CPU to GPU
// Tensor on CPU, we need copy it from CPU to GPU
FDASSERT(src->Shape().size() == 3, "The CPU tensor must has 3 dims.")
mat->input_cache->Resize(src->Shape(), src->Dtype(), "input_cache",
mat->output_cache->Resize(src->Shape(), src->Dtype(), "output_cache",
Device::GPU);
FDASSERT(
cudaMemcpyAsync(mat->input_cache->Data(), src->Data(), src->Nbytes(),
cudaMemcpyAsync(mat->output_cache->Data(), src->Data(), src->Nbytes(),
cudaMemcpyHostToDevice, mat->Stream()) == 0,
"[ERROR] Error occurs while copy memory from CPU to GPU.");
std::swap(mat->input_cache, mat->output_cache);
std::swap(mat->input_cache->name, mat->output_cache->name);
return mat->input_cache;
} else {
FDASSERT(false, "FDMat is on unsupported device: %d", src->device);

View File

@@ -29,10 +29,12 @@ FDTensor* FDMatBatch::Tensor() {
if (has_batched_tensor) {
return fd_tensor.get();
}
FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.")
FDASSERT(mats != nullptr, "Failed to get batched tensor, Mats are empty.");
FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.");
// Each mat has its own tensor,
// to get a batched tensor, we need copy these tensors to a batched tensor
FDTensor* src = (*mats)[0].Tensor();
device = src->device;
auto new_shape = src->Shape();
new_shape.insert(new_shape.begin(), mats->size());
input_cache->Resize(new_shape, src->Dtype(), "batch_input_cache", device);
@@ -51,26 +53,34 @@ FDTensor* FDMatBatch::Tensor() {
void FDMatBatch::SetTensor(FDTensor* tensor) {
fd_tensor->SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(),
tensor->device, tensor->device_id);
device = tensor->device;
has_batched_tensor = true;
}
FDTensor* CreateCachedGpuInputTensor(FDMatBatch* mat_batch) {
#ifdef WITH_GPU
auto mats = mat_batch->mats;
FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.")
FDTensor* src = (*mats)[0].Tensor();
if (mat_batch->device == Device::GPU) {
return mat_batch->Tensor();
} else if (mat_batch->device == Device::CPU) {
// Mats on CPU, we need copy them to GPU and then get a batched GPU tensor
for (size_t i = 0; i < mats->size(); ++i) {
FDTensor* tensor = CreateCachedGpuInputTensor(&(*mats)[i]);
(*mats)[i].SetTensor(tensor);
// Get the batched tensor
FDTensor* src = mat_batch->Tensor();
// Need to make sure the returned tensor is pointed to the input_cache.
if (src->Data() == mat_batch->output_cache->Data()) {
std::swap(mat_batch->input_cache, mat_batch->output_cache);
std::swap(mat_batch->input_cache->name, mat_batch->output_cache->name);
}
mat_batch->device = Device::GPU;
return mat_batch->Tensor();
if (src->device == Device::GPU) {
return src;
} else if (src->device == Device::CPU) {
// Batched tensor on CPU, we need copy it to GPU
mat_batch->output_cache->Resize(src->Shape(), src->Dtype(), "output_cache",
Device::GPU);
FDASSERT(cudaMemcpyAsync(mat_batch->output_cache->Data(), src->Data(),
src->Nbytes(), cudaMemcpyHostToDevice,
mat_batch->Stream()) == 0,
"[ERROR] Error occurs while copy memory from CPU to GPU.");
std::swap(mat_batch->input_cache, mat_batch->output_cache);
std::swap(mat_batch->input_cache->name, mat_batch->output_cache->name);
return mat_batch->input_cache;
} else {
FDASSERT(false, "FDMat is on unsupported device: %d", src->device);
FDASSERT(false, "FDMatBatch is on unsupported device: %d", src->device);
}
#else
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");

View File

@@ -56,7 +56,7 @@ struct FASTDEPLOY_DECL FDMatBatch {
void SetStream(cudaStream_t s);
#endif
std::vector<FDMat>* mats;
std::vector<FDMat>* mats = nullptr;
ProcLib mat_type = ProcLib::OPENCV;
FDMatBatchLayout layout = FDMatBatchLayout::NHWC;
Device device = Device::CPU;

View File

@@ -0,0 +1,116 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef WITH_GPU
#include "fastdeploy/vision/common/processors/normalize.h"
namespace fastdeploy {
namespace vision {
__global__ void NormalizeKernel(const uint8_t* src, float* dst,
const float* alpha, const float* beta,
int num_channel, bool swap_rb, int batch_size,
int edge) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx >= edge) return;
int img_size = edge / batch_size;
int n = idx / img_size; // batch index
int p = idx - (n * img_size); // pixel index within the image
for (int i = 0; i < num_channel; ++i) {
int j = i;
if (swap_rb) {
j = 2 - i;
}
dst[num_channel * idx + j] =
src[num_channel * idx + j] * alpha[i] + beta[i];
}
}
bool Normalize::ImplByCuda(FDMat* mat) {
if (mat->layout != Layout::HWC) {
FDERROR << "The input data must be NHWC format!" << std::endl;
return false;
}
// Prepare input tensor
FDTensor* src = CreateCachedGpuInputTensor(mat);
src->ExpandDim(0);
FDMatBatch mat_batch;
mat_batch.SetTensor(src);
mat_batch.mat_type = ProcLib::CUDA;
mat_batch.input_cache = mat->input_cache;
mat_batch.output_cache = mat->output_cache;
bool ret = ImplByCuda(&mat_batch);
FDTensor* dst = mat_batch.Tensor();
dst->Squeeze(0);
mat->SetTensor(dst);
mat->mat_type = ProcLib::CUDA;
return true;
}
bool Normalize::ImplByCuda(FDMatBatch* mat_batch) {
if (mat_batch->layout != FDMatBatchLayout::NHWC) {
FDERROR << "The input data must be NHWC format!" << std::endl;
return false;
}
// Prepare input tensor
FDTensor* src = CreateCachedGpuInputTensor(mat_batch);
// Prepare output tensor
mat_batch->output_cache->Resize(src->Shape(), FDDataType::FP32,
"batch_output_cache", Device::GPU);
// Copy alpha and beta to GPU
gpu_alpha_.Resize({1, 1, static_cast<int>(alpha_.size())}, FDDataType::FP32,
"alpha", Device::GPU);
cudaMemcpy(gpu_alpha_.Data(), alpha_.data(), gpu_alpha_.Nbytes(),
cudaMemcpyHostToDevice);
gpu_beta_.Resize({1, 1, static_cast<int>(beta_.size())}, FDDataType::FP32,
"beta", Device::GPU);
cudaMemcpy(gpu_beta_.Data(), beta_.data(), gpu_beta_.Nbytes(),
cudaMemcpyHostToDevice);
int jobs =
mat_batch->output_cache->Numel() / mat_batch->output_cache->shape[3];
int threads = 256;
int blocks = ceil(jobs / (float)threads);
NormalizeKernel<<<blocks, threads, 0, mat_batch->Stream()>>>(
reinterpret_cast<uint8_t*>(src->Data()),
reinterpret_cast<float*>(mat_batch->output_cache->Data()),
reinterpret_cast<float*>(gpu_alpha_.Data()),
reinterpret_cast<float*>(gpu_beta_.Data()),
mat_batch->output_cache->shape[3], swap_rb_,
mat_batch->output_cache->shape[0], jobs);
mat_batch->SetTensor(mat_batch->output_cache);
mat_batch->mat_type = ProcLib::CUDA;
return true;
}
#ifdef ENABLE_CVCUDA
bool Normalize::ImplByCvCuda(FDMat* mat) { return ImplByCuda(mat); }
bool Normalize::ImplByCvCuda(FDMatBatch* mat_batch) {
return ImplByCuda(mat_batch);
}
#endif
} // namespace vision
} // namespace fastdeploy
#endif

View File

@@ -28,6 +28,14 @@ class FASTDEPLOY_DECL Normalize : public Processor {
bool ImplByOpenCV(Mat* mat);
#ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat);
#endif
#ifdef WITH_GPU
bool ImplByCuda(FDMat* mat);
bool ImplByCuda(FDMatBatch* mat_batch);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(FDMat* mat);
bool ImplByCvCuda(FDMatBatch* mat_batch);
#endif
std::string Name() { return "Normalize"; }
@@ -61,6 +69,8 @@ class FASTDEPLOY_DECL Normalize : public Processor {
private:
std::vector<float> alpha_;
std::vector<float> beta_;
FDTensor gpu_alpha_;
FDTensor gpu_beta_;
bool swap_rb_;
};
} // namespace vision

View File

@@ -126,7 +126,7 @@ bool Pad::ImplByCvCuda(FDMat* mat) {
auto src_tensor = CreateCvCudaTensorWrapData(*src);
int height = mat->Height() + top_ + bottom_;
int width = mat->Height() + left_ + right_;
int width = mat->Width() + left_ + right_;
// Prepare output tensor
mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(),
@@ -137,9 +137,6 @@ bool Pad::ImplByCvCuda(FDMat* mat) {
NVCV_BORDER_CONSTANT, value);
mat->SetTensor(mat->output_cache);
mat->SetWidth(width);
mat->SetHeight(height);
mat->device = Device::GPU;
mat->mat_type = ProcLib::CVCUDA;
return true;
}

View File

@@ -22,8 +22,20 @@ namespace fastdeploy {
namespace vision {
namespace ocr {
void OcrClassifierResizeImage(FDMat* mat,
const std::vector<int>& cls_image_shape) {
ClassifierPreprocessor::ClassifierPreprocessor() {
resize_op_ = std::make_shared<Resize>(-1, -1);
std::vector<float> value = {0, 0, 0};
pad_op_ = std::make_shared<Pad>(0, 0, 0, 0, value);
normalize_op_ =
std::make_shared<Normalize>(std::vector<float>({0.5f, 0.5f, 0.5f}),
std::vector<float>({0.5f, 0.5f, 0.5f}), true);
hwc2chw_op_ = std::make_shared<HWC2CHW>();
}
void ClassifierPreprocessor::OcrClassifierResizeImage(
FDMat* mat, const std::vector<int>& cls_image_shape) {
int img_c = cls_image_shape[0];
int img_h = cls_image_shape[1];
int img_w = cls_image_shape[2];
@@ -36,12 +48,8 @@ void OcrClassifierResizeImage(FDMat* mat,
else
resize_w = int(ceilf(img_h * ratio));
Resize::Run(mat, resize_w, img_h);
}
bool ClassifierPreprocessor::Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs) {
return Run(images, outputs, 0, images->size());
resize_op_->SetWidthAndHeight(resize_w, img_h);
(*resize_op_)(mat);
}
bool ClassifierPreprocessor::Run(std::vector<FDMat>* images,
@@ -55,36 +63,37 @@ bool ClassifierPreprocessor::Run(std::vector<FDMat>* images,
return false;
}
std::vector<FDMat> mats(end_index - start_index);
for (size_t i = start_index; i < end_index; ++i) {
FDMat* mat = &(images->at(i));
mats[i - start_index] = images->at(i);
}
return Run(&mats, outputs);
}
bool ClassifierPreprocessor::Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs) {
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
FDMat* mat = &(image_batch->mats->at(i));
OcrClassifierResizeImage(mat, cls_image_shape_);
if (!disable_normalize_) {
Normalize::Run(mat, mean_, scale_, is_scale_);
(*normalize_op_)(mat);
}
std::vector<float> value = {0, 0, 0};
if (mat->Width() < cls_image_shape_[2]) {
Pad::Run(mat, 0, 0, 0, cls_image_shape_[2] - mat->Width(), value);
pad_op_->SetPaddingSize(0, 0, 0, cls_image_shape_[2] - mat->Width());
(*pad_op_)(mat);
}
if (!disable_permute_) {
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
(*hwc2chw_op_)(mat);
}
}
// Only have 1 output Tensor.
// Only have 1 output tensor.
outputs->resize(1);
// Concat all the preprocessed data to a batch tensor
size_t tensor_size = end_index - start_index;
std::vector<FDTensor> tensors(tensor_size);
for (size_t i = 0; i < tensor_size; ++i) {
(*images)[i + start_index].ShareWithTensor(&(tensors[i]));
tensors[i].ExpandDim(0);
}
if (tensors.size() == 1) {
(*outputs)[0] = std::move(tensors[0]);
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
}
// Get the NCHW tensor
FDTensor* tensor = image_batch->Tensor();
(*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
tensor->Data(), tensor->device,
tensor->device_id);
return true;
}

View File

@@ -14,6 +14,7 @@
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/processors/manager.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
@@ -22,32 +23,37 @@ namespace vision {
namespace ocr {
/*! @brief Preprocessor object for Classifier serials model.
*/
class FASTDEPLOY_DECL ClassifierPreprocessor {
class FASTDEPLOY_DECL ClassifierPreprocessor : public ProcessorManager {
public:
ClassifierPreprocessor();
using ProcessorManager::Run;
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input data list, all the elements are FDMat
* \param[in] outputs The output tensors which will be fed into runtime
* \return true if the preprocess successed, otherwise false
*/
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
size_t start_index, size_t end_index);
/// Set mean value for the image normalization in classification preprocess
void SetMean(const std::vector<float>& mean) { mean_ = mean; }
/// Get mean value of the image normalization in classification preprocess
std::vector<float> GetMean() const { return mean_; }
/** \brief Implement the virtual function of ProcessorManager, Apply() is the
* body of Run(). Apply() contains the main logic of preprocessing, Run() is
* called by users to execute preprocessing
*
* \param[in] image_batch The input image batch
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
virtual bool Apply(FDMatBatch* image_batch, std::vector<FDTensor>* outputs);
/// Set scale value for the image normalization in classification preprocess
void SetScale(const std::vector<float>& scale) { scale_ = scale; }
/// Get scale value of the image normalization in classification preprocess
std::vector<float> GetScale() const { return scale_; }
/// Set is_scale for the image normalization in classification preprocess
void SetIsScale(bool is_scale) { is_scale_ = is_scale; }
/// Get is_scale of the image normalization in classification preprocess
bool GetIsScale() const { return is_scale_; }
/// Set preprocess normalize parameters, please call this API to customize
/// the normalize parameters, otherwise it will use the default normalize
/// parameters.
void SetNormalize(const std::vector<float>& mean,
const std::vector<float>& std,
bool is_scale) {
normalize_op_ = std::make_shared<Normalize>(mean, std, is_scale);
}
/// Set cls_image_shape for the classification preprocess
void SetClsImageShape(const std::vector<int>& cls_image_shape) {
@@ -62,14 +68,18 @@ class FASTDEPLOY_DECL ClassifierPreprocessor {
void DisablePermute() { disable_normalize_ = true; }
private:
void OcrClassifierResizeImage(FDMat* mat,
const std::vector<int>& cls_image_shape);
// for recording the switch of hwc2chw
bool disable_permute_ = false;
// for recording the switch of normalize
bool disable_normalize_ = false;
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
bool is_scale_ = true;
std::vector<int> cls_image_shape_ = {3, 48, 192};
std::shared_ptr<Resize> resize_op_;
std::shared_ptr<Pad> pad_op_;
std::shared_ptr<Normalize> normalize_op_;
std::shared_ptr<HWC2CHW> hwc2chw_op_;
};
} // namespace ocr

View File

@@ -55,11 +55,9 @@ DBDetectorPreprocessor::DBDetectorPreprocessor() {
std::vector<float> value = {0, 0, 0};
pad_op_ = std::make_shared<Pad>(0, 0, 0, 0, value);
std::vector<float> mean = {0.485f, 0.456f, 0.406f};
std::vector<float> std = {0.229f, 0.224f, 0.225f};
bool is_scale = true;
normalize_permute_op_ =
std::make_shared<NormalizeAndPermute>(mean, std, is_scale);
normalize_permute_op_ = std::make_shared<NormalizeAndPermute>(
std::vector<float>({0.485f, 0.456f, 0.406f}),
std::vector<float>({0.229f, 0.224f, 0.225f}), true);
}
bool DBDetectorPreprocessor::ResizeImage(FDMat* img, int resize_w, int resize_h,

View File

@@ -46,9 +46,9 @@ class FASTDEPLOY_DECL DBDetectorPreprocessor : public ProcessorManager {
/// Set preprocess normalize parameters, please call this API to customize
/// the normalize parameters, otherwise it will use the default normalize
/// parameters.
void SetNormalize(const std::vector<float>& mean = {0.485f, 0.456f, 0.406f},
const std::vector<float>& std = {0.229f, 0.224f, 0.225f},
bool is_scale = true) {
void SetNormalize(const std::vector<float>& mean,
const std::vector<float>& std,
bool is_scale) {
normalize_permute_op_ =
std::make_shared<NormalizeAndPermute>(mean, std, is_scale);
}

View File

@@ -23,8 +23,8 @@ void BindPPOCRModel(pybind11::module& m) {
});
// DBDetector
pybind11::class_<vision::ocr::DBDetectorPreprocessor>(
m, "DBDetectorPreprocessor")
pybind11::class_<vision::ocr::DBDetectorPreprocessor,
vision::ProcessorManager>(m, "DBDetectorPreprocessor")
.def(pybind11::init<>())
.def_property("static_shape_infer",
&vision::ocr::DBDetectorPreprocessor::GetStaticShapeInfer,
@@ -133,19 +133,16 @@ void BindPPOCRModel(pybind11::module& m) {
});
// Classifier
pybind11::class_<vision::ocr::ClassifierPreprocessor>(
m, "ClassifierPreprocessor")
pybind11::class_<vision::ocr::ClassifierPreprocessor,
vision::ProcessorManager>(m, "ClassifierPreprocessor")
.def(pybind11::init<>())
.def_property("cls_image_shape",
&vision::ocr::ClassifierPreprocessor::GetClsImageShape,
&vision::ocr::ClassifierPreprocessor::SetClsImageShape)
.def_property("mean", &vision::ocr::ClassifierPreprocessor::GetMean,
&vision::ocr::ClassifierPreprocessor::SetMean)
.def_property("scale", &vision::ocr::ClassifierPreprocessor::GetScale,
&vision::ocr::ClassifierPreprocessor::SetScale)
.def_property("is_scale",
&vision::ocr::ClassifierPreprocessor::GetIsScale,
&vision::ocr::ClassifierPreprocessor::SetIsScale)
.def("set_normalize",
[](vision::ocr::ClassifierPreprocessor& self,
const std::vector<float>& mean, const std::vector<float>& std,
bool is_scale) { self.SetNormalize(mean, std, is_scale); })
.def("run",
[](vision::ocr::ClassifierPreprocessor& self,
std::vector<pybind11::array>& im_list) {
@@ -233,8 +230,8 @@ void BindPPOCRModel(pybind11::module& m) {
});
// Recognizer
pybind11::class_<vision::ocr::RecognizerPreprocessor>(
m, "RecognizerPreprocessor")
pybind11::class_<vision::ocr::RecognizerPreprocessor,
vision::ProcessorManager>(m, "RecognizerPreprocessor")
.def(pybind11::init<>())
.def_property("static_shape_infer",
&vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer,
@@ -242,13 +239,10 @@ void BindPPOCRModel(pybind11::module& m) {
.def_property("rec_image_shape",
&vision::ocr::RecognizerPreprocessor::GetRecImageShape,
&vision::ocr::RecognizerPreprocessor::SetRecImageShape)
.def_property("mean", &vision::ocr::RecognizerPreprocessor::GetMean,
&vision::ocr::RecognizerPreprocessor::SetMean)
.def_property("scale", &vision::ocr::RecognizerPreprocessor::GetScale,
&vision::ocr::RecognizerPreprocessor::SetScale)
.def_property("is_scale",
&vision::ocr::RecognizerPreprocessor::GetIsScale,
&vision::ocr::RecognizerPreprocessor::SetIsScale)
.def("set_normalize",
[](vision::ocr::RecognizerPreprocessor& self,
const std::vector<float>& mean, const std::vector<float>& std,
bool is_scale) { self.SetNormalize(mean, std, is_scale); })
.def("run",
[](vision::ocr::RecognizerPreprocessor& self,
std::vector<pybind11::array>& im_list) {

View File

@@ -22,8 +22,23 @@ namespace fastdeploy {
namespace vision {
namespace ocr {
void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio,
const std::vector<int>& rec_image_shape,
RecognizerPreprocessor::RecognizerPreprocessor() {
resize_op_ = std::make_shared<Resize>(-1, -1);
std::vector<float> value = {127, 127, 127};
pad_op_ = std::make_shared<Pad>(0, 0, 0, 0, value);
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> std = {0.5f, 0.5f, 0.5f};
normalize_permute_op_ =
std::make_shared<NormalizeAndPermute>(mean, std, true);
normalize_op_ = std::make_shared<Normalize>(mean, std, true);
hwc2chw_op_ = std::make_shared<HWC2CHW>();
cast_op_ = std::make_shared<Cast>("float");
}
void RecognizerPreprocessor::OcrRecognizerResizeImage(
FDMat* mat, float max_wh_ratio, const std::vector<int>& rec_image_shape,
bool static_shape_infer) {
int img_h, img_w;
img_h = rec_image_shape[1];
@@ -39,25 +54,25 @@ void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio,
} else {
resize_w = int(ceilf(img_h * ratio));
}
Resize::Run(mat, resize_w, img_h);
Pad::Run(mat, 0, 0, 0, int(img_w - mat->Width()), {127, 127, 127});
resize_op_->SetWidthAndHeight(resize_w, img_h);
(*resize_op_)(mat);
pad_op_->SetPaddingSize(0, 0, 0, int(img_w - mat->Width()));
(*pad_op_)(mat);
} else {
if (mat->Width() >= img_w) {
Resize::Run(mat, img_w, img_h); // Reszie W to 320
// Reszie W to 320
resize_op_->SetWidthAndHeight(img_w, img_h);
(*resize_op_)(mat);
} else {
Resize::Run(mat, mat->Width(), img_h);
Pad::Run(mat, 0, 0, 0, int(img_w - mat->Width()), {127, 127, 127});
resize_op_->SetWidthAndHeight(mat->Width(), img_h);
(*resize_op_)(mat);
// Pad to 320
pad_op_->SetPaddingSize(0, 0, 0, int(img_w - mat->Width()));
(*pad_op_)(mat);
}
}
}
bool RecognizerPreprocessor::Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs) {
return Run(images, outputs, 0, images->size(), {});
}
bool RecognizerPreprocessor::Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs,
size_t start_index, size_t end_index,
@@ -70,60 +85,55 @@ bool RecognizerPreprocessor::Run(std::vector<FDMat>* images,
return false;
}
std::vector<FDMat> mats(end_index - start_index);
for (size_t i = start_index; i < end_index; ++i) {
size_t real_index = i;
if (indices.size() != 0) {
real_index = indices[i];
}
mats[i - start_index] = images->at(real_index);
}
return Run(&mats, outputs);
}
bool RecognizerPreprocessor::Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs) {
int img_h = rec_image_shape_[1];
int img_w = rec_image_shape_[2];
float max_wh_ratio = img_w * 1.0 / img_h;
float ori_wh_ratio;
for (size_t i = start_index; i < end_index; ++i) {
size_t real_index = i;
if (indices.size() != 0) {
real_index = indices[i];
}
FDMat* mat = &(images->at(real_index));
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
FDMat* mat = &(image_batch->mats->at(i));
ori_wh_ratio = mat->Width() * 1.0 / mat->Height();
max_wh_ratio = std::max(max_wh_ratio, ori_wh_ratio);
}
for (size_t i = start_index; i < end_index; ++i) {
size_t real_index = i;
if (indices.size() != 0) {
real_index = indices[i];
}
FDMat* mat = &(images->at(real_index));
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
FDMat* mat = &(image_batch->mats->at(i));
OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_,
static_shape_infer_);
if (!disable_normalize_ && !disable_permute_) {
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
} else {
if (!disable_normalize_) {
Normalize::Run(mat, mean_, scale_, is_scale_);
}
if (!disable_permute_) {
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
}
}
}
// Only have 1 output Tensor.
outputs->resize(1);
size_t tensor_size = end_index - start_index;
// Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> tensors(tensor_size);
for (size_t i = 0; i < tensor_size; ++i) {
size_t real_index = i + start_index;
if (indices.size() != 0) {
real_index = indices[i + start_index];
}
(*images)[real_index].ShareWithTensor(&(tensors[i]));
tensors[i].ExpandDim(0);
}
if (tensors.size() == 1) {
(*outputs)[0] = std::move(tensors[0]);
if (!disable_normalize_ && !disable_permute_) {
(*normalize_permute_op_)(image_batch);
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
if (!disable_normalize_) {
(*normalize_op_)(image_batch);
}
if (!disable_permute_) {
(*hwc2chw_op_)(image_batch);
(*cast_op_)(image_batch);
}
}
// Only have 1 output Tensor.
outputs->resize(1);
// Get the NCHW tensor
FDTensor* tensor = image_batch->Tensor();
(*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
tensor->Data(), tensor->device,
tensor->device_id);
return true;
}

View File

@@ -14,6 +14,7 @@
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/processors/manager.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
@@ -22,19 +23,30 @@ namespace vision {
namespace ocr {
/*! @brief Preprocessor object for PaddleClas serials model.
*/
class FASTDEPLOY_DECL RecognizerPreprocessor {
class FASTDEPLOY_DECL RecognizerPreprocessor : public ProcessorManager {
public:
RecognizerPreprocessor();
using ProcessorManager::Run;
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input data list, all the elements are FDMat
* \param[in] outputs The output tensors which will be fed into runtime
* \return true if the preprocess successed, otherwise false
*/
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
size_t start_index, size_t end_index,
const std::vector<int>& indices);
/** \brief Implement the virtual function of ProcessorManager, Apply() is the
* body of Run(). Apply() contains the main logic of preprocessing, Run() is
* called by users to execute preprocessing
*
* \param[in] image_batch The input image batch
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
virtual bool Apply(FDMatBatch* image_batch, std::vector<FDTensor>* outputs);
/// Set static_shape_infer is true or not. When deploy PP-OCR
/// on hardware which can not support dynamic input shape very well,
/// like Huawei Ascned, static_shape_infer needs to to be true.
@@ -44,20 +56,16 @@ class FASTDEPLOY_DECL RecognizerPreprocessor {
/// Get static_shape_infer of the recognition preprocess
bool GetStaticShapeInfer() const { return static_shape_infer_; }
/// Set mean value for the image normalization in recognition preprocess
void SetMean(const std::vector<float>& mean) { mean_ = mean; }
/// Get mean value of the image normalization in recognition preprocess
std::vector<float> GetMean() const { return mean_; }
/// Set scale value for the image normalization in recognition preprocess
void SetScale(const std::vector<float>& scale) { scale_ = scale; }
/// Get scale value of the image normalization in recognition preprocess
std::vector<float> GetScale() const { return scale_; }
/// Set is_scale for the image normalization in recognition preprocess
void SetIsScale(bool is_scale) { is_scale_ = is_scale; }
/// Get is_scale of the image normalization in recognition preprocess
bool GetIsScale() const { return is_scale_; }
/// Set preprocess normalize parameters, please call this API to customize
/// the normalize parameters, otherwise it will use the default normalize
/// parameters.
void SetNormalize(const std::vector<float>& mean,
const std::vector<float>& std,
bool is_scale) {
normalize_permute_op_ =
std::make_shared<NormalizeAndPermute>(mean, std, is_scale);
normalize_op_ = std::make_shared<Normalize>(mean, std, is_scale);
}
/// Set rec_image_shape for the recognition preprocess
void SetRecImageShape(const std::vector<int>& rec_image_shape) {
@@ -72,15 +80,21 @@ class FASTDEPLOY_DECL RecognizerPreprocessor {
void DisablePermute() { disable_normalize_ = true; }
private:
void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio,
const std::vector<int>& rec_image_shape,
bool static_shape_infer);
// for recording the switch of hwc2chw
bool disable_permute_ = false;
// for recording the switch of normalize
bool disable_normalize_ = false;
std::vector<int> rec_image_shape_ = {3, 48, 320};
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
bool is_scale_ = true;
bool static_shape_infer_ = false;
std::shared_ptr<Resize> resize_op_;
std::shared_ptr<Pad> pad_op_;
std::shared_ptr<NormalizeAndPermute> normalize_permute_op_;
std::shared_ptr<Normalize> normalize_op_;
std::shared_ptr<HWC2CHW> hwc2chw_op_;
std::shared_ptr<Cast> cast_op_;
};
} // namespace ocr

View File

@@ -52,10 +52,7 @@ class DBDetectorPreprocessor:
value, int), "The value to set `max_side_len` must be type of int."
self._preprocessor.max_side_len = value
def set_normalize(self,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True):
def set_normalize(self, mean, std, is_scale):
"""Set preprocess normalize parameters, please call this API to
customize the normalize parameters, otherwise it will use the default
normalize parameters.
@@ -340,35 +337,15 @@ class ClassifierPreprocessor:
"""
return self._preprocessor.run(input_ims)
@property
def is_scale(self):
return self._preprocessor.is_scale
@is_scale.setter
def is_scale(self, value):
assert isinstance(
value, bool), "The value to set `is_scale` must be type of bool."
self._preprocessor.is_scale = value
@property
def scale(self):
return self._preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._preprocessor.scale = value
@property
def mean(self):
return self._preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._preprocessor.mean = value
def set_normalize(self, mean, std, is_scale):
"""Set preprocess normalize parameters, please call this API to
customize the normalize parameters, otherwise it will use the default
normalize parameters.
:param: mean: (list of float) mean values
:param: std: (list of float) std values
:param: is_scale: (boolean) whether to scale
"""
self._preprocessor.set_normalize(mean, std, is_scale)
@property
def cls_image_shape(self):
@@ -496,37 +473,6 @@ class Classifier(FastDeployModel):
def postprocessor(self, value):
self._model.postprocessor = value
# Cls Preprocessor Property
@property
def is_scale(self):
return self._model.preprocessor.is_scale
@is_scale.setter
def is_scale(self, value):
assert isinstance(
value, bool), "The value to set `is_scale` must be type of bool."
self._model.preprocessor.is_scale = value
@property
def scale(self):
return self._model.preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._model.preprocessor.scale = value
@property
def mean(self):
return self._model.preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._model.preprocessor.mean = value
@property
def cls_image_shape(self):
return self._model.preprocessor.cls_image_shape
@@ -575,35 +521,15 @@ class RecognizerPreprocessor:
bool), "The value to set `static_shape_infer` must be type of bool."
self._preprocessor.static_shape_infer = value
@property
def is_scale(self):
return self._preprocessor.is_scale
@is_scale.setter
def is_scale(self, value):
assert isinstance(
value, bool), "The value to set `is_scale` must be type of bool."
self._preprocessor.is_scale = value
@property
def scale(self):
return self._preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._preprocessor.scale = value
@property
def mean(self):
return self._preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._preprocessor.mean = value
def set_normalize(self, mean, std, is_scale):
"""Set preprocess normalize parameters, please call this API to
customize the normalize parameters, otherwise it will use the default
normalize parameters.
:param: mean: (list of float) mean values
:param: std: (list of float) std values
:param: is_scale: (boolean) whether to scale
"""
self._preprocessor.set_normalize(mean, std, is_scale)
@property
def rec_image_shape(self):
@@ -728,36 +654,6 @@ class Recognizer(FastDeployModel):
bool), "The value to set `static_shape_infer` must be type of bool."
self._model.preprocessor.static_shape_infer = value
@property
def is_scale(self):
return self._model.preprocessor.is_scale
@is_scale.setter
def is_scale(self, value):
assert isinstance(
value, bool), "The value to set `is_scale` must be type of bool."
self._model.preprocessor.is_scale = value
@property
def scale(self):
return self._model.preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._model.preprocessor.scale = value
@property
def mean(self):
return self._model.preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._model.preprocessor.mean = value
@property
def rec_image_shape(self):
return self._model.preprocessor.rec_image_shape