diff --git a/fastdeploy/vision/common/processors/cvcuda_utils.cc b/fastdeploy/vision/common/processors/cvcuda_utils.cc index c7d25361b..ff0d5e3ba 100644 --- a/fastdeploy/vision/common/processors/cvcuda_utils.cc +++ b/fastdeploy/vision/common/processors/cvcuda_utils.cc @@ -18,34 +18,36 @@ namespace fastdeploy { namespace vision { #ifdef ENABLE_CVCUDA -nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel) { +nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel, + bool interleaved) { FDASSERT(channel == 1 || channel == 3 || channel == 4, "Only support channel be 1/3/4 in CV-CUDA."); if (type == FDDataType::UINT8) { if (channel == 1) { return nvcv::FMT_U8; } else if (channel == 3) { - return nvcv::FMT_BGR8; + return (interleaved ? nvcv::FMT_BGR8 : nvcv::FMT_BGR8p); } else { - return nvcv::FMT_BGRA8; + return (interleaved ? nvcv::FMT_BGRA8 : nvcv::FMT_BGRA8p); } } else if (type == FDDataType::FP32) { if (channel == 1) { return nvcv::FMT_F32; } else if (channel == 3) { - return nvcv::FMT_BGRf32; + return (interleaved ? nvcv::FMT_BGRf32 : nvcv::FMT_BGRf32p); } else { - return nvcv::FMT_BGRAf32; + return (interleaved ? nvcv::FMT_BGRAf32 : nvcv::FMT_BGRAf32p); } } FDASSERT(false, "Data type of %s is not supported.", Str(type).c_str()); return nvcv::FMT_BGRf32; } -nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor) { +nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor, + Layout layout) { FDASSERT(tensor.shape.size() == 3, "When create CVCUDA tensor from FD tensor," - "tensor shape should be 3-Dim, HWC layout"); + "tensor shape should be 3-Dim,"); int batchsize = 1; int h = tensor.Shape()[0]; int w = tensor.Shape()[1]; @@ -56,10 +58,20 @@ nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor) { buf.strides[2] = c * buf.strides[3]; buf.strides[1] = w * buf.strides[2]; buf.strides[0] = h * buf.strides[1]; + if (layout == Layout::CHW) { + c = tensor.Shape()[0]; + h = tensor.Shape()[1]; + w = tensor.Shape()[2]; + buf.strides[3] = FDDataTypeSize(tensor.Dtype()); + buf.strides[2] = w * buf.strides[3]; + buf.strides[1] = h * buf.strides[2]; + buf.strides[0] = c * buf.strides[1]; + } buf.basePtr = reinterpret_cast(const_cast(tensor.Data())); nvcv::Tensor::Requirements req = nvcv::Tensor::CalcRequirements( - batchsize, {w, h}, CreateCvCudaImageFormat(tensor.Dtype(), c)); + batchsize, {w, h}, + CreateCvCudaImageFormat(tensor.Dtype(), c, layout == Layout::HWC)); nvcv::TensorDataStridedCuda tensor_data( nvcv::TensorShape{req.shape, req.rank, req.layout}, diff --git a/fastdeploy/vision/common/processors/cvcuda_utils.h b/fastdeploy/vision/common/processors/cvcuda_utils.h index 60971ec49..2c84d073d 100644 --- a/fastdeploy/vision/common/processors/cvcuda_utils.h +++ b/fastdeploy/vision/common/processors/cvcuda_utils.h @@ -15,6 +15,7 @@ #pragma once #include "fastdeploy/core/fd_tensor.h" +#include "fastdeploy/vision/common/processors/mat.h" #ifdef ENABLE_CVCUDA #include "nvcv/Tensor.hpp" @@ -23,8 +24,10 @@ namespace fastdeploy { namespace vision { -nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel); -nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor); +nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel, + bool interleaved = true); +nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor, + Layout layout = Layout::HWC); void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor); nvcv::ImageWrapData CreateImageWrapData(const FDTensor& tensor); void CreateCvCudaImageBatchVarShape(std::vector& tensors, diff --git a/fastdeploy/vision/common/processors/hwc2chw.cc b/fastdeploy/vision/common/processors/hwc2chw.cc index 9db5c09ff..9c45b396e 100644 --- a/fastdeploy/vision/common/processors/hwc2chw.cc +++ b/fastdeploy/vision/common/processors/hwc2chw.cc @@ -63,6 +63,26 @@ bool HWC2CHW::ImplByFlyCV(Mat* mat) { } #endif +#ifdef ENABLE_CVCUDA +bool HWC2CHW::ImplByCvCuda(FDMat* mat) { + // Prepare input tensor + FDTensor* src = CreateCachedGpuInputTensor(mat); + auto src_tensor = CreateCvCudaTensorWrapData(*src); + + // Prepare output tensor + mat->output_cache->Resize({mat->Channels(), mat->Height(), mat->Width()}, + src->Dtype(), "output_cache", Device::GPU); + auto dst_tensor = + CreateCvCudaTensorWrapData(*(mat->output_cache), Layout::CHW); + + cvcuda_reformat_op_(mat->Stream(), src_tensor, dst_tensor); + + mat->SetTensor(mat->output_cache); + mat->mat_type = ProcLib::CVCUDA; + return true; +} +#endif + bool HWC2CHW::Run(Mat* mat, ProcLib lib) { auto h = HWC2CHW(); return h(mat, lib); diff --git a/fastdeploy/vision/common/processors/hwc2chw.h b/fastdeploy/vision/common/processors/hwc2chw.h index 535a1887b..f68b0c3fb 100644 --- a/fastdeploy/vision/common/processors/hwc2chw.h +++ b/fastdeploy/vision/common/processors/hwc2chw.h @@ -15,6 +15,11 @@ #pragma once #include "fastdeploy/vision/common/processors/base.h" +#ifdef ENABLE_CVCUDA +#include + +#include "fastdeploy/vision/common/processors/cvcuda_utils.h" +#endif namespace fastdeploy { namespace vision { @@ -24,10 +29,17 @@ class FASTDEPLOY_DECL HWC2CHW : public Processor { bool ImplByOpenCV(Mat* mat); #ifdef ENABLE_FLYCV bool ImplByFlyCV(Mat* mat); +#endif +#ifdef ENABLE_CVCUDA + bool ImplByCvCuda(FDMat* mat); #endif std::string Name() { return "HWC2CHW"; } static bool Run(Mat* mat, ProcLib lib = ProcLib::DEFAULT); + private: +#ifdef ENABLE_CVCUDA + cvcuda::Reformat cvcuda_reformat_op_; +#endif }; } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/manager.cc b/fastdeploy/vision/common/processors/manager.cc index 2f751ab80..05167e6fe 100644 --- a/fastdeploy/vision/common/processors/manager.cc +++ b/fastdeploy/vision/common/processors/manager.cc @@ -73,6 +73,7 @@ bool ProcessorManager::Run(std::vector* images, } (*images)[i].input_cache = &input_caches_[i]; (*images)[i].output_cache = &output_caches_[i]; + (*images)[i].proc_lib = proc_lib_; if ((*images)[i].mat_type == ProcLib::CUDA) { // Make a copy of the input data ptr, so that the original data ptr of // FDMat won't be modified. diff --git a/fastdeploy/vision/common/processors/mat.cc b/fastdeploy/vision/common/processors/mat.cc index b78f57436..da1d72ccb 100644 --- a/fastdeploy/vision/common/processors/mat.cc +++ b/fastdeploy/vision/common/processors/mat.cc @@ -272,6 +272,9 @@ std::vector WrapMat(const std::vector& images) { } bool CheckShapeConsistency(std::vector* mats) { + if (mats == nullptr) { + return true; + } for (size_t i = 1; i < mats->size(); ++i) { if ((*mats)[i].Channels() != (*mats)[0].Channels() || (*mats)[i].Width() != (*mats)[0].Width() || @@ -285,21 +288,24 @@ bool CheckShapeConsistency(std::vector* mats) { FDTensor* CreateCachedGpuInputTensor(Mat* mat) { #ifdef WITH_GPU FDTensor* src = mat->Tensor(); + // Need to make sure the tensor is pointed to the input_cache. + if (src->Data() == mat->output_cache->Data()) { + std::swap(mat->input_cache, mat->output_cache); + std::swap(mat->input_cache->name, mat->output_cache->name); + } if (src->device == Device::GPU) { - if (src->Data() == mat->output_cache->Data()) { - std::swap(mat->input_cache, mat->output_cache); - std::swap(mat->input_cache->name, mat->output_cache->name); - } return src; } else if (src->device == Device::CPU) { - // Mats on CPU, we need copy these tensors from CPU to GPU + // Tensor on CPU, we need copy it from CPU to GPU FDASSERT(src->Shape().size() == 3, "The CPU tensor must has 3 dims.") - mat->input_cache->Resize(src->Shape(), src->Dtype(), "input_cache", - Device::GPU); + mat->output_cache->Resize(src->Shape(), src->Dtype(), "output_cache", + Device::GPU); FDASSERT( - cudaMemcpyAsync(mat->input_cache->Data(), src->Data(), src->Nbytes(), + cudaMemcpyAsync(mat->output_cache->Data(), src->Data(), src->Nbytes(), cudaMemcpyHostToDevice, mat->Stream()) == 0, "[ERROR] Error occurs while copy memory from CPU to GPU."); + std::swap(mat->input_cache, mat->output_cache); + std::swap(mat->input_cache->name, mat->output_cache->name); return mat->input_cache; } else { FDASSERT(false, "FDMat is on unsupported device: %d", src->device); diff --git a/fastdeploy/vision/common/processors/mat_batch.cc b/fastdeploy/vision/common/processors/mat_batch.cc index aa154f334..d61c0b704 100644 --- a/fastdeploy/vision/common/processors/mat_batch.cc +++ b/fastdeploy/vision/common/processors/mat_batch.cc @@ -29,10 +29,12 @@ FDTensor* FDMatBatch::Tensor() { if (has_batched_tensor) { return fd_tensor.get(); } - FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.") + FDASSERT(mats != nullptr, "Failed to get batched tensor, Mats are empty."); + FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent."); // Each mat has its own tensor, // to get a batched tensor, we need copy these tensors to a batched tensor FDTensor* src = (*mats)[0].Tensor(); + device = src->device; auto new_shape = src->Shape(); new_shape.insert(new_shape.begin(), mats->size()); input_cache->Resize(new_shape, src->Dtype(), "batch_input_cache", device); @@ -51,26 +53,34 @@ FDTensor* FDMatBatch::Tensor() { void FDMatBatch::SetTensor(FDTensor* tensor) { fd_tensor->SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(), tensor->device, tensor->device_id); + device = tensor->device; has_batched_tensor = true; } FDTensor* CreateCachedGpuInputTensor(FDMatBatch* mat_batch) { #ifdef WITH_GPU - auto mats = mat_batch->mats; - FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.") - FDTensor* src = (*mats)[0].Tensor(); - if (mat_batch->device == Device::GPU) { - return mat_batch->Tensor(); - } else if (mat_batch->device == Device::CPU) { - // Mats on CPU, we need copy them to GPU and then get a batched GPU tensor - for (size_t i = 0; i < mats->size(); ++i) { - FDTensor* tensor = CreateCachedGpuInputTensor(&(*mats)[i]); - (*mats)[i].SetTensor(tensor); - } - mat_batch->device = Device::GPU; - return mat_batch->Tensor(); + // Get the batched tensor + FDTensor* src = mat_batch->Tensor(); + // Need to make sure the returned tensor is pointed to the input_cache. + if (src->Data() == mat_batch->output_cache->Data()) { + std::swap(mat_batch->input_cache, mat_batch->output_cache); + std::swap(mat_batch->input_cache->name, mat_batch->output_cache->name); + } + if (src->device == Device::GPU) { + return src; + } else if (src->device == Device::CPU) { + // Batched tensor on CPU, we need copy it to GPU + mat_batch->output_cache->Resize(src->Shape(), src->Dtype(), "output_cache", + Device::GPU); + FDASSERT(cudaMemcpyAsync(mat_batch->output_cache->Data(), src->Data(), + src->Nbytes(), cudaMemcpyHostToDevice, + mat_batch->Stream()) == 0, + "[ERROR] Error occurs while copy memory from CPU to GPU."); + std::swap(mat_batch->input_cache, mat_batch->output_cache); + std::swap(mat_batch->input_cache->name, mat_batch->output_cache->name); + return mat_batch->input_cache; } else { - FDASSERT(false, "FDMat is on unsupported device: %d", src->device); + FDASSERT(false, "FDMatBatch is on unsupported device: %d", src->device); } #else FDASSERT(false, "FastDeploy didn't compile with WITH_GPU."); diff --git a/fastdeploy/vision/common/processors/mat_batch.h b/fastdeploy/vision/common/processors/mat_batch.h index 9d876a911..cd24b4a80 100644 --- a/fastdeploy/vision/common/processors/mat_batch.h +++ b/fastdeploy/vision/common/processors/mat_batch.h @@ -56,7 +56,7 @@ struct FASTDEPLOY_DECL FDMatBatch { void SetStream(cudaStream_t s); #endif - std::vector* mats; + std::vector* mats = nullptr; ProcLib mat_type = ProcLib::OPENCV; FDMatBatchLayout layout = FDMatBatchLayout::NHWC; Device device = Device::CPU; diff --git a/fastdeploy/vision/common/processors/normalize.cu b/fastdeploy/vision/common/processors/normalize.cu new file mode 100644 index 000000000..e4cd7c100 --- /dev/null +++ b/fastdeploy/vision/common/processors/normalize.cu @@ -0,0 +1,116 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef WITH_GPU +#include "fastdeploy/vision/common/processors/normalize.h" + +namespace fastdeploy { +namespace vision { + +__global__ void NormalizeKernel(const uint8_t* src, float* dst, + const float* alpha, const float* beta, + int num_channel, bool swap_rb, int batch_size, + int edge) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx >= edge) return; + + int img_size = edge / batch_size; + int n = idx / img_size; // batch index + int p = idx - (n * img_size); // pixel index within the image + + for (int i = 0; i < num_channel; ++i) { + int j = i; + if (swap_rb) { + j = 2 - i; + } + dst[num_channel * idx + j] = + src[num_channel * idx + j] * alpha[i] + beta[i]; + } +} + +bool Normalize::ImplByCuda(FDMat* mat) { + if (mat->layout != Layout::HWC) { + FDERROR << "The input data must be NHWC format!" << std::endl; + return false; + } + + // Prepare input tensor + FDTensor* src = CreateCachedGpuInputTensor(mat); + src->ExpandDim(0); + FDMatBatch mat_batch; + mat_batch.SetTensor(src); + mat_batch.mat_type = ProcLib::CUDA; + mat_batch.input_cache = mat->input_cache; + mat_batch.output_cache = mat->output_cache; + + bool ret = ImplByCuda(&mat_batch); + + FDTensor* dst = mat_batch.Tensor(); + dst->Squeeze(0); + mat->SetTensor(dst); + mat->mat_type = ProcLib::CUDA; + return true; +} + +bool Normalize::ImplByCuda(FDMatBatch* mat_batch) { + if (mat_batch->layout != FDMatBatchLayout::NHWC) { + FDERROR << "The input data must be NHWC format!" << std::endl; + return false; + } + // Prepare input tensor + FDTensor* src = CreateCachedGpuInputTensor(mat_batch); + + // Prepare output tensor + mat_batch->output_cache->Resize(src->Shape(), FDDataType::FP32, + "batch_output_cache", Device::GPU); + + // Copy alpha and beta to GPU + gpu_alpha_.Resize({1, 1, static_cast(alpha_.size())}, FDDataType::FP32, + "alpha", Device::GPU); + cudaMemcpy(gpu_alpha_.Data(), alpha_.data(), gpu_alpha_.Nbytes(), + cudaMemcpyHostToDevice); + + gpu_beta_.Resize({1, 1, static_cast(beta_.size())}, FDDataType::FP32, + "beta", Device::GPU); + cudaMemcpy(gpu_beta_.Data(), beta_.data(), gpu_beta_.Nbytes(), + cudaMemcpyHostToDevice); + + int jobs = + mat_batch->output_cache->Numel() / mat_batch->output_cache->shape[3]; + int threads = 256; + int blocks = ceil(jobs / (float)threads); + NormalizeKernel<<Stream()>>>( + reinterpret_cast(src->Data()), + reinterpret_cast(mat_batch->output_cache->Data()), + reinterpret_cast(gpu_alpha_.Data()), + reinterpret_cast(gpu_beta_.Data()), + mat_batch->output_cache->shape[3], swap_rb_, + mat_batch->output_cache->shape[0], jobs); + + mat_batch->SetTensor(mat_batch->output_cache); + mat_batch->mat_type = ProcLib::CUDA; + return true; +} + +#ifdef ENABLE_CVCUDA +bool Normalize::ImplByCvCuda(FDMat* mat) { return ImplByCuda(mat); } + +bool Normalize::ImplByCvCuda(FDMatBatch* mat_batch) { + return ImplByCuda(mat_batch); +} +#endif + +} // namespace vision +} // namespace fastdeploy +#endif diff --git a/fastdeploy/vision/common/processors/normalize.h b/fastdeploy/vision/common/processors/normalize.h index c489207df..f95cc01a8 100644 --- a/fastdeploy/vision/common/processors/normalize.h +++ b/fastdeploy/vision/common/processors/normalize.h @@ -28,6 +28,14 @@ class FASTDEPLOY_DECL Normalize : public Processor { bool ImplByOpenCV(Mat* mat); #ifdef ENABLE_FLYCV bool ImplByFlyCV(Mat* mat); +#endif +#ifdef WITH_GPU + bool ImplByCuda(FDMat* mat); + bool ImplByCuda(FDMatBatch* mat_batch); +#endif +#ifdef ENABLE_CVCUDA + bool ImplByCvCuda(FDMat* mat); + bool ImplByCvCuda(FDMatBatch* mat_batch); #endif std::string Name() { return "Normalize"; } @@ -61,6 +69,8 @@ class FASTDEPLOY_DECL Normalize : public Processor { private: std::vector alpha_; std::vector beta_; + FDTensor gpu_alpha_; + FDTensor gpu_beta_; bool swap_rb_; }; } // namespace vision diff --git a/fastdeploy/vision/common/processors/pad.cc b/fastdeploy/vision/common/processors/pad.cc index 2db1fba20..044668e12 100644 --- a/fastdeploy/vision/common/processors/pad.cc +++ b/fastdeploy/vision/common/processors/pad.cc @@ -126,7 +126,7 @@ bool Pad::ImplByCvCuda(FDMat* mat) { auto src_tensor = CreateCvCudaTensorWrapData(*src); int height = mat->Height() + top_ + bottom_; - int width = mat->Height() + left_ + right_; + int width = mat->Width() + left_ + right_; // Prepare output tensor mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(), @@ -137,9 +137,6 @@ bool Pad::ImplByCvCuda(FDMat* mat) { NVCV_BORDER_CONSTANT, value); mat->SetTensor(mat->output_cache); - mat->SetWidth(width); - mat->SetHeight(height); - mat->device = Device::GPU; mat->mat_type = ProcLib::CVCUDA; return true; } diff --git a/fastdeploy/vision/ocr/ppocr/cls_preprocessor.cc b/fastdeploy/vision/ocr/ppocr/cls_preprocessor.cc index 35f98acc9..63d19b6f2 100644 --- a/fastdeploy/vision/ocr/ppocr/cls_preprocessor.cc +++ b/fastdeploy/vision/ocr/ppocr/cls_preprocessor.cc @@ -22,8 +22,20 @@ namespace fastdeploy { namespace vision { namespace ocr { -void OcrClassifierResizeImage(FDMat* mat, - const std::vector& cls_image_shape) { +ClassifierPreprocessor::ClassifierPreprocessor() { + resize_op_ = std::make_shared(-1, -1); + + std::vector value = {0, 0, 0}; + pad_op_ = std::make_shared(0, 0, 0, 0, value); + + normalize_op_ = + std::make_shared(std::vector({0.5f, 0.5f, 0.5f}), + std::vector({0.5f, 0.5f, 0.5f}), true); + hwc2chw_op_ = std::make_shared(); +} + +void ClassifierPreprocessor::OcrClassifierResizeImage( + FDMat* mat, const std::vector& cls_image_shape) { int img_c = cls_image_shape[0]; int img_h = cls_image_shape[1]; int img_w = cls_image_shape[2]; @@ -36,12 +48,8 @@ void OcrClassifierResizeImage(FDMat* mat, else resize_w = int(ceilf(img_h * ratio)); - Resize::Run(mat, resize_w, img_h); -} - -bool ClassifierPreprocessor::Run(std::vector* images, - std::vector* outputs) { - return Run(images, outputs, 0, images->size()); + resize_op_->SetWidthAndHeight(resize_w, img_h); + (*resize_op_)(mat); } bool ClassifierPreprocessor::Run(std::vector* images, @@ -55,36 +63,37 @@ bool ClassifierPreprocessor::Run(std::vector* images, return false; } + std::vector mats(end_index - start_index); for (size_t i = start_index; i < end_index; ++i) { - FDMat* mat = &(images->at(i)); + mats[i - start_index] = images->at(i); + } + return Run(&mats, outputs); +} + +bool ClassifierPreprocessor::Apply(FDMatBatch* image_batch, + std::vector* outputs) { + for (size_t i = 0; i < image_batch->mats->size(); ++i) { + FDMat* mat = &(image_batch->mats->at(i)); OcrClassifierResizeImage(mat, cls_image_shape_); if (!disable_normalize_) { - Normalize::Run(mat, mean_, scale_, is_scale_); + (*normalize_op_)(mat); } std::vector value = {0, 0, 0}; if (mat->Width() < cls_image_shape_[2]) { - Pad::Run(mat, 0, 0, 0, cls_image_shape_[2] - mat->Width(), value); + pad_op_->SetPaddingSize(0, 0, 0, cls_image_shape_[2] - mat->Width()); + (*pad_op_)(mat); } - if (!disable_permute_) { - HWC2CHW::Run(mat); - Cast::Run(mat, "float"); + (*hwc2chw_op_)(mat); } } - // Only have 1 output Tensor. + // Only have 1 output tensor. outputs->resize(1); - // Concat all the preprocessed data to a batch tensor - size_t tensor_size = end_index - start_index; - std::vector tensors(tensor_size); - for (size_t i = 0; i < tensor_size; ++i) { - (*images)[i + start_index].ShareWithTensor(&(tensors[i])); - tensors[i].ExpandDim(0); - } - if (tensors.size() == 1) { - (*outputs)[0] = std::move(tensors[0]); - } else { - function::Concat(tensors, &((*outputs)[0]), 0); - } + // Get the NCHW tensor + FDTensor* tensor = image_batch->Tensor(); + (*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(), + tensor->Data(), tensor->device, + tensor->device_id); return true; } diff --git a/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h b/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h index 921f3f826..15f501a47 100644 --- a/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h @@ -14,6 +14,7 @@ #pragma once #include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/processors/manager.h" #include "fastdeploy/vision/common/result.h" namespace fastdeploy { @@ -22,32 +23,37 @@ namespace vision { namespace ocr { /*! @brief Preprocessor object for Classifier serials model. */ -class FASTDEPLOY_DECL ClassifierPreprocessor { +class FASTDEPLOY_DECL ClassifierPreprocessor : public ProcessorManager { public: + ClassifierPreprocessor(); + using ProcessorManager::Run; /** \brief Process the input image and prepare input tensors for runtime * * \param[in] images The input data list, all the elements are FDMat * \param[in] outputs The output tensors which will be fed into runtime * \return true if the preprocess successed, otherwise false */ - bool Run(std::vector* images, std::vector* outputs); bool Run(std::vector* images, std::vector* outputs, size_t start_index, size_t end_index); - /// Set mean value for the image normalization in classification preprocess - void SetMean(const std::vector& mean) { mean_ = mean; } - /// Get mean value of the image normalization in classification preprocess - std::vector GetMean() const { return mean_; } + /** \brief Implement the virtual function of ProcessorManager, Apply() is the + * body of Run(). Apply() contains the main logic of preprocessing, Run() is + * called by users to execute preprocessing + * + * \param[in] image_batch The input image batch + * \param[in] outputs The output tensors which will feed in runtime + * \return true if the preprocess successed, otherwise false + */ + virtual bool Apply(FDMatBatch* image_batch, std::vector* outputs); - /// Set scale value for the image normalization in classification preprocess - void SetScale(const std::vector& scale) { scale_ = scale; } - /// Get scale value of the image normalization in classification preprocess - std::vector GetScale() const { return scale_; } - - /// Set is_scale for the image normalization in classification preprocess - void SetIsScale(bool is_scale) { is_scale_ = is_scale; } - /// Get is_scale of the image normalization in classification preprocess - bool GetIsScale() const { return is_scale_; } + /// Set preprocess normalize parameters, please call this API to customize + /// the normalize parameters, otherwise it will use the default normalize + /// parameters. + void SetNormalize(const std::vector& mean, + const std::vector& std, + bool is_scale) { + normalize_op_ = std::make_shared(mean, std, is_scale); + } /// Set cls_image_shape for the classification preprocess void SetClsImageShape(const std::vector& cls_image_shape) { @@ -62,14 +68,18 @@ class FASTDEPLOY_DECL ClassifierPreprocessor { void DisablePermute() { disable_normalize_ = true; } private: + void OcrClassifierResizeImage(FDMat* mat, + const std::vector& cls_image_shape); // for recording the switch of hwc2chw bool disable_permute_ = false; // for recording the switch of normalize bool disable_normalize_ = false; - std::vector mean_ = {0.5f, 0.5f, 0.5f}; - std::vector scale_ = {0.5f, 0.5f, 0.5f}; - bool is_scale_ = true; std::vector cls_image_shape_ = {3, 48, 192}; + + std::shared_ptr resize_op_; + std::shared_ptr pad_op_; + std::shared_ptr normalize_op_; + std::shared_ptr hwc2chw_op_; }; } // namespace ocr diff --git a/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc b/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc index 06f47b6ef..b872a581a 100644 --- a/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc +++ b/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc @@ -55,11 +55,9 @@ DBDetectorPreprocessor::DBDetectorPreprocessor() { std::vector value = {0, 0, 0}; pad_op_ = std::make_shared(0, 0, 0, 0, value); - std::vector mean = {0.485f, 0.456f, 0.406f}; - std::vector std = {0.229f, 0.224f, 0.225f}; - bool is_scale = true; - normalize_permute_op_ = - std::make_shared(mean, std, is_scale); + normalize_permute_op_ = std::make_shared( + std::vector({0.485f, 0.456f, 0.406f}), + std::vector({0.229f, 0.224f, 0.225f}), true); } bool DBDetectorPreprocessor::ResizeImage(FDMat* img, int resize_w, int resize_h, diff --git a/fastdeploy/vision/ocr/ppocr/det_preprocessor.h b/fastdeploy/vision/ocr/ppocr/det_preprocessor.h index 32ef80011..e144b1d95 100644 --- a/fastdeploy/vision/ocr/ppocr/det_preprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/det_preprocessor.h @@ -46,9 +46,9 @@ class FASTDEPLOY_DECL DBDetectorPreprocessor : public ProcessorManager { /// Set preprocess normalize parameters, please call this API to customize /// the normalize parameters, otherwise it will use the default normalize /// parameters. - void SetNormalize(const std::vector& mean = {0.485f, 0.456f, 0.406f}, - const std::vector& std = {0.229f, 0.224f, 0.225f}, - bool is_scale = true) { + void SetNormalize(const std::vector& mean, + const std::vector& std, + bool is_scale) { normalize_permute_op_ = std::make_shared(mean, std, is_scale); } diff --git a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc index a1ebd09c6..ad499eb6e 100644 --- a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc +++ b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc @@ -23,8 +23,8 @@ void BindPPOCRModel(pybind11::module& m) { }); // DBDetector - pybind11::class_( - m, "DBDetectorPreprocessor") + pybind11::class_(m, "DBDetectorPreprocessor") .def(pybind11::init<>()) .def_property("static_shape_infer", &vision::ocr::DBDetectorPreprocessor::GetStaticShapeInfer, @@ -133,19 +133,16 @@ void BindPPOCRModel(pybind11::module& m) { }); // Classifier - pybind11::class_( - m, "ClassifierPreprocessor") + pybind11::class_(m, "ClassifierPreprocessor") .def(pybind11::init<>()) .def_property("cls_image_shape", &vision::ocr::ClassifierPreprocessor::GetClsImageShape, &vision::ocr::ClassifierPreprocessor::SetClsImageShape) - .def_property("mean", &vision::ocr::ClassifierPreprocessor::GetMean, - &vision::ocr::ClassifierPreprocessor::SetMean) - .def_property("scale", &vision::ocr::ClassifierPreprocessor::GetScale, - &vision::ocr::ClassifierPreprocessor::SetScale) - .def_property("is_scale", - &vision::ocr::ClassifierPreprocessor::GetIsScale, - &vision::ocr::ClassifierPreprocessor::SetIsScale) + .def("set_normalize", + [](vision::ocr::ClassifierPreprocessor& self, + const std::vector& mean, const std::vector& std, + bool is_scale) { self.SetNormalize(mean, std, is_scale); }) .def("run", [](vision::ocr::ClassifierPreprocessor& self, std::vector& im_list) { @@ -233,8 +230,8 @@ void BindPPOCRModel(pybind11::module& m) { }); // Recognizer - pybind11::class_( - m, "RecognizerPreprocessor") + pybind11::class_(m, "RecognizerPreprocessor") .def(pybind11::init<>()) .def_property("static_shape_infer", &vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer, @@ -242,13 +239,10 @@ void BindPPOCRModel(pybind11::module& m) { .def_property("rec_image_shape", &vision::ocr::RecognizerPreprocessor::GetRecImageShape, &vision::ocr::RecognizerPreprocessor::SetRecImageShape) - .def_property("mean", &vision::ocr::RecognizerPreprocessor::GetMean, - &vision::ocr::RecognizerPreprocessor::SetMean) - .def_property("scale", &vision::ocr::RecognizerPreprocessor::GetScale, - &vision::ocr::RecognizerPreprocessor::SetScale) - .def_property("is_scale", - &vision::ocr::RecognizerPreprocessor::GetIsScale, - &vision::ocr::RecognizerPreprocessor::SetIsScale) + .def("set_normalize", + [](vision::ocr::RecognizerPreprocessor& self, + const std::vector& mean, const std::vector& std, + bool is_scale) { self.SetNormalize(mean, std, is_scale); }) .def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector& im_list) { diff --git a/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc index 59c7de279..ad4e66f31 100644 --- a/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc +++ b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc @@ -22,9 +22,24 @@ namespace fastdeploy { namespace vision { namespace ocr { -void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio, - const std::vector& rec_image_shape, - bool static_shape_infer) { +RecognizerPreprocessor::RecognizerPreprocessor() { + resize_op_ = std::make_shared(-1, -1); + + std::vector value = {127, 127, 127}; + pad_op_ = std::make_shared(0, 0, 0, 0, value); + + std::vector mean = {0.5f, 0.5f, 0.5f}; + std::vector std = {0.5f, 0.5f, 0.5f}; + normalize_permute_op_ = + std::make_shared(mean, std, true); + normalize_op_ = std::make_shared(mean, std, true); + hwc2chw_op_ = std::make_shared(); + cast_op_ = std::make_shared("float"); +} + +void RecognizerPreprocessor::OcrRecognizerResizeImage( + FDMat* mat, float max_wh_ratio, const std::vector& rec_image_shape, + bool static_shape_infer) { int img_h, img_w; img_h = rec_image_shape[1]; img_w = rec_image_shape[2]; @@ -39,25 +54,25 @@ void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio, } else { resize_w = int(ceilf(img_h * ratio)); } - Resize::Run(mat, resize_w, img_h); - Pad::Run(mat, 0, 0, 0, int(img_w - mat->Width()), {127, 127, 127}); - + resize_op_->SetWidthAndHeight(resize_w, img_h); + (*resize_op_)(mat); + pad_op_->SetPaddingSize(0, 0, 0, int(img_w - mat->Width())); + (*pad_op_)(mat); } else { if (mat->Width() >= img_w) { - Resize::Run(mat, img_w, img_h); // Reszie W to 320 + // Reszie W to 320 + resize_op_->SetWidthAndHeight(img_w, img_h); + (*resize_op_)(mat); } else { - Resize::Run(mat, mat->Width(), img_h); - Pad::Run(mat, 0, 0, 0, int(img_w - mat->Width()), {127, 127, 127}); + resize_op_->SetWidthAndHeight(mat->Width(), img_h); + (*resize_op_)(mat); // Pad to 320 + pad_op_->SetPaddingSize(0, 0, 0, int(img_w - mat->Width())); + (*pad_op_)(mat); } } } -bool RecognizerPreprocessor::Run(std::vector* images, - std::vector* outputs) { - return Run(images, outputs, 0, images->size(), {}); -} - bool RecognizerPreprocessor::Run(std::vector* images, std::vector* outputs, size_t start_index, size_t end_index, @@ -70,60 +85,55 @@ bool RecognizerPreprocessor::Run(std::vector* images, return false; } + std::vector mats(end_index - start_index); + for (size_t i = start_index; i < end_index; ++i) { + size_t real_index = i; + if (indices.size() != 0) { + real_index = indices[i]; + } + mats[i - start_index] = images->at(real_index); + } + return Run(&mats, outputs); +} + +bool RecognizerPreprocessor::Apply(FDMatBatch* image_batch, + std::vector* outputs) { int img_h = rec_image_shape_[1]; int img_w = rec_image_shape_[2]; float max_wh_ratio = img_w * 1.0 / img_h; float ori_wh_ratio; - for (size_t i = start_index; i < end_index; ++i) { - size_t real_index = i; - if (indices.size() != 0) { - real_index = indices[i]; - } - FDMat* mat = &(images->at(real_index)); + for (size_t i = 0; i < image_batch->mats->size(); ++i) { + FDMat* mat = &(image_batch->mats->at(i)); ori_wh_ratio = mat->Width() * 1.0 / mat->Height(); max_wh_ratio = std::max(max_wh_ratio, ori_wh_ratio); } - for (size_t i = start_index; i < end_index; ++i) { - size_t real_index = i; - if (indices.size() != 0) { - real_index = indices[i]; - } - FDMat* mat = &(images->at(real_index)); + for (size_t i = 0; i < image_batch->mats->size(); ++i) { + FDMat* mat = &(image_batch->mats->at(i)); OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_, static_shape_infer_); - if (!disable_normalize_ && !disable_permute_) { - NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_); - } else { - if (!disable_normalize_) { - Normalize::Run(mat, mean_, scale_, is_scale_); - } - if (!disable_permute_) { - HWC2CHW::Run(mat); - Cast::Run(mat, "float"); - } + } + + if (!disable_normalize_ && !disable_permute_) { + (*normalize_permute_op_)(image_batch); + } else { + if (!disable_normalize_) { + (*normalize_op_)(image_batch); + } + if (!disable_permute_) { + (*hwc2chw_op_)(image_batch); + (*cast_op_)(image_batch); } } + // Only have 1 output Tensor. outputs->resize(1); - size_t tensor_size = end_index - start_index; - // Concat all the preprocessed data to a batch tensor - std::vector tensors(tensor_size); - for (size_t i = 0; i < tensor_size; ++i) { - size_t real_index = i + start_index; - if (indices.size() != 0) { - real_index = indices[i + start_index]; - } - - (*images)[real_index].ShareWithTensor(&(tensors[i])); - tensors[i].ExpandDim(0); - } - if (tensors.size() == 1) { - (*outputs)[0] = std::move(tensors[0]); - } else { - function::Concat(tensors, &((*outputs)[0]), 0); - } + // Get the NCHW tensor + FDTensor* tensor = image_batch->Tensor(); + (*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(), + tensor->Data(), tensor->device, + tensor->device_id); return true; } diff --git a/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h index c5edb2a80..ca630bcd2 100644 --- a/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h @@ -14,6 +14,7 @@ #pragma once #include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/processors/manager.h" #include "fastdeploy/vision/common/result.h" namespace fastdeploy { @@ -22,19 +23,30 @@ namespace vision { namespace ocr { /*! @brief Preprocessor object for PaddleClas serials model. */ -class FASTDEPLOY_DECL RecognizerPreprocessor { +class FASTDEPLOY_DECL RecognizerPreprocessor : public ProcessorManager { public: + RecognizerPreprocessor(); + using ProcessorManager::Run; /** \brief Process the input image and prepare input tensors for runtime * * \param[in] images The input data list, all the elements are FDMat * \param[in] outputs The output tensors which will be fed into runtime * \return true if the preprocess successed, otherwise false */ - bool Run(std::vector* images, std::vector* outputs); bool Run(std::vector* images, std::vector* outputs, size_t start_index, size_t end_index, const std::vector& indices); + /** \brief Implement the virtual function of ProcessorManager, Apply() is the + * body of Run(). Apply() contains the main logic of preprocessing, Run() is + * called by users to execute preprocessing + * + * \param[in] image_batch The input image batch + * \param[in] outputs The output tensors which will feed in runtime + * \return true if the preprocess successed, otherwise false + */ + virtual bool Apply(FDMatBatch* image_batch, std::vector* outputs); + /// Set static_shape_infer is true or not. When deploy PP-OCR /// on hardware which can not support dynamic input shape very well, /// like Huawei Ascned, static_shape_infer needs to to be true. @@ -44,20 +56,16 @@ class FASTDEPLOY_DECL RecognizerPreprocessor { /// Get static_shape_infer of the recognition preprocess bool GetStaticShapeInfer() const { return static_shape_infer_; } - /// Set mean value for the image normalization in recognition preprocess - void SetMean(const std::vector& mean) { mean_ = mean; } - /// Get mean value of the image normalization in recognition preprocess - std::vector GetMean() const { return mean_; } - - /// Set scale value for the image normalization in recognition preprocess - void SetScale(const std::vector& scale) { scale_ = scale; } - /// Get scale value of the image normalization in recognition preprocess - std::vector GetScale() const { return scale_; } - - /// Set is_scale for the image normalization in recognition preprocess - void SetIsScale(bool is_scale) { is_scale_ = is_scale; } - /// Get is_scale of the image normalization in recognition preprocess - bool GetIsScale() const { return is_scale_; } + /// Set preprocess normalize parameters, please call this API to customize + /// the normalize parameters, otherwise it will use the default normalize + /// parameters. + void SetNormalize(const std::vector& mean, + const std::vector& std, + bool is_scale) { + normalize_permute_op_ = + std::make_shared(mean, std, is_scale); + normalize_op_ = std::make_shared(mean, std, is_scale); + } /// Set rec_image_shape for the recognition preprocess void SetRecImageShape(const std::vector& rec_image_shape) { @@ -72,15 +80,21 @@ class FASTDEPLOY_DECL RecognizerPreprocessor { void DisablePermute() { disable_normalize_ = true; } private: + void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio, + const std::vector& rec_image_shape, + bool static_shape_infer); // for recording the switch of hwc2chw bool disable_permute_ = false; // for recording the switch of normalize bool disable_normalize_ = false; std::vector rec_image_shape_ = {3, 48, 320}; - std::vector mean_ = {0.5f, 0.5f, 0.5f}; - std::vector scale_ = {0.5f, 0.5f, 0.5f}; - bool is_scale_ = true; bool static_shape_infer_ = false; + std::shared_ptr resize_op_; + std::shared_ptr pad_op_; + std::shared_ptr normalize_permute_op_; + std::shared_ptr normalize_op_; + std::shared_ptr hwc2chw_op_; + std::shared_ptr cast_op_; }; } // namespace ocr diff --git a/python/fastdeploy/vision/ocr/ppocr/__init__.py b/python/fastdeploy/vision/ocr/ppocr/__init__.py index 1fa39600b..45886d40e 100755 --- a/python/fastdeploy/vision/ocr/ppocr/__init__.py +++ b/python/fastdeploy/vision/ocr/ppocr/__init__.py @@ -52,10 +52,7 @@ class DBDetectorPreprocessor: value, int), "The value to set `max_side_len` must be type of int." self._preprocessor.max_side_len = value - def set_normalize(self, - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225], - is_scale=True): + def set_normalize(self, mean, std, is_scale): """Set preprocess normalize parameters, please call this API to customize the normalize parameters, otherwise it will use the default normalize parameters. @@ -340,35 +337,15 @@ class ClassifierPreprocessor: """ return self._preprocessor.run(input_ims) - @property - def is_scale(self): - return self._preprocessor.is_scale - - @is_scale.setter - def is_scale(self, value): - assert isinstance( - value, bool), "The value to set `is_scale` must be type of bool." - self._preprocessor.is_scale = value - - @property - def scale(self): - return self._preprocessor.scale - - @scale.setter - def scale(self, value): - assert isinstance( - value, list), "The value to set `scale` must be type of list." - self._preprocessor.scale = value - - @property - def mean(self): - return self._preprocessor.mean - - @mean.setter - def mean(self, value): - assert isinstance( - value, list), "The value to set `mean` must be type of list." - self._preprocessor.mean = value + def set_normalize(self, mean, std, is_scale): + """Set preprocess normalize parameters, please call this API to + customize the normalize parameters, otherwise it will use the default + normalize parameters. + :param: mean: (list of float) mean values + :param: std: (list of float) std values + :param: is_scale: (boolean) whether to scale + """ + self._preprocessor.set_normalize(mean, std, is_scale) @property def cls_image_shape(self): @@ -496,37 +473,6 @@ class Classifier(FastDeployModel): def postprocessor(self, value): self._model.postprocessor = value - # Cls Preprocessor Property - @property - def is_scale(self): - return self._model.preprocessor.is_scale - - @is_scale.setter - def is_scale(self, value): - assert isinstance( - value, bool), "The value to set `is_scale` must be type of bool." - self._model.preprocessor.is_scale = value - - @property - def scale(self): - return self._model.preprocessor.scale - - @scale.setter - def scale(self, value): - assert isinstance( - value, list), "The value to set `scale` must be type of list." - self._model.preprocessor.scale = value - - @property - def mean(self): - return self._model.preprocessor.mean - - @mean.setter - def mean(self, value): - assert isinstance( - value, list), "The value to set `mean` must be type of list." - self._model.preprocessor.mean = value - @property def cls_image_shape(self): return self._model.preprocessor.cls_image_shape @@ -575,35 +521,15 @@ class RecognizerPreprocessor: bool), "The value to set `static_shape_infer` must be type of bool." self._preprocessor.static_shape_infer = value - @property - def is_scale(self): - return self._preprocessor.is_scale - - @is_scale.setter - def is_scale(self, value): - assert isinstance( - value, bool), "The value to set `is_scale` must be type of bool." - self._preprocessor.is_scale = value - - @property - def scale(self): - return self._preprocessor.scale - - @scale.setter - def scale(self, value): - assert isinstance( - value, list), "The value to set `scale` must be type of list." - self._preprocessor.scale = value - - @property - def mean(self): - return self._preprocessor.mean - - @mean.setter - def mean(self, value): - assert isinstance( - value, list), "The value to set `mean` must be type of list." - self._preprocessor.mean = value + def set_normalize(self, mean, std, is_scale): + """Set preprocess normalize parameters, please call this API to + customize the normalize parameters, otherwise it will use the default + normalize parameters. + :param: mean: (list of float) mean values + :param: std: (list of float) std values + :param: is_scale: (boolean) whether to scale + """ + self._preprocessor.set_normalize(mean, std, is_scale) @property def rec_image_shape(self): @@ -728,36 +654,6 @@ class Recognizer(FastDeployModel): bool), "The value to set `static_shape_infer` must be type of bool." self._model.preprocessor.static_shape_infer = value - @property - def is_scale(self): - return self._model.preprocessor.is_scale - - @is_scale.setter - def is_scale(self, value): - assert isinstance( - value, bool), "The value to set `is_scale` must be type of bool." - self._model.preprocessor.is_scale = value - - @property - def scale(self): - return self._model.preprocessor.scale - - @scale.setter - def scale(self, value): - assert isinstance( - value, list), "The value to set `scale` must be type of list." - self._model.preprocessor.scale = value - - @property - def mean(self): - return self._model.preprocessor.mean - - @mean.setter - def mean(self, value): - assert isinstance( - value, list), "The value to set `mean` must be type of list." - self._model.preprocessor.mean = value - @property def rec_image_shape(self): return self._model.preprocessor.rec_image_shape