[CVCUDA] Vision Processor Python API and Tutorial (#1394)

* bind success

* bind success fix

* FDMat pybind, ResizeByShort pybind

* FDMat pybind, ResizeByShort pybind, remove initialized_

* override BindProcessorManager::Run in python is available

* PyProcessorManager done

* vision_pybind fix

* manager.py fix

* add tutorials

* remove Apply() bind

* remove Apply() bind and fix

* fix reviewed problem

* fix reviewed problem

* fix reviewed problem readme

* fix reviewed problem readme etc

* apply return outputs

* nits

* update readme

* fix FDMatbatch

* add op pybind: CenterCrop, Pad

* add op overload for pass FDMatBatch

---------

Co-authored-by: Wang Xinyu <shaywxy@gmail.com>
This commit is contained in:
guxukai
2023-03-10 14:42:32 +08:00
committed by GitHub
parent cb7c8a07d4
commit c6480de736
22 changed files with 530 additions and 34 deletions

View File

@@ -0,0 +1,28 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <pybind11/operators.h>
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindProcessor(pybind11::module& m) {
pybind11::class_<vision::Processor>(m, "Processor")
.def("__call__", [](vision::Processor& self,
vision::FDMat* mat) { return self(mat); })
.def("__call__",
[](vision::Processor& self, vision::FDMatBatch* mat_batch) {
return self(mat_batch);
});
}
} // namespace fastdeploy

View File

@@ -0,0 +1,23 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindCenterCrop(pybind11::module& m) {
pybind11::class_<vision::CenterCrop, vision::Processor>(
m, "CenterCrop")
.def(pybind11::init<int, int>(), "Default constructor");
}
} // namespace fastdeploy

View File

@@ -49,48 +49,51 @@ bool ProcessorManager::CudaUsed() {
return (proc_lib_ == ProcLib::CUDA || proc_lib_ == ProcLib::CVCUDA); return (proc_lib_ == ProcLib::CUDA || proc_lib_ == ProcLib::CVCUDA);
} }
bool ProcessorManager::Run(std::vector<FDMat>* images, void ProcessorManager::PreApply(FDMatBatch* image_batch) {
std::vector<FDTensor>* outputs) { FDASSERT(image_batch->mats != nullptr, "The mats is empty.");
if (images->size() == 0) { FDASSERT(image_batch->mats->size() > 0,
FDERROR << "The size of input images should be greater than 0." "The size of input images should be greater than 0.");
<< std::endl;
return false; if (image_batch->mats->size() > input_caches_.size()) {
input_caches_.resize(image_batch->mats->size());
output_caches_.resize(image_batch->mats->size());
} }
image_batch->input_cache = &batch_input_cache_;
image_batch->output_cache = &batch_output_cache_;
if (images->size() > input_caches_.size()) {
input_caches_.resize(images->size());
output_caches_.resize(images->size());
}
FDMatBatch image_batch(images);
image_batch.input_cache = &batch_input_cache_;
image_batch.output_cache = &batch_output_cache_;
image_batch.proc_lib = proc_lib_;
for (size_t i = 0; i < images->size(); ++i) {
if (CudaUsed()) { if (CudaUsed()) {
SetStream(&image_batch); SetStream(image_batch);
} }
(*images)[i].input_cache = &input_caches_[i];
(*images)[i].output_cache = &output_caches_[i]; for (size_t i = 0; i < image_batch->mats->size(); ++i) {
(*images)[i].proc_lib = proc_lib_; FDMat* mat = &(image_batch->mats->at(i));
if ((*images)[i].mat_type == ProcLib::CUDA) { mat->input_cache = &input_caches_[i];
mat->output_cache = &output_caches_[i];
mat->proc_lib = proc_lib_;
if (mat->mat_type == ProcLib::CUDA) {
// Make a copy of the input data ptr, so that the original data ptr of // Make a copy of the input data ptr, so that the original data ptr of
// FDMat won't be modified. // FDMat won't be modified.
auto fd_tensor = std::make_shared<FDTensor>(); auto fd_tensor = std::make_shared<FDTensor>();
fd_tensor->SetExternalData( fd_tensor->SetExternalData(mat->Tensor()->shape, mat->Tensor()->Dtype(),
(*images)[i].Tensor()->shape, (*images)[i].Tensor()->Dtype(), mat->Tensor()->Data(), mat->Tensor()->device,
(*images)[i].Tensor()->Data(), (*images)[i].Tensor()->device, mat->Tensor()->device_id);
(*images)[i].Tensor()->device_id); mat->SetTensor(fd_tensor);
(*images)[i].SetTensor(fd_tensor);
} }
} }
}
bool ret = Apply(&image_batch, outputs); void ProcessorManager::PostApply() {
if (CudaUsed()) { if (CudaUsed()) {
SyncStream(); SyncStream();
} }
}
bool ProcessorManager::Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs) {
FDMatBatch image_batch(images);
PreApply(&image_batch);
bool ret = Apply(&image_batch, outputs);
PostApply();
return ret; return ret;
} }

View File

@@ -78,6 +78,10 @@ class FASTDEPLOY_DECL ProcessorManager {
virtual bool Apply(FDMatBatch* image_batch, virtual bool Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs) = 0; std::vector<FDTensor>* outputs) = 0;
void PreApply(FDMatBatch* image_batch);
void PostApply();
protected: protected:
ProcLib proc_lib_ = ProcLib::DEFAULT; ProcLib proc_lib_ = ProcLib::DEFAULT;

View File

@@ -14,8 +14,22 @@
#include "fastdeploy/pybind/main.h" #include "fastdeploy/pybind/main.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision {
// PyProcessorManager is used for pybind11::init() of ProcessorManager
// Because ProcessorManager have a pure Virtual function Apply()
class FASTDEPLOY_DECL PyProcessorManager : public ProcessorManager {
public:
using ProcessorManager::ProcessorManager;
bool Apply(FDMatBatch* image_batch, std::vector<FDTensor>* outputs) override {
PYBIND11_OVERRIDE_PURE(bool, ProcessorManager, Apply, image_batch, outputs);
}
};
} // namespace vision
void BindProcessorManager(pybind11::module& m) { void BindProcessorManager(pybind11::module& m) {
pybind11::class_<vision::ProcessorManager>(m, "ProcessorManager") pybind11::class_<vision::ProcessorManager, vision::PyProcessorManager>(
m, "ProcessorManager")
.def(pybind11::init<>())
.def("run", .def("run",
[](vision::ProcessorManager& self, [](vision::ProcessorManager& self,
std::vector<pybind11::array>& im_list) { std::vector<pybind11::array>& im_list) {
@@ -34,6 +48,8 @@ void BindProcessorManager(pybind11::module& m) {
} }
return outputs; return outputs;
}) })
.def("pre_apply", &vision::ProcessorManager::PreApply)
.def("post_apply", &vision::ProcessorManager::PostApply)
.def("use_cuda", .def("use_cuda",
[](vision::ProcessorManager& self, bool enable_cv_cuda = false, [](vision::ProcessorManager& self, bool enable_cv_cuda = false,
int gpu_id = -1) { self.UseCuda(enable_cv_cuda, gpu_id); }); int gpu_id = -1) { self.UseCuda(enable_cv_cuda, gpu_id); });

View File

@@ -57,6 +57,10 @@ struct FASTDEPLOY_DECL FDMatBatch {
#endif #endif
std::vector<FDMat>* mats = nullptr; std::vector<FDMat>* mats = nullptr;
// Used by pybind, since python cannot pass list as pointer or reference
std::vector<FDMat> mats_holder;
ProcLib mat_type = ProcLib::OPENCV; ProcLib mat_type = ProcLib::OPENCV;
FDMatBatchLayout layout = FDMatBatchLayout::NHWC; FDMatBatchLayout layout = FDMatBatchLayout::NHWC;
Device device = Device::CPU; Device device = Device::CPU;

View File

@@ -0,0 +1,30 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindFDMatBatch(pybind11::module& m) {
pybind11::class_<vision::FDMatBatch>(m, "FDMatBatch")
.def(pybind11::init<>(), "Default constructor")
.def_readwrite("input_cache", &vision::FDMatBatch::input_cache)
.def_readwrite("output_cache", &vision::FDMatBatch::output_cache)
.def_readwrite("mats", &vision::FDMatBatch::mats)
.def("from_mats",
[](vision::FDMatBatch& self, std::vector<vision::FDMat>& _mats) {
self.mats_holder = _mats;
self.mats = &(self.mats_holder);
});
}
} // namespace fastdeploy

View File

@@ -0,0 +1,29 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindFDMat(pybind11::module& m) {
pybind11::class_<vision::FDMat>(m, "FDMat")
.def(pybind11::init<>(), "Default constructor")
.def_readwrite("input_cache", &vision::FDMat::input_cache)
.def_readwrite("output_cache", &vision::FDMat::output_cache)
.def("from_numpy",
[](vision::FDMat& self, pybind11::array& pyarray) {
self = vision::WrapMat(PyArrayToCvMat(pyarray));
})
.def("print_info", &vision::FDMat::PrintInfo);
}
} // namespace fastdeploy

View File

@@ -0,0 +1,25 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindNormalizeAndPermute(pybind11::module& m) {
pybind11::class_<vision::NormalizeAndPermute, vision::Processor>(
m, "NormalizeAndPermute")
.def(pybind11::init<std::vector<float>, std::vector<float>, bool,
std::vector<float>, std::vector<float>, bool>(),
"Default constructor");
}
} // namespace fastdeploy

View File

@@ -0,0 +1,23 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindPad(pybind11::module& m) {
pybind11::class_<vision::Pad, vision::Processor>(
m, "Pad")
.def(pybind11::init<int, int, int, int, std::vector<float>>(), "Default constructor");
}
} // namespace fastdeploy

View File

@@ -0,0 +1,36 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindProcessorManager(pybind11::module& m);
void BindNormalizeAndPermute(pybind11::module& m);
void BindProcessor(pybind11::module& m);
void BindResizeByShort(pybind11::module& m);
void BindCenterCrop(pybind11::module& m);
void BindPad(pybind11::module& m);
void BindProcessors(pybind11::module& m) {
auto processors_m =
m.def_submodule("processors", "Module to deploy Processors models");
BindProcessorManager(processors_m);
BindProcessor(processors_m);
BindNormalizeAndPermute(processors_m);
BindResizeByShort(processors_m);
BindCenterCrop(processors_m);
BindPad(processors_m);
}
} // namespace fastdeploy

View File

@@ -0,0 +1,23 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindResizeByShort(pybind11::module& m) {
pybind11::class_<vision::ResizeByShort, vision::Processor>(m, "ResizeByShort")
.def(pybind11::init<int, int, bool, std::vector<int>>(),
"Default constructor");
}
} // namespace fastdeploy

View File

@@ -16,7 +16,9 @@
namespace fastdeploy { namespace fastdeploy {
void BindProcessorManager(pybind11::module& m); void BindFDMat(pybind11::module& m);
void BindFDMatBatch(pybind11::module& m);
void BindProcessors(pybind11::module& m);
void BindDetection(pybind11::module& m); void BindDetection(pybind11::module& m);
void BindClassification(pybind11::module& m); void BindClassification(pybind11::module& m);
void BindSegmentation(pybind11::module& m); void BindSegmentation(pybind11::module& m);
@@ -205,7 +207,9 @@ void BindVision(pybind11::module& m) {
m.def("disable_flycv", &vision::DisableFlyCV, m.def("disable_flycv", &vision::DisableFlyCV,
"Disable image preprocessing by FlyCV, change to use OpenCV."); "Disable image preprocessing by FlyCV, change to use OpenCV.");
BindProcessorManager(m); BindFDMat(m);
BindFDMatBatch(m);
BindProcessors(m);
BindDetection(m); BindDetection(m);
BindClassification(m); BindClassification(m);
BindSegmentation(m); BindSegmentation(m);

View File

@@ -14,3 +14,4 @@
from __future__ import absolute_import from __future__ import absolute_import
from .manager import ProcessorManager from .manager import ProcessorManager
from .manager import PyProcessorManager

View File

@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from abc import ABC, abstractmethod
from ... import c_lib_wrap as C
class ProcessorManager: class ProcessorManager:
@@ -34,3 +36,34 @@ class ProcessorManager:
:param: gpu_id: GPU device id :param: gpu_id: GPU device id
""" """
return self._manager.use_cuda(enable_cv_cuda, gpu_id) return self._manager.use_cuda(enable_cv_cuda, gpu_id)
class PyProcessorManager(ABC):
"""
PyProcessorManager is used to define a customized processor in python
"""
def __init__(self):
self._manager = C.vision.processors.ProcessorManager()
def use_cuda(self, enable_cv_cuda=False, gpu_id=-1):
"""Use CUDA processors
:param: enable_cv_cuda: Ture: use CV-CUDA, False: use CUDA only
:param: gpu_id: GPU device id
"""
return self._manager.use_cuda(enable_cv_cuda, gpu_id)
def __call__(self, images):
image_batch = C.vision.FDMatBatch()
image_batch.from_mats(images)
self._manager.pre_apply(image_batch)
outputs = self.apply(image_batch)
self._manager.post_apply()
return outputs
@abstractmethod
def apply(self, image_batch):
print("This function has to be implemented.")
return []

View File

@@ -7,4 +7,5 @@ This directory provides some tutorials for FastDeploy. For other model deploymen
- Intel independent graphics card/integrated graphics card deployment [see intel_gpu](intel_gpu) - Intel independent graphics card/integrated graphics card deployment [see intel_gpu](intel_gpu)
- Model multithreaded call [see multi_thread](multi_thread) - Model multithreaded call [see multi_thread](multi_thread)
- Image decoding, including hardward decoding, e.g. nvJPEG [image_decoder](image_decoder) - Image decoding, including hardward decoding, e.g. nvJPEG [image_decoder](image_decoder)
- Vision processor (preprocessing, CV-CUDA, FlyCV, etc.[vision_processor](vision_processor)
- Deploy models with C or C# API [use_c_csharp_sdk](use_c_sharp_sdk) - Deploy models with C or C# API [use_c_csharp_sdk](use_c_sharp_sdk)

View File

@@ -8,4 +8,5 @@
- Intel独立显卡/集成显卡部署 [见intel_gpu](intel_gpu) - Intel独立显卡/集成显卡部署 [见intel_gpu](intel_gpu)
- 模型多线程调用 [见multi_thread](multi_thread) - 模型多线程调用 [见multi_thread](multi_thread)
- 图片解码含nvJPEG硬解码 [见image_decoder](image_decoder) - 图片解码含nvJPEG硬解码 [见image_decoder](image_decoder)
- 多硬件图像处理库预处理、CV-CUDA、FlyCV等 [见vision_processor](vision_processor)
- 使用C或C# API进行模型部署 [见use_c_csharp_sdk](use_c_sharp_sdk) - 使用C或C# API进行模型部署 [见use_c_csharp_sdk](use_c_sharp_sdk)

View File

@@ -0,0 +1,43 @@
English | [中文](README_CN.md)
# Vision Processor
Vision Processor is used to implement model preprocessing, postprocessing, etc. The following 3rd party vision libraries are integrated:
- OpenCV, general CPU image processing
- FlyCV, mainly optimized for ARM CPU
- CV-CUDA, for NVIDIA GPU
## C++
TODO(guxukai)
## Python
Python API, Currently supported operators are as follows:
- ResizeByShort
- NormalizeAndPermute
Users can implement a image processing modules by inheriting the `PyProcessorManager` class. The base class `PyProcessorManager` implements GPU memory management, CUDA stream management, etc. Users only need to implement the apply() function by calling vision processors in this library and implements processing logic. For specific implementation, please refer to the demo code.
### Demo
- [Python Demo](python)
### Performance comparison between CV-CUDA and OpenCV:
CPU: Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz
GPU: T4
CUDA: 11.6
Processing logic: Resize -> NormalizeAndPermute
Warmup 100 roundstested 1000 rounds and get avg. latency.
| Input Image Shape | Target shape | Batch Size | OpenCV | CV-CUDA | Gain |
| ----------- | -- | ---------- | ------- | ------ | ------ |
| 1920x1080 | 640x360 | 1 | 1.1572ms | 0.9067ms | 16.44% |
| 1280x720 | 640x360 | 1 | 2.7551ms | 0.5296ms | 80.78% |
| 360x240 | 640x360 | 1 | 3.3450ms | 0.2421ms | 92.76% |

View File

@@ -0,0 +1,42 @@
中文 | [English](README.md)
# 多硬件图像处理库
多硬件图像处理库Vision Processor可用于实现模型的预处理、后处理等图像操作底层封装了多个第三方图像处理库包括
- OpenCV用于通用CPU图像处理
- FlyCV主要针对ARM CPU加速
- CV-CUDA用于NVIDIA GPU
## C++
待编写
## Python
Python API目前支持的算子如下
- ResizeByShort
- NormalizeAndPermute
用户可通过继承PyProcessorManager类实现自己的图像处理模块。基类PyProcessorManager实现了GPU内存管理、CUDA stream管理等用户仅需要实现apply()函数,在其中调用多硬件图像处理库中的算子、实现处理逻辑即可,具体实现可参考示例代码。
### 示例代码
- [Python示例](python)
### CV-CUDA与OpenCV性能对比
CPU: Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz
GPU: T4
CUDA: 11.6
Processing logic: Resize -> NormalizeAndPermute
Warmup 100 roundstested 1000 rounds and get avg. latency.
| Input Image Shape | Target shape | Batch Size | OpenCV | CV-CUDA | Gain |
| ----------- | -- | ---------- | ------- | ------ | ------ |
| 1920x1080 | 640x360 | 1 | 1.1572ms | 0.9067ms | 16.44% |
| 1280x720 | 640x360 | 1 | 2.7551ms | 0.5296ms | 80.78% |
| 360x240 | 640x360 | 1 | 3.3450ms | 0.2421ms | 92.76% |

View File

@@ -0,0 +1,19 @@
English | [中文](README_CN.md)
# Preprocessor Python Demo
1. [build FastDeployPython](../../../docs/cn/build_and_install), or download[FastDeploy prebuilt libraryPython](../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
2. Run the Demo
```bash
# Download the test image
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
# Run the Demo
# OpenCV
python preprocess.py
# CV-CUDA
python preprocess.py --use_cvcuda True
```

View File

@@ -0,0 +1,19 @@
中文 | [English](README.md)
# Preprocessor Python 示例代码
1. [编译FastDeployPython](../docs/cn/build_and_install), 或直接下载[FastDeploy预编译库Python](../docs/cn/build_and_install/download_prebuilt_libraries.md)
2. 运行示例代码
```bash
# 下载测试图片
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
# 运行示例代码
# OpenCV
python preprocess.py
# CV-CUDA
python preprocess.py --use_cvcuda True
```

View File

@@ -0,0 +1,89 @@
import fastdeploy as fd
import cv2
from fastdeploy.vision.common.manager import PyProcessorManager
def parse_arguments():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--use_cvcuda",
required=False,
type=bool,
help="Use CV-CUDA in preprocess")
return parser.parse_args()
# define CustomProcessor
class CustomProcessor(PyProcessorManager):
def __init__(self) -> None:
super().__init__()
# create op
hw = [500, 500]
self.resize_op = fd.C.vision.processors.ResizeByShort(100, 1, True, hw)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
is_scale = True
min = []
max = []
swap_rb = False
self.normalize_permute_op = fd.C.vision.processors.NormalizeAndPermute(
mean, std, is_scale, min, max, swap_rb)
width = 50
height = 50
self.centercrop_op = fd.C.vision.processors.CenterCrop(width, height)
top = 5
bottom = 5
left = 5
right = 5
pad_value = [225, 225, 225]
self.pad_op = fd.C.vision.processors.Pad(top, bottom, left, right,
pad_value)
def apply(self, image_batch):
outputs = []
self.resize_op(image_batch)
self.centercrop_op(image_batch)
self.pad_op(image_batch)
self.normalize_permute_op(image_batch)
for i in range(len(image_batch.mats)):
outputs.append(image_batch.mats[i])
return outputs
if __name__ == "__main__":
# read jpg
im1 = cv2.imread('ILSVRC2012_val_00000010.jpeg')
im2 = cv2.imread('ILSVRC2012_val_00000010.jpeg')
mat1 = fd.C.vision.FDMat()
mat1.from_numpy(im1)
mat2 = fd.C.vision.FDMat()
mat2.from_numpy(im2)
images = [mat1, mat2]
args = parse_arguments()
# creae processor
preprocessor = CustomProcessor()
# use CV-CUDA
if args.use_cvcuda:
preprocessor.use_cuda(True, -1)
# show input
for i in range(len(images)):
images[i].print_info('images' + str(i) + ': ')
# run the Processer with CVCUDA
outputs = preprocessor(images)
# show output
for i in range(len(outputs)):
outputs[i].print_info('outputs' + str(i) + ': ')