[CVCUDA] PP-OCR Cls & Rec preprocessor support CV-CUDA (#1470)

* ppocr cls preprocessor use manager

* hwc2chw cvcuda

* ppocr rec preproc use manager

* ocr rec preproc cvcuda

* fix rec preproc bug

* ppocr cls&rec preproc set normalize

* fix pybind

* address comment
This commit is contained in:
Wang Xinyu
2023-03-02 10:50:44 +08:00
committed by GitHub
parent fe2882a1ef
commit 044ab993d2
19 changed files with 424 additions and 306 deletions

View File

@@ -18,34 +18,36 @@ namespace fastdeploy {
namespace vision { namespace vision {
#ifdef ENABLE_CVCUDA #ifdef ENABLE_CVCUDA
nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel) { nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel,
bool interleaved) {
FDASSERT(channel == 1 || channel == 3 || channel == 4, FDASSERT(channel == 1 || channel == 3 || channel == 4,
"Only support channel be 1/3/4 in CV-CUDA."); "Only support channel be 1/3/4 in CV-CUDA.");
if (type == FDDataType::UINT8) { if (type == FDDataType::UINT8) {
if (channel == 1) { if (channel == 1) {
return nvcv::FMT_U8; return nvcv::FMT_U8;
} else if (channel == 3) { } else if (channel == 3) {
return nvcv::FMT_BGR8; return (interleaved ? nvcv::FMT_BGR8 : nvcv::FMT_BGR8p);
} else { } else {
return nvcv::FMT_BGRA8; return (interleaved ? nvcv::FMT_BGRA8 : nvcv::FMT_BGRA8p);
} }
} else if (type == FDDataType::FP32) { } else if (type == FDDataType::FP32) {
if (channel == 1) { if (channel == 1) {
return nvcv::FMT_F32; return nvcv::FMT_F32;
} else if (channel == 3) { } else if (channel == 3) {
return nvcv::FMT_BGRf32; return (interleaved ? nvcv::FMT_BGRf32 : nvcv::FMT_BGRf32p);
} else { } else {
return nvcv::FMT_BGRAf32; return (interleaved ? nvcv::FMT_BGRAf32 : nvcv::FMT_BGRAf32p);
} }
} }
FDASSERT(false, "Data type of %s is not supported.", Str(type).c_str()); FDASSERT(false, "Data type of %s is not supported.", Str(type).c_str());
return nvcv::FMT_BGRf32; return nvcv::FMT_BGRf32;
} }
nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor) { nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor,
Layout layout) {
FDASSERT(tensor.shape.size() == 3, FDASSERT(tensor.shape.size() == 3,
"When create CVCUDA tensor from FD tensor," "When create CVCUDA tensor from FD tensor,"
"tensor shape should be 3-Dim, HWC layout"); "tensor shape should be 3-Dim,");
int batchsize = 1; int batchsize = 1;
int h = tensor.Shape()[0]; int h = tensor.Shape()[0];
int w = tensor.Shape()[1]; int w = tensor.Shape()[1];
@@ -56,10 +58,20 @@ nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor) {
buf.strides[2] = c * buf.strides[3]; buf.strides[2] = c * buf.strides[3];
buf.strides[1] = w * buf.strides[2]; buf.strides[1] = w * buf.strides[2];
buf.strides[0] = h * buf.strides[1]; buf.strides[0] = h * buf.strides[1];
if (layout == Layout::CHW) {
c = tensor.Shape()[0];
h = tensor.Shape()[1];
w = tensor.Shape()[2];
buf.strides[3] = FDDataTypeSize(tensor.Dtype());
buf.strides[2] = w * buf.strides[3];
buf.strides[1] = h * buf.strides[2];
buf.strides[0] = c * buf.strides[1];
}
buf.basePtr = reinterpret_cast<NVCVByte*>(const_cast<void*>(tensor.Data())); buf.basePtr = reinterpret_cast<NVCVByte*>(const_cast<void*>(tensor.Data()));
nvcv::Tensor::Requirements req = nvcv::Tensor::CalcRequirements( nvcv::Tensor::Requirements req = nvcv::Tensor::CalcRequirements(
batchsize, {w, h}, CreateCvCudaImageFormat(tensor.Dtype(), c)); batchsize, {w, h},
CreateCvCudaImageFormat(tensor.Dtype(), c, layout == Layout::HWC));
nvcv::TensorDataStridedCuda tensor_data( nvcv::TensorDataStridedCuda tensor_data(
nvcv::TensorShape{req.shape, req.rank, req.layout}, nvcv::TensorShape{req.shape, req.rank, req.layout},

View File

@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "fastdeploy/core/fd_tensor.h" #include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/vision/common/processors/mat.h"
#ifdef ENABLE_CVCUDA #ifdef ENABLE_CVCUDA
#include "nvcv/Tensor.hpp" #include "nvcv/Tensor.hpp"
@@ -23,8 +24,10 @@
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel); nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel,
nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor); bool interleaved = true);
nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor,
Layout layout = Layout::HWC);
void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor); void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor);
nvcv::ImageWrapData CreateImageWrapData(const FDTensor& tensor); nvcv::ImageWrapData CreateImageWrapData(const FDTensor& tensor);
void CreateCvCudaImageBatchVarShape(std::vector<FDTensor*>& tensors, void CreateCvCudaImageBatchVarShape(std::vector<FDTensor*>& tensors,

View File

@@ -63,6 +63,26 @@ bool HWC2CHW::ImplByFlyCV(Mat* mat) {
} }
#endif #endif
#ifdef ENABLE_CVCUDA
bool HWC2CHW::ImplByCvCuda(FDMat* mat) {
// Prepare input tensor
FDTensor* src = CreateCachedGpuInputTensor(mat);
auto src_tensor = CreateCvCudaTensorWrapData(*src);
// Prepare output tensor
mat->output_cache->Resize({mat->Channels(), mat->Height(), mat->Width()},
src->Dtype(), "output_cache", Device::GPU);
auto dst_tensor =
CreateCvCudaTensorWrapData(*(mat->output_cache), Layout::CHW);
cvcuda_reformat_op_(mat->Stream(), src_tensor, dst_tensor);
mat->SetTensor(mat->output_cache);
mat->mat_type = ProcLib::CVCUDA;
return true;
}
#endif
bool HWC2CHW::Run(Mat* mat, ProcLib lib) { bool HWC2CHW::Run(Mat* mat, ProcLib lib) {
auto h = HWC2CHW(); auto h = HWC2CHW();
return h(mat, lib); return h(mat, lib);

View File

@@ -15,6 +15,11 @@
#pragma once #pragma once
#include "fastdeploy/vision/common/processors/base.h" #include "fastdeploy/vision/common/processors/base.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpReformat.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -24,10 +29,17 @@ class FASTDEPLOY_DECL HWC2CHW : public Processor {
bool ImplByOpenCV(Mat* mat); bool ImplByOpenCV(Mat* mat);
#ifdef ENABLE_FLYCV #ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat); bool ImplByFlyCV(Mat* mat);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(FDMat* mat);
#endif #endif
std::string Name() { return "HWC2CHW"; } std::string Name() { return "HWC2CHW"; }
static bool Run(Mat* mat, ProcLib lib = ProcLib::DEFAULT); static bool Run(Mat* mat, ProcLib lib = ProcLib::DEFAULT);
private:
#ifdef ENABLE_CVCUDA
cvcuda::Reformat cvcuda_reformat_op_;
#endif
}; };
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -73,6 +73,7 @@ bool ProcessorManager::Run(std::vector<FDMat>* images,
} }
(*images)[i].input_cache = &input_caches_[i]; (*images)[i].input_cache = &input_caches_[i];
(*images)[i].output_cache = &output_caches_[i]; (*images)[i].output_cache = &output_caches_[i];
(*images)[i].proc_lib = proc_lib_;
if ((*images)[i].mat_type == ProcLib::CUDA) { if ((*images)[i].mat_type == ProcLib::CUDA) {
// Make a copy of the input data ptr, so that the original data ptr of // Make a copy of the input data ptr, so that the original data ptr of
// FDMat won't be modified. // FDMat won't be modified.

View File

@@ -272,6 +272,9 @@ std::vector<FDMat> WrapMat(const std::vector<cv::Mat>& images) {
} }
bool CheckShapeConsistency(std::vector<Mat>* mats) { bool CheckShapeConsistency(std::vector<Mat>* mats) {
if (mats == nullptr) {
return true;
}
for (size_t i = 1; i < mats->size(); ++i) { for (size_t i = 1; i < mats->size(); ++i) {
if ((*mats)[i].Channels() != (*mats)[0].Channels() || if ((*mats)[i].Channels() != (*mats)[0].Channels() ||
(*mats)[i].Width() != (*mats)[0].Width() || (*mats)[i].Width() != (*mats)[0].Width() ||
@@ -285,21 +288,24 @@ bool CheckShapeConsistency(std::vector<Mat>* mats) {
FDTensor* CreateCachedGpuInputTensor(Mat* mat) { FDTensor* CreateCachedGpuInputTensor(Mat* mat) {
#ifdef WITH_GPU #ifdef WITH_GPU
FDTensor* src = mat->Tensor(); FDTensor* src = mat->Tensor();
if (src->device == Device::GPU) { // Need to make sure the tensor is pointed to the input_cache.
if (src->Data() == mat->output_cache->Data()) { if (src->Data() == mat->output_cache->Data()) {
std::swap(mat->input_cache, mat->output_cache); std::swap(mat->input_cache, mat->output_cache);
std::swap(mat->input_cache->name, mat->output_cache->name); std::swap(mat->input_cache->name, mat->output_cache->name);
} }
if (src->device == Device::GPU) {
return src; return src;
} else if (src->device == Device::CPU) { } else if (src->device == Device::CPU) {
// Mats on CPU, we need copy these tensors from CPU to GPU // Tensor on CPU, we need copy it from CPU to GPU
FDASSERT(src->Shape().size() == 3, "The CPU tensor must has 3 dims.") FDASSERT(src->Shape().size() == 3, "The CPU tensor must has 3 dims.")
mat->input_cache->Resize(src->Shape(), src->Dtype(), "input_cache", mat->output_cache->Resize(src->Shape(), src->Dtype(), "output_cache",
Device::GPU); Device::GPU);
FDASSERT( FDASSERT(
cudaMemcpyAsync(mat->input_cache->Data(), src->Data(), src->Nbytes(), cudaMemcpyAsync(mat->output_cache->Data(), src->Data(), src->Nbytes(),
cudaMemcpyHostToDevice, mat->Stream()) == 0, cudaMemcpyHostToDevice, mat->Stream()) == 0,
"[ERROR] Error occurs while copy memory from CPU to GPU."); "[ERROR] Error occurs while copy memory from CPU to GPU.");
std::swap(mat->input_cache, mat->output_cache);
std::swap(mat->input_cache->name, mat->output_cache->name);
return mat->input_cache; return mat->input_cache;
} else { } else {
FDASSERT(false, "FDMat is on unsupported device: %d", src->device); FDASSERT(false, "FDMat is on unsupported device: %d", src->device);

View File

@@ -29,10 +29,12 @@ FDTensor* FDMatBatch::Tensor() {
if (has_batched_tensor) { if (has_batched_tensor) {
return fd_tensor.get(); return fd_tensor.get();
} }
FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.") FDASSERT(mats != nullptr, "Failed to get batched tensor, Mats are empty.");
FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.");
// Each mat has its own tensor, // Each mat has its own tensor,
// to get a batched tensor, we need copy these tensors to a batched tensor // to get a batched tensor, we need copy these tensors to a batched tensor
FDTensor* src = (*mats)[0].Tensor(); FDTensor* src = (*mats)[0].Tensor();
device = src->device;
auto new_shape = src->Shape(); auto new_shape = src->Shape();
new_shape.insert(new_shape.begin(), mats->size()); new_shape.insert(new_shape.begin(), mats->size());
input_cache->Resize(new_shape, src->Dtype(), "batch_input_cache", device); input_cache->Resize(new_shape, src->Dtype(), "batch_input_cache", device);
@@ -51,26 +53,34 @@ FDTensor* FDMatBatch::Tensor() {
void FDMatBatch::SetTensor(FDTensor* tensor) { void FDMatBatch::SetTensor(FDTensor* tensor) {
fd_tensor->SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(), fd_tensor->SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(),
tensor->device, tensor->device_id); tensor->device, tensor->device_id);
device = tensor->device;
has_batched_tensor = true; has_batched_tensor = true;
} }
FDTensor* CreateCachedGpuInputTensor(FDMatBatch* mat_batch) { FDTensor* CreateCachedGpuInputTensor(FDMatBatch* mat_batch) {
#ifdef WITH_GPU #ifdef WITH_GPU
auto mats = mat_batch->mats; // Get the batched tensor
FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.") FDTensor* src = mat_batch->Tensor();
FDTensor* src = (*mats)[0].Tensor(); // Need to make sure the returned tensor is pointed to the input_cache.
if (mat_batch->device == Device::GPU) { if (src->Data() == mat_batch->output_cache->Data()) {
return mat_batch->Tensor(); std::swap(mat_batch->input_cache, mat_batch->output_cache);
} else if (mat_batch->device == Device::CPU) { std::swap(mat_batch->input_cache->name, mat_batch->output_cache->name);
// Mats on CPU, we need copy them to GPU and then get a batched GPU tensor
for (size_t i = 0; i < mats->size(); ++i) {
FDTensor* tensor = CreateCachedGpuInputTensor(&(*mats)[i]);
(*mats)[i].SetTensor(tensor);
} }
mat_batch->device = Device::GPU; if (src->device == Device::GPU) {
return mat_batch->Tensor(); return src;
} else if (src->device == Device::CPU) {
// Batched tensor on CPU, we need copy it to GPU
mat_batch->output_cache->Resize(src->Shape(), src->Dtype(), "output_cache",
Device::GPU);
FDASSERT(cudaMemcpyAsync(mat_batch->output_cache->Data(), src->Data(),
src->Nbytes(), cudaMemcpyHostToDevice,
mat_batch->Stream()) == 0,
"[ERROR] Error occurs while copy memory from CPU to GPU.");
std::swap(mat_batch->input_cache, mat_batch->output_cache);
std::swap(mat_batch->input_cache->name, mat_batch->output_cache->name);
return mat_batch->input_cache;
} else { } else {
FDASSERT(false, "FDMat is on unsupported device: %d", src->device); FDASSERT(false, "FDMatBatch is on unsupported device: %d", src->device);
} }
#else #else
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU."); FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");

View File

@@ -56,7 +56,7 @@ struct FASTDEPLOY_DECL FDMatBatch {
void SetStream(cudaStream_t s); void SetStream(cudaStream_t s);
#endif #endif
std::vector<FDMat>* mats; std::vector<FDMat>* mats = nullptr;
ProcLib mat_type = ProcLib::OPENCV; ProcLib mat_type = ProcLib::OPENCV;
FDMatBatchLayout layout = FDMatBatchLayout::NHWC; FDMatBatchLayout layout = FDMatBatchLayout::NHWC;
Device device = Device::CPU; Device device = Device::CPU;

View File

@@ -0,0 +1,116 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef WITH_GPU
#include "fastdeploy/vision/common/processors/normalize.h"
namespace fastdeploy {
namespace vision {
__global__ void NormalizeKernel(const uint8_t* src, float* dst,
const float* alpha, const float* beta,
int num_channel, bool swap_rb, int batch_size,
int edge) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx >= edge) return;
int img_size = edge / batch_size;
int n = idx / img_size; // batch index
int p = idx - (n * img_size); // pixel index within the image
for (int i = 0; i < num_channel; ++i) {
int j = i;
if (swap_rb) {
j = 2 - i;
}
dst[num_channel * idx + j] =
src[num_channel * idx + j] * alpha[i] + beta[i];
}
}
bool Normalize::ImplByCuda(FDMat* mat) {
if (mat->layout != Layout::HWC) {
FDERROR << "The input data must be NHWC format!" << std::endl;
return false;
}
// Prepare input tensor
FDTensor* src = CreateCachedGpuInputTensor(mat);
src->ExpandDim(0);
FDMatBatch mat_batch;
mat_batch.SetTensor(src);
mat_batch.mat_type = ProcLib::CUDA;
mat_batch.input_cache = mat->input_cache;
mat_batch.output_cache = mat->output_cache;
bool ret = ImplByCuda(&mat_batch);
FDTensor* dst = mat_batch.Tensor();
dst->Squeeze(0);
mat->SetTensor(dst);
mat->mat_type = ProcLib::CUDA;
return true;
}
bool Normalize::ImplByCuda(FDMatBatch* mat_batch) {
if (mat_batch->layout != FDMatBatchLayout::NHWC) {
FDERROR << "The input data must be NHWC format!" << std::endl;
return false;
}
// Prepare input tensor
FDTensor* src = CreateCachedGpuInputTensor(mat_batch);
// Prepare output tensor
mat_batch->output_cache->Resize(src->Shape(), FDDataType::FP32,
"batch_output_cache", Device::GPU);
// Copy alpha and beta to GPU
gpu_alpha_.Resize({1, 1, static_cast<int>(alpha_.size())}, FDDataType::FP32,
"alpha", Device::GPU);
cudaMemcpy(gpu_alpha_.Data(), alpha_.data(), gpu_alpha_.Nbytes(),
cudaMemcpyHostToDevice);
gpu_beta_.Resize({1, 1, static_cast<int>(beta_.size())}, FDDataType::FP32,
"beta", Device::GPU);
cudaMemcpy(gpu_beta_.Data(), beta_.data(), gpu_beta_.Nbytes(),
cudaMemcpyHostToDevice);
int jobs =
mat_batch->output_cache->Numel() / mat_batch->output_cache->shape[3];
int threads = 256;
int blocks = ceil(jobs / (float)threads);
NormalizeKernel<<<blocks, threads, 0, mat_batch->Stream()>>>(
reinterpret_cast<uint8_t*>(src->Data()),
reinterpret_cast<float*>(mat_batch->output_cache->Data()),
reinterpret_cast<float*>(gpu_alpha_.Data()),
reinterpret_cast<float*>(gpu_beta_.Data()),
mat_batch->output_cache->shape[3], swap_rb_,
mat_batch->output_cache->shape[0], jobs);
mat_batch->SetTensor(mat_batch->output_cache);
mat_batch->mat_type = ProcLib::CUDA;
return true;
}
#ifdef ENABLE_CVCUDA
bool Normalize::ImplByCvCuda(FDMat* mat) { return ImplByCuda(mat); }
bool Normalize::ImplByCvCuda(FDMatBatch* mat_batch) {
return ImplByCuda(mat_batch);
}
#endif
} // namespace vision
} // namespace fastdeploy
#endif

View File

@@ -28,6 +28,14 @@ class FASTDEPLOY_DECL Normalize : public Processor {
bool ImplByOpenCV(Mat* mat); bool ImplByOpenCV(Mat* mat);
#ifdef ENABLE_FLYCV #ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat); bool ImplByFlyCV(Mat* mat);
#endif
#ifdef WITH_GPU
bool ImplByCuda(FDMat* mat);
bool ImplByCuda(FDMatBatch* mat_batch);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(FDMat* mat);
bool ImplByCvCuda(FDMatBatch* mat_batch);
#endif #endif
std::string Name() { return "Normalize"; } std::string Name() { return "Normalize"; }
@@ -61,6 +69,8 @@ class FASTDEPLOY_DECL Normalize : public Processor {
private: private:
std::vector<float> alpha_; std::vector<float> alpha_;
std::vector<float> beta_; std::vector<float> beta_;
FDTensor gpu_alpha_;
FDTensor gpu_beta_;
bool swap_rb_; bool swap_rb_;
}; };
} // namespace vision } // namespace vision

View File

@@ -126,7 +126,7 @@ bool Pad::ImplByCvCuda(FDMat* mat) {
auto src_tensor = CreateCvCudaTensorWrapData(*src); auto src_tensor = CreateCvCudaTensorWrapData(*src);
int height = mat->Height() + top_ + bottom_; int height = mat->Height() + top_ + bottom_;
int width = mat->Height() + left_ + right_; int width = mat->Width() + left_ + right_;
// Prepare output tensor // Prepare output tensor
mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(), mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(),
@@ -137,9 +137,6 @@ bool Pad::ImplByCvCuda(FDMat* mat) {
NVCV_BORDER_CONSTANT, value); NVCV_BORDER_CONSTANT, value);
mat->SetTensor(mat->output_cache); mat->SetTensor(mat->output_cache);
mat->SetWidth(width);
mat->SetHeight(height);
mat->device = Device::GPU;
mat->mat_type = ProcLib::CVCUDA; mat->mat_type = ProcLib::CVCUDA;
return true; return true;
} }

View File

@@ -22,8 +22,20 @@ namespace fastdeploy {
namespace vision { namespace vision {
namespace ocr { namespace ocr {
void OcrClassifierResizeImage(FDMat* mat, ClassifierPreprocessor::ClassifierPreprocessor() {
const std::vector<int>& cls_image_shape) { resize_op_ = std::make_shared<Resize>(-1, -1);
std::vector<float> value = {0, 0, 0};
pad_op_ = std::make_shared<Pad>(0, 0, 0, 0, value);
normalize_op_ =
std::make_shared<Normalize>(std::vector<float>({0.5f, 0.5f, 0.5f}),
std::vector<float>({0.5f, 0.5f, 0.5f}), true);
hwc2chw_op_ = std::make_shared<HWC2CHW>();
}
void ClassifierPreprocessor::OcrClassifierResizeImage(
FDMat* mat, const std::vector<int>& cls_image_shape) {
int img_c = cls_image_shape[0]; int img_c = cls_image_shape[0];
int img_h = cls_image_shape[1]; int img_h = cls_image_shape[1];
int img_w = cls_image_shape[2]; int img_w = cls_image_shape[2];
@@ -36,12 +48,8 @@ void OcrClassifierResizeImage(FDMat* mat,
else else
resize_w = int(ceilf(img_h * ratio)); resize_w = int(ceilf(img_h * ratio));
Resize::Run(mat, resize_w, img_h); resize_op_->SetWidthAndHeight(resize_w, img_h);
} (*resize_op_)(mat);
bool ClassifierPreprocessor::Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs) {
return Run(images, outputs, 0, images->size());
} }
bool ClassifierPreprocessor::Run(std::vector<FDMat>* images, bool ClassifierPreprocessor::Run(std::vector<FDMat>* images,
@@ -55,36 +63,37 @@ bool ClassifierPreprocessor::Run(std::vector<FDMat>* images,
return false; return false;
} }
std::vector<FDMat> mats(end_index - start_index);
for (size_t i = start_index; i < end_index; ++i) { for (size_t i = start_index; i < end_index; ++i) {
FDMat* mat = &(images->at(i)); mats[i - start_index] = images->at(i);
}
return Run(&mats, outputs);
}
bool ClassifierPreprocessor::Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs) {
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
FDMat* mat = &(image_batch->mats->at(i));
OcrClassifierResizeImage(mat, cls_image_shape_); OcrClassifierResizeImage(mat, cls_image_shape_);
if (!disable_normalize_) { if (!disable_normalize_) {
Normalize::Run(mat, mean_, scale_, is_scale_); (*normalize_op_)(mat);
} }
std::vector<float> value = {0, 0, 0}; std::vector<float> value = {0, 0, 0};
if (mat->Width() < cls_image_shape_[2]) { if (mat->Width() < cls_image_shape_[2]) {
Pad::Run(mat, 0, 0, 0, cls_image_shape_[2] - mat->Width(), value); pad_op_->SetPaddingSize(0, 0, 0, cls_image_shape_[2] - mat->Width());
(*pad_op_)(mat);
} }
if (!disable_permute_) { if (!disable_permute_) {
HWC2CHW::Run(mat); (*hwc2chw_op_)(mat);
Cast::Run(mat, "float");
} }
} }
// Only have 1 output Tensor. // Only have 1 output tensor.
outputs->resize(1); outputs->resize(1);
// Concat all the preprocessed data to a batch tensor // Get the NCHW tensor
size_t tensor_size = end_index - start_index; FDTensor* tensor = image_batch->Tensor();
std::vector<FDTensor> tensors(tensor_size); (*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
for (size_t i = 0; i < tensor_size; ++i) { tensor->Data(), tensor->device,
(*images)[i + start_index].ShareWithTensor(&(tensors[i])); tensor->device_id);
tensors[i].ExpandDim(0);
}
if (tensors.size() == 1) {
(*outputs)[0] = std::move(tensors[0]);
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
}
return true; return true;
} }

View File

@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/processors/manager.h"
#include "fastdeploy/vision/common/result.h" #include "fastdeploy/vision/common/result.h"
namespace fastdeploy { namespace fastdeploy {
@@ -22,32 +23,37 @@ namespace vision {
namespace ocr { namespace ocr {
/*! @brief Preprocessor object for Classifier serials model. /*! @brief Preprocessor object for Classifier serials model.
*/ */
class FASTDEPLOY_DECL ClassifierPreprocessor { class FASTDEPLOY_DECL ClassifierPreprocessor : public ProcessorManager {
public: public:
ClassifierPreprocessor();
using ProcessorManager::Run;
/** \brief Process the input image and prepare input tensors for runtime /** \brief Process the input image and prepare input tensors for runtime
* *
* \param[in] images The input data list, all the elements are FDMat * \param[in] images The input data list, all the elements are FDMat
* \param[in] outputs The output tensors which will be fed into runtime * \param[in] outputs The output tensors which will be fed into runtime
* \return true if the preprocess successed, otherwise false * \return true if the preprocess successed, otherwise false
*/ */
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs, bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
size_t start_index, size_t end_index); size_t start_index, size_t end_index);
/// Set mean value for the image normalization in classification preprocess /** \brief Implement the virtual function of ProcessorManager, Apply() is the
void SetMean(const std::vector<float>& mean) { mean_ = mean; } * body of Run(). Apply() contains the main logic of preprocessing, Run() is
/// Get mean value of the image normalization in classification preprocess * called by users to execute preprocessing
std::vector<float> GetMean() const { return mean_; } *
* \param[in] image_batch The input image batch
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
virtual bool Apply(FDMatBatch* image_batch, std::vector<FDTensor>* outputs);
/// Set scale value for the image normalization in classification preprocess /// Set preprocess normalize parameters, please call this API to customize
void SetScale(const std::vector<float>& scale) { scale_ = scale; } /// the normalize parameters, otherwise it will use the default normalize
/// Get scale value of the image normalization in classification preprocess /// parameters.
std::vector<float> GetScale() const { return scale_; } void SetNormalize(const std::vector<float>& mean,
const std::vector<float>& std,
/// Set is_scale for the image normalization in classification preprocess bool is_scale) {
void SetIsScale(bool is_scale) { is_scale_ = is_scale; } normalize_op_ = std::make_shared<Normalize>(mean, std, is_scale);
/// Get is_scale of the image normalization in classification preprocess }
bool GetIsScale() const { return is_scale_; }
/// Set cls_image_shape for the classification preprocess /// Set cls_image_shape for the classification preprocess
void SetClsImageShape(const std::vector<int>& cls_image_shape) { void SetClsImageShape(const std::vector<int>& cls_image_shape) {
@@ -62,14 +68,18 @@ class FASTDEPLOY_DECL ClassifierPreprocessor {
void DisablePermute() { disable_normalize_ = true; } void DisablePermute() { disable_normalize_ = true; }
private: private:
void OcrClassifierResizeImage(FDMat* mat,
const std::vector<int>& cls_image_shape);
// for recording the switch of hwc2chw // for recording the switch of hwc2chw
bool disable_permute_ = false; bool disable_permute_ = false;
// for recording the switch of normalize // for recording the switch of normalize
bool disable_normalize_ = false; bool disable_normalize_ = false;
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
bool is_scale_ = true;
std::vector<int> cls_image_shape_ = {3, 48, 192}; std::vector<int> cls_image_shape_ = {3, 48, 192};
std::shared_ptr<Resize> resize_op_;
std::shared_ptr<Pad> pad_op_;
std::shared_ptr<Normalize> normalize_op_;
std::shared_ptr<HWC2CHW> hwc2chw_op_;
}; };
} // namespace ocr } // namespace ocr

View File

@@ -55,11 +55,9 @@ DBDetectorPreprocessor::DBDetectorPreprocessor() {
std::vector<float> value = {0, 0, 0}; std::vector<float> value = {0, 0, 0};
pad_op_ = std::make_shared<Pad>(0, 0, 0, 0, value); pad_op_ = std::make_shared<Pad>(0, 0, 0, 0, value);
std::vector<float> mean = {0.485f, 0.456f, 0.406f}; normalize_permute_op_ = std::make_shared<NormalizeAndPermute>(
std::vector<float> std = {0.229f, 0.224f, 0.225f}; std::vector<float>({0.485f, 0.456f, 0.406f}),
bool is_scale = true; std::vector<float>({0.229f, 0.224f, 0.225f}), true);
normalize_permute_op_ =
std::make_shared<NormalizeAndPermute>(mean, std, is_scale);
} }
bool DBDetectorPreprocessor::ResizeImage(FDMat* img, int resize_w, int resize_h, bool DBDetectorPreprocessor::ResizeImage(FDMat* img, int resize_w, int resize_h,

View File

@@ -46,9 +46,9 @@ class FASTDEPLOY_DECL DBDetectorPreprocessor : public ProcessorManager {
/// Set preprocess normalize parameters, please call this API to customize /// Set preprocess normalize parameters, please call this API to customize
/// the normalize parameters, otherwise it will use the default normalize /// the normalize parameters, otherwise it will use the default normalize
/// parameters. /// parameters.
void SetNormalize(const std::vector<float>& mean = {0.485f, 0.456f, 0.406f}, void SetNormalize(const std::vector<float>& mean,
const std::vector<float>& std = {0.229f, 0.224f, 0.225f}, const std::vector<float>& std,
bool is_scale = true) { bool is_scale) {
normalize_permute_op_ = normalize_permute_op_ =
std::make_shared<NormalizeAndPermute>(mean, std, is_scale); std::make_shared<NormalizeAndPermute>(mean, std, is_scale);
} }

View File

@@ -23,8 +23,8 @@ void BindPPOCRModel(pybind11::module& m) {
}); });
// DBDetector // DBDetector
pybind11::class_<vision::ocr::DBDetectorPreprocessor>( pybind11::class_<vision::ocr::DBDetectorPreprocessor,
m, "DBDetectorPreprocessor") vision::ProcessorManager>(m, "DBDetectorPreprocessor")
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_property("static_shape_infer", .def_property("static_shape_infer",
&vision::ocr::DBDetectorPreprocessor::GetStaticShapeInfer, &vision::ocr::DBDetectorPreprocessor::GetStaticShapeInfer,
@@ -133,19 +133,16 @@ void BindPPOCRModel(pybind11::module& m) {
}); });
// Classifier // Classifier
pybind11::class_<vision::ocr::ClassifierPreprocessor>( pybind11::class_<vision::ocr::ClassifierPreprocessor,
m, "ClassifierPreprocessor") vision::ProcessorManager>(m, "ClassifierPreprocessor")
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_property("cls_image_shape", .def_property("cls_image_shape",
&vision::ocr::ClassifierPreprocessor::GetClsImageShape, &vision::ocr::ClassifierPreprocessor::GetClsImageShape,
&vision::ocr::ClassifierPreprocessor::SetClsImageShape) &vision::ocr::ClassifierPreprocessor::SetClsImageShape)
.def_property("mean", &vision::ocr::ClassifierPreprocessor::GetMean, .def("set_normalize",
&vision::ocr::ClassifierPreprocessor::SetMean) [](vision::ocr::ClassifierPreprocessor& self,
.def_property("scale", &vision::ocr::ClassifierPreprocessor::GetScale, const std::vector<float>& mean, const std::vector<float>& std,
&vision::ocr::ClassifierPreprocessor::SetScale) bool is_scale) { self.SetNormalize(mean, std, is_scale); })
.def_property("is_scale",
&vision::ocr::ClassifierPreprocessor::GetIsScale,
&vision::ocr::ClassifierPreprocessor::SetIsScale)
.def("run", .def("run",
[](vision::ocr::ClassifierPreprocessor& self, [](vision::ocr::ClassifierPreprocessor& self,
std::vector<pybind11::array>& im_list) { std::vector<pybind11::array>& im_list) {
@@ -233,8 +230,8 @@ void BindPPOCRModel(pybind11::module& m) {
}); });
// Recognizer // Recognizer
pybind11::class_<vision::ocr::RecognizerPreprocessor>( pybind11::class_<vision::ocr::RecognizerPreprocessor,
m, "RecognizerPreprocessor") vision::ProcessorManager>(m, "RecognizerPreprocessor")
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_property("static_shape_infer", .def_property("static_shape_infer",
&vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer, &vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer,
@@ -242,13 +239,10 @@ void BindPPOCRModel(pybind11::module& m) {
.def_property("rec_image_shape", .def_property("rec_image_shape",
&vision::ocr::RecognizerPreprocessor::GetRecImageShape, &vision::ocr::RecognizerPreprocessor::GetRecImageShape,
&vision::ocr::RecognizerPreprocessor::SetRecImageShape) &vision::ocr::RecognizerPreprocessor::SetRecImageShape)
.def_property("mean", &vision::ocr::RecognizerPreprocessor::GetMean, .def("set_normalize",
&vision::ocr::RecognizerPreprocessor::SetMean) [](vision::ocr::RecognizerPreprocessor& self,
.def_property("scale", &vision::ocr::RecognizerPreprocessor::GetScale, const std::vector<float>& mean, const std::vector<float>& std,
&vision::ocr::RecognizerPreprocessor::SetScale) bool is_scale) { self.SetNormalize(mean, std, is_scale); })
.def_property("is_scale",
&vision::ocr::RecognizerPreprocessor::GetIsScale,
&vision::ocr::RecognizerPreprocessor::SetIsScale)
.def("run", .def("run",
[](vision::ocr::RecognizerPreprocessor& self, [](vision::ocr::RecognizerPreprocessor& self,
std::vector<pybind11::array>& im_list) { std::vector<pybind11::array>& im_list) {

View File

@@ -22,8 +22,23 @@ namespace fastdeploy {
namespace vision { namespace vision {
namespace ocr { namespace ocr {
void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio, RecognizerPreprocessor::RecognizerPreprocessor() {
const std::vector<int>& rec_image_shape, resize_op_ = std::make_shared<Resize>(-1, -1);
std::vector<float> value = {127, 127, 127};
pad_op_ = std::make_shared<Pad>(0, 0, 0, 0, value);
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> std = {0.5f, 0.5f, 0.5f};
normalize_permute_op_ =
std::make_shared<NormalizeAndPermute>(mean, std, true);
normalize_op_ = std::make_shared<Normalize>(mean, std, true);
hwc2chw_op_ = std::make_shared<HWC2CHW>();
cast_op_ = std::make_shared<Cast>("float");
}
void RecognizerPreprocessor::OcrRecognizerResizeImage(
FDMat* mat, float max_wh_ratio, const std::vector<int>& rec_image_shape,
bool static_shape_infer) { bool static_shape_infer) {
int img_h, img_w; int img_h, img_w;
img_h = rec_image_shape[1]; img_h = rec_image_shape[1];
@@ -39,25 +54,25 @@ void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio,
} else { } else {
resize_w = int(ceilf(img_h * ratio)); resize_w = int(ceilf(img_h * ratio));
} }
Resize::Run(mat, resize_w, img_h); resize_op_->SetWidthAndHeight(resize_w, img_h);
Pad::Run(mat, 0, 0, 0, int(img_w - mat->Width()), {127, 127, 127}); (*resize_op_)(mat);
pad_op_->SetPaddingSize(0, 0, 0, int(img_w - mat->Width()));
(*pad_op_)(mat);
} else { } else {
if (mat->Width() >= img_w) { if (mat->Width() >= img_w) {
Resize::Run(mat, img_w, img_h); // Reszie W to 320 // Reszie W to 320
resize_op_->SetWidthAndHeight(img_w, img_h);
(*resize_op_)(mat);
} else { } else {
Resize::Run(mat, mat->Width(), img_h); resize_op_->SetWidthAndHeight(mat->Width(), img_h);
Pad::Run(mat, 0, 0, 0, int(img_w - mat->Width()), {127, 127, 127}); (*resize_op_)(mat);
// Pad to 320 // Pad to 320
pad_op_->SetPaddingSize(0, 0, 0, int(img_w - mat->Width()));
(*pad_op_)(mat);
} }
} }
} }
bool RecognizerPreprocessor::Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs) {
return Run(images, outputs, 0, images->size(), {});
}
bool RecognizerPreprocessor::Run(std::vector<FDMat>* images, bool RecognizerPreprocessor::Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs, std::vector<FDTensor>* outputs,
size_t start_index, size_t end_index, size_t start_index, size_t end_index,
@@ -70,60 +85,55 @@ bool RecognizerPreprocessor::Run(std::vector<FDMat>* images,
return false; return false;
} }
std::vector<FDMat> mats(end_index - start_index);
for (size_t i = start_index; i < end_index; ++i) {
size_t real_index = i;
if (indices.size() != 0) {
real_index = indices[i];
}
mats[i - start_index] = images->at(real_index);
}
return Run(&mats, outputs);
}
bool RecognizerPreprocessor::Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs) {
int img_h = rec_image_shape_[1]; int img_h = rec_image_shape_[1];
int img_w = rec_image_shape_[2]; int img_w = rec_image_shape_[2];
float max_wh_ratio = img_w * 1.0 / img_h; float max_wh_ratio = img_w * 1.0 / img_h;
float ori_wh_ratio; float ori_wh_ratio;
for (size_t i = start_index; i < end_index; ++i) { for (size_t i = 0; i < image_batch->mats->size(); ++i) {
size_t real_index = i; FDMat* mat = &(image_batch->mats->at(i));
if (indices.size() != 0) {
real_index = indices[i];
}
FDMat* mat = &(images->at(real_index));
ori_wh_ratio = mat->Width() * 1.0 / mat->Height(); ori_wh_ratio = mat->Width() * 1.0 / mat->Height();
max_wh_ratio = std::max(max_wh_ratio, ori_wh_ratio); max_wh_ratio = std::max(max_wh_ratio, ori_wh_ratio);
} }
for (size_t i = start_index; i < end_index; ++i) { for (size_t i = 0; i < image_batch->mats->size(); ++i) {
size_t real_index = i; FDMat* mat = &(image_batch->mats->at(i));
if (indices.size() != 0) {
real_index = indices[i];
}
FDMat* mat = &(images->at(real_index));
OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_, OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_,
static_shape_infer_); static_shape_infer_);
if (!disable_normalize_ && !disable_permute_) {
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
} else {
if (!disable_normalize_) {
Normalize::Run(mat, mean_, scale_, is_scale_);
}
if (!disable_permute_) {
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
}
}
}
// Only have 1 output Tensor.
outputs->resize(1);
size_t tensor_size = end_index - start_index;
// Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> tensors(tensor_size);
for (size_t i = 0; i < tensor_size; ++i) {
size_t real_index = i + start_index;
if (indices.size() != 0) {
real_index = indices[i + start_index];
} }
(*images)[real_index].ShareWithTensor(&(tensors[i])); if (!disable_normalize_ && !disable_permute_) {
tensors[i].ExpandDim(0); (*normalize_permute_op_)(image_batch);
}
if (tensors.size() == 1) {
(*outputs)[0] = std::move(tensors[0]);
} else { } else {
function::Concat(tensors, &((*outputs)[0]), 0); if (!disable_normalize_) {
(*normalize_op_)(image_batch);
} }
if (!disable_permute_) {
(*hwc2chw_op_)(image_batch);
(*cast_op_)(image_batch);
}
}
// Only have 1 output Tensor.
outputs->resize(1);
// Get the NCHW tensor
FDTensor* tensor = image_batch->Tensor();
(*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
tensor->Data(), tensor->device,
tensor->device_id);
return true; return true;
} }

View File

@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/processors/manager.h"
#include "fastdeploy/vision/common/result.h" #include "fastdeploy/vision/common/result.h"
namespace fastdeploy { namespace fastdeploy {
@@ -22,19 +23,30 @@ namespace vision {
namespace ocr { namespace ocr {
/*! @brief Preprocessor object for PaddleClas serials model. /*! @brief Preprocessor object for PaddleClas serials model.
*/ */
class FASTDEPLOY_DECL RecognizerPreprocessor { class FASTDEPLOY_DECL RecognizerPreprocessor : public ProcessorManager {
public: public:
RecognizerPreprocessor();
using ProcessorManager::Run;
/** \brief Process the input image and prepare input tensors for runtime /** \brief Process the input image and prepare input tensors for runtime
* *
* \param[in] images The input data list, all the elements are FDMat * \param[in] images The input data list, all the elements are FDMat
* \param[in] outputs The output tensors which will be fed into runtime * \param[in] outputs The output tensors which will be fed into runtime
* \return true if the preprocess successed, otherwise false * \return true if the preprocess successed, otherwise false
*/ */
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs, bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
size_t start_index, size_t end_index, size_t start_index, size_t end_index,
const std::vector<int>& indices); const std::vector<int>& indices);
/** \brief Implement the virtual function of ProcessorManager, Apply() is the
* body of Run(). Apply() contains the main logic of preprocessing, Run() is
* called by users to execute preprocessing
*
* \param[in] image_batch The input image batch
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
virtual bool Apply(FDMatBatch* image_batch, std::vector<FDTensor>* outputs);
/// Set static_shape_infer is true or not. When deploy PP-OCR /// Set static_shape_infer is true or not. When deploy PP-OCR
/// on hardware which can not support dynamic input shape very well, /// on hardware which can not support dynamic input shape very well,
/// like Huawei Ascned, static_shape_infer needs to to be true. /// like Huawei Ascned, static_shape_infer needs to to be true.
@@ -44,20 +56,16 @@ class FASTDEPLOY_DECL RecognizerPreprocessor {
/// Get static_shape_infer of the recognition preprocess /// Get static_shape_infer of the recognition preprocess
bool GetStaticShapeInfer() const { return static_shape_infer_; } bool GetStaticShapeInfer() const { return static_shape_infer_; }
/// Set mean value for the image normalization in recognition preprocess /// Set preprocess normalize parameters, please call this API to customize
void SetMean(const std::vector<float>& mean) { mean_ = mean; } /// the normalize parameters, otherwise it will use the default normalize
/// Get mean value of the image normalization in recognition preprocess /// parameters.
std::vector<float> GetMean() const { return mean_; } void SetNormalize(const std::vector<float>& mean,
const std::vector<float>& std,
/// Set scale value for the image normalization in recognition preprocess bool is_scale) {
void SetScale(const std::vector<float>& scale) { scale_ = scale; } normalize_permute_op_ =
/// Get scale value of the image normalization in recognition preprocess std::make_shared<NormalizeAndPermute>(mean, std, is_scale);
std::vector<float> GetScale() const { return scale_; } normalize_op_ = std::make_shared<Normalize>(mean, std, is_scale);
}
/// Set is_scale for the image normalization in recognition preprocess
void SetIsScale(bool is_scale) { is_scale_ = is_scale; }
/// Get is_scale of the image normalization in recognition preprocess
bool GetIsScale() const { return is_scale_; }
/// Set rec_image_shape for the recognition preprocess /// Set rec_image_shape for the recognition preprocess
void SetRecImageShape(const std::vector<int>& rec_image_shape) { void SetRecImageShape(const std::vector<int>& rec_image_shape) {
@@ -72,15 +80,21 @@ class FASTDEPLOY_DECL RecognizerPreprocessor {
void DisablePermute() { disable_normalize_ = true; } void DisablePermute() { disable_normalize_ = true; }
private: private:
void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio,
const std::vector<int>& rec_image_shape,
bool static_shape_infer);
// for recording the switch of hwc2chw // for recording the switch of hwc2chw
bool disable_permute_ = false; bool disable_permute_ = false;
// for recording the switch of normalize // for recording the switch of normalize
bool disable_normalize_ = false; bool disable_normalize_ = false;
std::vector<int> rec_image_shape_ = {3, 48, 320}; std::vector<int> rec_image_shape_ = {3, 48, 320};
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
bool is_scale_ = true;
bool static_shape_infer_ = false; bool static_shape_infer_ = false;
std::shared_ptr<Resize> resize_op_;
std::shared_ptr<Pad> pad_op_;
std::shared_ptr<NormalizeAndPermute> normalize_permute_op_;
std::shared_ptr<Normalize> normalize_op_;
std::shared_ptr<HWC2CHW> hwc2chw_op_;
std::shared_ptr<Cast> cast_op_;
}; };
} // namespace ocr } // namespace ocr

View File

@@ -52,10 +52,7 @@ class DBDetectorPreprocessor:
value, int), "The value to set `max_side_len` must be type of int." value, int), "The value to set `max_side_len` must be type of int."
self._preprocessor.max_side_len = value self._preprocessor.max_side_len = value
def set_normalize(self, def set_normalize(self, mean, std, is_scale):
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True):
"""Set preprocess normalize parameters, please call this API to """Set preprocess normalize parameters, please call this API to
customize the normalize parameters, otherwise it will use the default customize the normalize parameters, otherwise it will use the default
normalize parameters. normalize parameters.
@@ -340,35 +337,15 @@ class ClassifierPreprocessor:
""" """
return self._preprocessor.run(input_ims) return self._preprocessor.run(input_ims)
@property def set_normalize(self, mean, std, is_scale):
def is_scale(self): """Set preprocess normalize parameters, please call this API to
return self._preprocessor.is_scale customize the normalize parameters, otherwise it will use the default
normalize parameters.
@is_scale.setter :param: mean: (list of float) mean values
def is_scale(self, value): :param: std: (list of float) std values
assert isinstance( :param: is_scale: (boolean) whether to scale
value, bool), "The value to set `is_scale` must be type of bool." """
self._preprocessor.is_scale = value self._preprocessor.set_normalize(mean, std, is_scale)
@property
def scale(self):
return self._preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._preprocessor.scale = value
@property
def mean(self):
return self._preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._preprocessor.mean = value
@property @property
def cls_image_shape(self): def cls_image_shape(self):
@@ -496,37 +473,6 @@ class Classifier(FastDeployModel):
def postprocessor(self, value): def postprocessor(self, value):
self._model.postprocessor = value self._model.postprocessor = value
# Cls Preprocessor Property
@property
def is_scale(self):
return self._model.preprocessor.is_scale
@is_scale.setter
def is_scale(self, value):
assert isinstance(
value, bool), "The value to set `is_scale` must be type of bool."
self._model.preprocessor.is_scale = value
@property
def scale(self):
return self._model.preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._model.preprocessor.scale = value
@property
def mean(self):
return self._model.preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._model.preprocessor.mean = value
@property @property
def cls_image_shape(self): def cls_image_shape(self):
return self._model.preprocessor.cls_image_shape return self._model.preprocessor.cls_image_shape
@@ -575,35 +521,15 @@ class RecognizerPreprocessor:
bool), "The value to set `static_shape_infer` must be type of bool." bool), "The value to set `static_shape_infer` must be type of bool."
self._preprocessor.static_shape_infer = value self._preprocessor.static_shape_infer = value
@property def set_normalize(self, mean, std, is_scale):
def is_scale(self): """Set preprocess normalize parameters, please call this API to
return self._preprocessor.is_scale customize the normalize parameters, otherwise it will use the default
normalize parameters.
@is_scale.setter :param: mean: (list of float) mean values
def is_scale(self, value): :param: std: (list of float) std values
assert isinstance( :param: is_scale: (boolean) whether to scale
value, bool), "The value to set `is_scale` must be type of bool." """
self._preprocessor.is_scale = value self._preprocessor.set_normalize(mean, std, is_scale)
@property
def scale(self):
return self._preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._preprocessor.scale = value
@property
def mean(self):
return self._preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._preprocessor.mean = value
@property @property
def rec_image_shape(self): def rec_image_shape(self):
@@ -728,36 +654,6 @@ class Recognizer(FastDeployModel):
bool), "The value to set `static_shape_infer` must be type of bool." bool), "The value to set `static_shape_infer` must be type of bool."
self._model.preprocessor.static_shape_infer = value self._model.preprocessor.static_shape_infer = value
@property
def is_scale(self):
return self._model.preprocessor.is_scale
@is_scale.setter
def is_scale(self, value):
assert isinstance(
value, bool), "The value to set `is_scale` must be type of bool."
self._model.preprocessor.is_scale = value
@property
def scale(self):
return self._model.preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._model.preprocessor.scale = value
@property
def mean(self):
return self._model.preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._model.preprocessor.mean = value
@property @property
def rec_image_shape(self): def rec_image_shape(self):
return self._model.preprocessor.rec_image_shape return self._model.preprocessor.rec_image_shape