diff --git a/fastdeploy/vision/common/processors/base_pybind.cc b/fastdeploy/vision/common/processors/base_pybind.cc new file mode 100644 index 000000000..9991719d1 --- /dev/null +++ b/fastdeploy/vision/common/processors/base_pybind.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindProcessor(pybind11::module& m) { + pybind11::class_(m, "Processor") + .def("__call__", [](vision::Processor& self, + vision::FDMat* mat) { return self(mat); }) + .def("__call__", + [](vision::Processor& self, vision::FDMatBatch* mat_batch) { + return self(mat_batch); + }); +} + +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/common/processors/center_crop_pybind.cc b/fastdeploy/vision/common/processors/center_crop_pybind.cc new file mode 100644 index 000000000..6c9a30cc0 --- /dev/null +++ b/fastdeploy/vision/common/processors/center_crop_pybind.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindCenterCrop(pybind11::module& m) { + pybind11::class_( + m, "CenterCrop") + .def(pybind11::init(), "Default constructor"); +} + +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/common/processors/manager.cc b/fastdeploy/vision/common/processors/manager.cc index 05167e6fe..89ae4bf15 100644 --- a/fastdeploy/vision/common/processors/manager.cc +++ b/fastdeploy/vision/common/processors/manager.cc @@ -49,48 +49,51 @@ bool ProcessorManager::CudaUsed() { return (proc_lib_ == ProcLib::CUDA || proc_lib_ == ProcLib::CVCUDA); } -bool ProcessorManager::Run(std::vector* images, - std::vector* outputs) { - if (images->size() == 0) { - FDERROR << "The size of input images should be greater than 0." - << std::endl; - return false; +void ProcessorManager::PreApply(FDMatBatch* image_batch) { + FDASSERT(image_batch->mats != nullptr, "The mats is empty."); + FDASSERT(image_batch->mats->size() > 0, + "The size of input images should be greater than 0."); + + if (image_batch->mats->size() > input_caches_.size()) { + input_caches_.resize(image_batch->mats->size()); + output_caches_.resize(image_batch->mats->size()); + } + image_batch->input_cache = &batch_input_cache_; + image_batch->output_cache = &batch_output_cache_; + + if (CudaUsed()) { + SetStream(image_batch); } - if (images->size() > input_caches_.size()) { - input_caches_.resize(images->size()); - output_caches_.resize(images->size()); - } - - FDMatBatch image_batch(images); - image_batch.input_cache = &batch_input_cache_; - image_batch.output_cache = &batch_output_cache_; - image_batch.proc_lib = proc_lib_; - - for (size_t i = 0; i < images->size(); ++i) { - if (CudaUsed()) { - SetStream(&image_batch); - } - (*images)[i].input_cache = &input_caches_[i]; - (*images)[i].output_cache = &output_caches_[i]; - (*images)[i].proc_lib = proc_lib_; - if ((*images)[i].mat_type == ProcLib::CUDA) { + for (size_t i = 0; i < image_batch->mats->size(); ++i) { + FDMat* mat = &(image_batch->mats->at(i)); + mat->input_cache = &input_caches_[i]; + mat->output_cache = &output_caches_[i]; + mat->proc_lib = proc_lib_; + if (mat->mat_type == ProcLib::CUDA) { // Make a copy of the input data ptr, so that the original data ptr of // FDMat won't be modified. auto fd_tensor = std::make_shared(); - fd_tensor->SetExternalData( - (*images)[i].Tensor()->shape, (*images)[i].Tensor()->Dtype(), - (*images)[i].Tensor()->Data(), (*images)[i].Tensor()->device, - (*images)[i].Tensor()->device_id); - (*images)[i].SetTensor(fd_tensor); + fd_tensor->SetExternalData(mat->Tensor()->shape, mat->Tensor()->Dtype(), + mat->Tensor()->Data(), mat->Tensor()->device, + mat->Tensor()->device_id); + mat->SetTensor(fd_tensor); } } +} - bool ret = Apply(&image_batch, outputs); - +void ProcessorManager::PostApply() { if (CudaUsed()) { SyncStream(); } +} + +bool ProcessorManager::Run(std::vector* images, + std::vector* outputs) { + FDMatBatch image_batch(images); + PreApply(&image_batch); + bool ret = Apply(&image_batch, outputs); + PostApply(); return ret; } diff --git a/fastdeploy/vision/common/processors/manager.h b/fastdeploy/vision/common/processors/manager.h index aa6dde56a..c184edf08 100644 --- a/fastdeploy/vision/common/processors/manager.h +++ b/fastdeploy/vision/common/processors/manager.h @@ -78,6 +78,10 @@ class FASTDEPLOY_DECL ProcessorManager { virtual bool Apply(FDMatBatch* image_batch, std::vector* outputs) = 0; + void PreApply(FDMatBatch* image_batch); + + void PostApply(); + protected: ProcLib proc_lib_ = ProcLib::DEFAULT; diff --git a/fastdeploy/vision/common/processors/manager_pybind.cc b/fastdeploy/vision/common/processors/manager_pybind.cc index 65507cce5..ce6418aa9 100644 --- a/fastdeploy/vision/common/processors/manager_pybind.cc +++ b/fastdeploy/vision/common/processors/manager_pybind.cc @@ -14,8 +14,22 @@ #include "fastdeploy/pybind/main.h" namespace fastdeploy { +namespace vision { +// PyProcessorManager is used for pybind11::init() of ProcessorManager +// Because ProcessorManager have a pure Virtual function Apply() +class FASTDEPLOY_DECL PyProcessorManager : public ProcessorManager { + public: + using ProcessorManager::ProcessorManager; + bool Apply(FDMatBatch* image_batch, std::vector* outputs) override { + PYBIND11_OVERRIDE_PURE(bool, ProcessorManager, Apply, image_batch, outputs); + } +}; +} // namespace vision + void BindProcessorManager(pybind11::module& m) { - pybind11::class_(m, "ProcessorManager") + pybind11::class_( + m, "ProcessorManager") + .def(pybind11::init<>()) .def("run", [](vision::ProcessorManager& self, std::vector& im_list) { @@ -34,6 +48,8 @@ void BindProcessorManager(pybind11::module& m) { } return outputs; }) + .def("pre_apply", &vision::ProcessorManager::PreApply) + .def("post_apply", &vision::ProcessorManager::PostApply) .def("use_cuda", [](vision::ProcessorManager& self, bool enable_cv_cuda = false, int gpu_id = -1) { self.UseCuda(enable_cv_cuda, gpu_id); }); diff --git a/fastdeploy/vision/common/processors/mat_batch.h b/fastdeploy/vision/common/processors/mat_batch.h index cd24b4a80..cd3206d1f 100644 --- a/fastdeploy/vision/common/processors/mat_batch.h +++ b/fastdeploy/vision/common/processors/mat_batch.h @@ -57,6 +57,10 @@ struct FASTDEPLOY_DECL FDMatBatch { #endif std::vector* mats = nullptr; + + // Used by pybind, since python cannot pass list as pointer or reference + std::vector mats_holder; + ProcLib mat_type = ProcLib::OPENCV; FDMatBatchLayout layout = FDMatBatchLayout::NHWC; Device device = Device::CPU; diff --git a/fastdeploy/vision/common/processors/mat_batch_pybind.cc b/fastdeploy/vision/common/processors/mat_batch_pybind.cc new file mode 100644 index 000000000..bea8e534b --- /dev/null +++ b/fastdeploy/vision/common/processors/mat_batch_pybind.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindFDMatBatch(pybind11::module& m) { + pybind11::class_(m, "FDMatBatch") + .def(pybind11::init<>(), "Default constructor") + .def_readwrite("input_cache", &vision::FDMatBatch::input_cache) + .def_readwrite("output_cache", &vision::FDMatBatch::output_cache) + .def_readwrite("mats", &vision::FDMatBatch::mats) + .def("from_mats", + [](vision::FDMatBatch& self, std::vector& _mats) { + self.mats_holder = _mats; + self.mats = &(self.mats_holder); + }); +} + +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/common/processors/mat_pybind.cc b/fastdeploy/vision/common/processors/mat_pybind.cc new file mode 100644 index 000000000..26f29b7ae --- /dev/null +++ b/fastdeploy/vision/common/processors/mat_pybind.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindFDMat(pybind11::module& m) { + pybind11::class_(m, "FDMat") + .def(pybind11::init<>(), "Default constructor") + .def_readwrite("input_cache", &vision::FDMat::input_cache) + .def_readwrite("output_cache", &vision::FDMat::output_cache) + .def("from_numpy", + [](vision::FDMat& self, pybind11::array& pyarray) { + self = vision::WrapMat(PyArrayToCvMat(pyarray)); + }) + .def("print_info", &vision::FDMat::PrintInfo); +} + +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/common/processors/normalize_and_permute_pybind.cc b/fastdeploy/vision/common/processors/normalize_and_permute_pybind.cc new file mode 100644 index 000000000..c49ca08ab --- /dev/null +++ b/fastdeploy/vision/common/processors/normalize_and_permute_pybind.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindNormalizeAndPermute(pybind11::module& m) { + pybind11::class_( + m, "NormalizeAndPermute") + .def(pybind11::init, std::vector, bool, + std::vector, std::vector, bool>(), + "Default constructor"); +} + +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/common/processors/pad_pybind.cc b/fastdeploy/vision/common/processors/pad_pybind.cc new file mode 100644 index 000000000..94b7a7978 --- /dev/null +++ b/fastdeploy/vision/common/processors/pad_pybind.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindPad(pybind11::module& m) { + pybind11::class_( + m, "Pad") + .def(pybind11::init>(), "Default constructor"); +} + +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/common/processors/processors_pybind.cc b/fastdeploy/vision/common/processors/processors_pybind.cc new file mode 100644 index 000000000..4d9684aa2 --- /dev/null +++ b/fastdeploy/vision/common/processors/processors_pybind.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { + +void BindProcessorManager(pybind11::module& m); +void BindNormalizeAndPermute(pybind11::module& m); +void BindProcessor(pybind11::module& m); +void BindResizeByShort(pybind11::module& m); +void BindCenterCrop(pybind11::module& m); +void BindPad(pybind11::module& m); + +void BindProcessors(pybind11::module& m) { + auto processors_m = + m.def_submodule("processors", "Module to deploy Processors models"); + BindProcessorManager(processors_m); + BindProcessor(processors_m); + BindNormalizeAndPermute(processors_m); + BindResizeByShort(processors_m); + BindCenterCrop(processors_m); + BindPad(processors_m); +} +} // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/resize_by_short_pybind.cc b/fastdeploy/vision/common/processors/resize_by_short_pybind.cc new file mode 100644 index 000000000..bcae92d90 --- /dev/null +++ b/fastdeploy/vision/common/processors/resize_by_short_pybind.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindResizeByShort(pybind11::module& m) { + pybind11::class_(m, "ResizeByShort") + .def(pybind11::init>(), + "Default constructor"); +} + +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/vision_pybind.cc b/fastdeploy/vision/vision_pybind.cc index 03e625728..172d9c098 100755 --- a/fastdeploy/vision/vision_pybind.cc +++ b/fastdeploy/vision/vision_pybind.cc @@ -16,7 +16,9 @@ namespace fastdeploy { -void BindProcessorManager(pybind11::module& m); +void BindFDMat(pybind11::module& m); +void BindFDMatBatch(pybind11::module& m); +void BindProcessors(pybind11::module& m); void BindDetection(pybind11::module& m); void BindClassification(pybind11::module& m); void BindSegmentation(pybind11::module& m); @@ -205,7 +207,9 @@ void BindVision(pybind11::module& m) { m.def("disable_flycv", &vision::DisableFlyCV, "Disable image preprocessing by FlyCV, change to use OpenCV."); - BindProcessorManager(m); + BindFDMat(m); + BindFDMatBatch(m); + BindProcessors(m); BindDetection(m); BindClassification(m); BindSegmentation(m); diff --git a/python/fastdeploy/vision/common/__init__.py b/python/fastdeploy/vision/common/__init__.py index 6e010a427..d41296379 100644 --- a/python/fastdeploy/vision/common/__init__.py +++ b/python/fastdeploy/vision/common/__init__.py @@ -14,3 +14,4 @@ from __future__ import absolute_import from .manager import ProcessorManager +from .manager import PyProcessorManager diff --git a/python/fastdeploy/vision/common/manager.py b/python/fastdeploy/vision/common/manager.py index 05da3d68e..7e7f1db4a 100644 --- a/python/fastdeploy/vision/common/manager.py +++ b/python/fastdeploy/vision/common/manager.py @@ -13,6 +13,8 @@ # limitations under the License. from __future__ import absolute_import +from abc import ABC, abstractmethod +from ... import c_lib_wrap as C class ProcessorManager: @@ -34,3 +36,34 @@ class ProcessorManager: :param: gpu_id: GPU device id """ return self._manager.use_cuda(enable_cv_cuda, gpu_id) + + +class PyProcessorManager(ABC): + """ + PyProcessorManager is used to define a customized processor in python + """ + + def __init__(self): + self._manager = C.vision.processors.ProcessorManager() + + def use_cuda(self, enable_cv_cuda=False, gpu_id=-1): + """Use CUDA processors + + :param: enable_cv_cuda: Ture: use CV-CUDA, False: use CUDA only + :param: gpu_id: GPU device id + """ + return self._manager.use_cuda(enable_cv_cuda, gpu_id) + + def __call__(self, images): + image_batch = C.vision.FDMatBatch() + image_batch.from_mats(images) + + self._manager.pre_apply(image_batch) + outputs = self.apply(image_batch) + self._manager.post_apply() + return outputs + + @abstractmethod + def apply(self, image_batch): + print("This function has to be implemented.") + return [] diff --git a/tutorials/README.md b/tutorials/README.md index 6f98b970d..a0477d8fe 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -7,4 +7,5 @@ This directory provides some tutorials for FastDeploy. For other model deploymen - Intel independent graphics card/integrated graphics card deployment [see intel_gpu](intel_gpu) - Model multithreaded call [see multi_thread](multi_thread) - Image decoding, including hardward decoding, e.g. nvJPEG [image_decoder](image_decoder) +- Vision processor (preprocessing, CV-CUDA, FlyCV, etc.)[vision_processor](vision_processor) - Deploy models with C or C# API [use_c_csharp_sdk](use_c_sharp_sdk) diff --git a/tutorials/README_CN.md b/tutorials/README_CN.md index 1a9bc4aff..74e2b6a25 100644 --- a/tutorials/README_CN.md +++ b/tutorials/README_CN.md @@ -8,4 +8,5 @@ - Intel独立显卡/集成显卡部署 [见intel_gpu](intel_gpu) - 模型多线程调用 [见multi_thread](multi_thread) - 图片解码(含nvJPEG硬解码) [见image_decoder](image_decoder) +- 多硬件图像处理库(预处理、CV-CUDA、FlyCV等) [见vision_processor](vision_processor) - 使用C或C# API进行模型部署 [见use_c_csharp_sdk](use_c_sharp_sdk) diff --git a/tutorials/vision_processor/README.md b/tutorials/vision_processor/README.md new file mode 100644 index 000000000..c8362b917 --- /dev/null +++ b/tutorials/vision_processor/README.md @@ -0,0 +1,43 @@ +English | [中文](README_CN.md) + +# Vision Processor + +Vision Processor is used to implement model preprocessing, postprocessing, etc. The following 3rd party vision libraries are integrated: +- OpenCV, general CPU image processing +- FlyCV, mainly optimized for ARM CPU +- CV-CUDA, for NVIDIA GPU + +## C++ + +TODO(guxukai) + +## Python + +Python API, Currently supported operators are as follows: + +- ResizeByShort +- NormalizeAndPermute + +Users can implement a image processing modules by inheriting the `PyProcessorManager` class. The base class `PyProcessorManager` implements GPU memory management, CUDA stream management, etc. Users only need to implement the apply() function by calling vision processors in this library and implements processing logic. For specific implementation, please refer to the demo code. + +### Demo + +- [Python Demo](python) + +### Performance comparison between CV-CUDA and OpenCV: + +CPU: Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz + +GPU: T4 + +CUDA: 11.6 + +Processing logic: Resize -> NormalizeAndPermute + +Warmup 100 rounds,tested 1000 rounds and get avg. latency. + +| Input Image Shape | Target shape | Batch Size | OpenCV | CV-CUDA | Gain | +| ----------- | -- | ---------- | ------- | ------ | ------ | +| 1920x1080 | 640x360 | 1 | 1.1572ms | 0.9067ms | 16.44% | +| 1280x720 | 640x360 | 1 | 2.7551ms | 0.5296ms | 80.78% | +| 360x240 | 640x360 | 1 | 3.3450ms | 0.2421ms | 92.76% | diff --git a/tutorials/vision_processor/README_CN.md b/tutorials/vision_processor/README_CN.md new file mode 100644 index 000000000..8d4e9741e --- /dev/null +++ b/tutorials/vision_processor/README_CN.md @@ -0,0 +1,42 @@ +中文 | [English](README.md) +# 多硬件图像处理库 + +多硬件图像处理库(Vision Processor)可用于实现模型的预处理、后处理等图像操作,底层封装了多个第三方图像处理库,包括: +- OpenCV,用于通用CPU图像处理 +- FlyCV,主要针对ARM CPU加速 +- CV-CUDA,用于NVIDIA GPU + +## C++ + +待编写 + +## Python + +Python API目前支持的算子如下: + +- ResizeByShort +- NormalizeAndPermute + +用户可通过继承PyProcessorManager类,实现自己的图像处理模块。基类PyProcessorManager实现了GPU内存管理、CUDA stream管理等,用户仅需要实现apply()函数,在其中调用多硬件图像处理库中的算子、实现处理逻辑即可,具体实现可参考示例代码。 + +### 示例代码 + +- [Python示例](python) + +### CV-CUDA与OpenCV性能对比 + +CPU: Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz + +GPU: T4 + +CUDA: 11.6 + +Processing logic: Resize -> NormalizeAndPermute + +Warmup 100 rounds,tested 1000 rounds and get avg. latency. + +| Input Image Shape | Target shape | Batch Size | OpenCV | CV-CUDA | Gain | +| ----------- | -- | ---------- | ------- | ------ | ------ | +| 1920x1080 | 640x360 | 1 | 1.1572ms | 0.9067ms | 16.44% | +| 1280x720 | 640x360 | 1 | 2.7551ms | 0.5296ms | 80.78% | +| 360x240 | 640x360 | 1 | 3.3450ms | 0.2421ms | 92.76% | diff --git a/tutorials/vision_processor/python/README.md b/tutorials/vision_processor/python/README.md new file mode 100644 index 000000000..f99816a0a --- /dev/null +++ b/tutorials/vision_processor/python/README.md @@ -0,0 +1,19 @@ +English | [中文](README_CN.md) + +# Preprocessor Python Demo + +1. [build FastDeploy(Python)](../../../docs/cn/build_and_install), or download[FastDeploy prebuilt library(Python)](../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +2. Run the Demo +```bash +# Download the test image +wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + +# Run the Demo + +# OpenCV +python preprocess.py + +# CV-CUDA +python preprocess.py --use_cvcuda True +``` diff --git a/tutorials/vision_processor/python/README_CN.md b/tutorials/vision_processor/python/README_CN.md new file mode 100644 index 000000000..0b3d3b737 --- /dev/null +++ b/tutorials/vision_processor/python/README_CN.md @@ -0,0 +1,19 @@ +中文 | [English](README.md) + +# Preprocessor Python 示例代码 + +1. [编译FastDeploy(Python)](../docs/cn/build_and_install), 或直接下载[FastDeploy预编译库(Python)](../docs/cn/build_and_install/download_prebuilt_libraries.md) + +2. 运行示例代码 +```bash +# 下载测试图片 +wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + +# 运行示例代码 + +# OpenCV +python preprocess.py + +# CV-CUDA +python preprocess.py --use_cvcuda True +``` diff --git a/tutorials/vision_processor/python/preprocess.py b/tutorials/vision_processor/python/preprocess.py new file mode 100644 index 000000000..51c378da7 --- /dev/null +++ b/tutorials/vision_processor/python/preprocess.py @@ -0,0 +1,89 @@ +import fastdeploy as fd +import cv2 + +from fastdeploy.vision.common.manager import PyProcessorManager + + +def parse_arguments(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + "--use_cvcuda", + required=False, + type=bool, + help="Use CV-CUDA in preprocess") + return parser.parse_args() + + +# define CustomProcessor +class CustomProcessor(PyProcessorManager): + def __init__(self) -> None: + super().__init__() + # create op + hw = [500, 500] + self.resize_op = fd.C.vision.processors.ResizeByShort(100, 1, True, hw) + + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + is_scale = True + min = [] + max = [] + swap_rb = False + self.normalize_permute_op = fd.C.vision.processors.NormalizeAndPermute( + mean, std, is_scale, min, max, swap_rb) + + width = 50 + height = 50 + self.centercrop_op = fd.C.vision.processors.CenterCrop(width, height) + + top = 5 + bottom = 5 + left = 5 + right = 5 + pad_value = [225, 225, 225] + self.pad_op = fd.C.vision.processors.Pad(top, bottom, left, right, + pad_value) + + def apply(self, image_batch): + outputs = [] + self.resize_op(image_batch) + self.centercrop_op(image_batch) + self.pad_op(image_batch) + self.normalize_permute_op(image_batch) + + for i in range(len(image_batch.mats)): + outputs.append(image_batch.mats[i]) + + return outputs + + +if __name__ == "__main__": + + # read jpg + im1 = cv2.imread('ILSVRC2012_val_00000010.jpeg') + im2 = cv2.imread('ILSVRC2012_val_00000010.jpeg') + + mat1 = fd.C.vision.FDMat() + mat1.from_numpy(im1) + mat2 = fd.C.vision.FDMat() + mat2.from_numpy(im2) + images = [mat1, mat2] + + args = parse_arguments() + # creae processor + preprocessor = CustomProcessor() + + # use CV-CUDA + if args.use_cvcuda: + preprocessor.use_cuda(True, -1) + + # show input + for i in range(len(images)): + images[i].print_info('images' + str(i) + ': ') + + # run the Processer with CVCUDA + outputs = preprocessor(images) + + # show output + for i in range(len(outputs)): + outputs[i].print_info('outputs' + str(i) + ': ')