diff --git a/examples/vision/generation/anemigan/README.md b/examples/vision/generation/anemigan/README.md new file mode 100644 index 000000000..721ed5644 --- /dev/null +++ b/examples/vision/generation/anemigan/README.md @@ -0,0 +1,36 @@ +# 图像生成模型 + +FastDeploy目前支持PaddleHub预训练模型库中如下风格迁移模型的部署 + +| 模型 | 说明 | 模型格式 | +| :--- | :--- | :------- | +|[animegan_v1_hayao_60](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v1_hayao_60&en_category=GANs)|可将输入的图像转换成宫崎骏动漫风格,模型权重转换自AnimeGAN V1官方开源项目|paddle| +|[animegan_v2_paprika_97](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_paprika_97&en_category=GANs)|可将输入的图像转换成今敏红辣椒动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle| +|[animegan_v2_hayao_64](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_hayao_64&en_category=GANs)|可将输入的图像转换成宫崎骏动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle| +|[animegan_v2_shinkai_53](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_shinkai_53&en_category=GANs)|可将输入的图像转换成新海诚动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle| +|[animegan_v2_shinkai_33](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_shinkai_33&en_category=GANs)|可将输入的图像转换成新海诚动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle| +|[animegan_v2_paprika_54](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_paprika_54&en_category=GANs)|可将输入的图像转换成今敏红辣椒动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle| +|[animegan_v2_hayao_99](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_hayao_99&en_category=GANs)|可将输入的图像转换成宫崎骏动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle| +|[animegan_v2_paprika_74](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_paprika_74&en_category=GANs)|可将输入的图像转换成今敏红辣椒动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle| +|[animegan_v2_paprika_98](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_paprika_98&en_category=GANs)|可将输入的图像转换成今敏红辣椒动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle| + +## FastDeploy paddle backend部署和hub速度对比(ips, 越高越好) +| Device | FastDeploy | Hub | +| :--- | :--- | :------- | +| CPU | 0.075 | 0.069| +| GPU | 8.33 | 8.26 | + + + +## 下载预训练模型 +使用fastdeploy.download_model即可以下载模型, 例如下载animegan_v1_hayao_60 +```python +import fastdeploy as fd +fd.download_model(name='animegan_v1_hayao_60', path='./', format='paddle') +``` +将会在当前目录获得animegan_v1_hayao_60的预训练模型。 + +## 详细部署文档 + +- [Python部署](python) +- [C++部署](cpp) diff --git a/examples/vision/generation/anemigan/cpp/CMakeLists.txt b/examples/vision/generation/anemigan/cpp/CMakeLists.txt new file mode 100755 index 000000000..7d1bd2ee1 --- /dev/null +++ b/examples/vision/generation/anemigan/cpp/CMakeLists.txt @@ -0,0 +1,13 @@ +PROJECT(infer_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.10) + +# 指定下载解压后的fastdeploy库路径 +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") +include(${FASTDEPLOY_INSTALL_DIR}/utils/gflags.cmake) +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc) +target_link_libraries(infer_demo ${FASTDEPLOY_LIBS} ${GFLAGS_LIBRARIES}) diff --git a/examples/vision/generation/anemigan/cpp/README.md b/examples/vision/generation/anemigan/cpp/README.md new file mode 100755 index 000000000..9d58c6ad3 --- /dev/null +++ b/examples/vision/generation/anemigan/cpp/README.md @@ -0,0 +1,84 @@ +# AnimeGAN C++部署示例 + +本目录下提供`infer.cc`快速完成AnimeGAN在CPU/GPU部署的示例。 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) +- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +以Linux上AnimeGAN推理为例,在本目录执行如下命令即可完成编译测试,支持此模型需保证FastDeploy版本1.0.2以上(x.x.x>=1.0.2) + +```bash +mkdir build +cd build +# 下载FastDeploy预编译库,用户可在上文提到的`FastDeploy预编译库`中自行选择合适的版本使用 +wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz +tar xvf fastdeploy-linux-x64-x.x.x.tgz +cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x +make -j + +# 下载准备好的模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/style_transfer_testimg.jpg +wget https://bj.bcebos.com/paddlehub/fastdeploy/animegan_v1_hayao_60_v1.0.0.tgz +tar xvfz animegan_v1_hayao_60_v1.0.0.tgz + +# CPU推理 +./infer_demo --model animegan_v1_hayao_60 --image style_transfer_testimg.jpg --device cpu +# GPU推理 +./infer_demo --model animegan_v1_hayao_60 --image style_transfer_testimg.jpg --device gpu +``` + +以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: +- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md) + +## AnimeGAN C++接口 + +### AnimeGAN类 + +```c++ +fastdeploy::vision::generation::AnimeGAN( + const string& model_file, + const string& params_file = "", + const RuntimeOption& runtime_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) +``` + +AnimeGAN模型加载和初始化,其中model_file为导出的Paddle模型结构文件,params_file为模型参数文件。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为Paddle格式 + +#### Predict函数 + +> ```c++ +> bool AnimeGAN::Predict(cv::Mat& image, cv::Mat* result) +> ``` +> +> 模型预测入口,输入图像输出风格迁移后的结果。 +> +> **参数** +> +> > * **image**: 输入数据,注意需为HWC,BGR格式 +> > * **result**: 风格转换后的图像,BGR格式 + +#### BatchPredict函数 + +> ```c++ +> bool AnimeGAN::BatchPredict(const std::vector& images, std::vector* results); +> ``` +> +> 模型预测入口,输入一组图像并输出风格迁移后的结果。 +> +> **参数** +> +> > * **images**: 输入数据,一组图像数据,注意需为HWC,BGR格式 +> > * **results**: 风格转换后的一组图像,BGR格式 + +- [模型介绍](../../) +- [Python部署](../python) +- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/examples/vision/generation/anemigan/cpp/infer.cc b/examples/vision/generation/anemigan/cpp/infer.cc new file mode 100644 index 000000000..ad10797e9 --- /dev/null +++ b/examples/vision/generation/anemigan/cpp/infer.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" +#include "gflags/gflags.h" + +DEFINE_string(model, "", "Directory of the inference model."); +DEFINE_string(image, "", "Path of the image file."); +DEFINE_string(device, "cpu", + "Type of inference device, support 'cpu' or 'gpu'."); + +void PrintUsage() { + std::cout << "Usage: infer_demo --model model_path --image img_path --device [cpu|gpu]" + << std::endl; + std::cout << "Default value of device: cpu" << std::endl; +} + +bool CreateRuntimeOption(fastdeploy::RuntimeOption* option) { + if (FLAGS_device == "gpu") { + option->UseGpu(); + } + else if (FLAGS_device == "cpu") { + option->SetPaddleMKLDNN(false); + return true; + } else { + std::cerr << "Only support device CPU/GPU now, " << FLAGS_device << " is not supported." << std::endl; + return false; + } + + return true; +} + +int main(int argc, char* argv[]) { + google::ParseCommandLineFlags(&argc, &argv, true); + auto option = fastdeploy::RuntimeOption(); + if (!CreateRuntimeOption(&option)) { + PrintUsage(); + return -1; + } + + auto model = fastdeploy::vision::generation::AnimeGAN(FLAGS_model+"/model.pdmodel", FLAGS_model+"/model.pdiparams", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return -1; + } + + auto im = cv::imread(FLAGS_image); + cv::Mat res; + if (!model.Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return -1; + } + + cv::imwrite("style_transfer_result.png", res); + std::cout << "Visualized result saved in ./style_transfer_result.png" << std::endl; + + return 0; +} diff --git a/examples/vision/generation/anemigan/python/README.md b/examples/vision/generation/anemigan/python/README.md new file mode 100644 index 000000000..9c4562402 --- /dev/null +++ b/examples/vision/generation/anemigan/python/README.md @@ -0,0 +1,70 @@ +# AnimeGAN Python部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) +- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +本目录下提供`infer.py`快速完成AnimeGAN在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成 + +```bash +# 下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd FastDeploy/examples/vision/generation/anemigan/python +# 下载准备好的测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/style_transfer_testimg.jpg + +# CPU推理 +python infer.py --model animegan_v1_hayao_60 --image style_transfer_testimg.jpg --device cpu +# GPU推理 +python infer.py --model animegan_v1_hayao_60 --image style_transfer_testimg.jpg --device gpu +``` + +## AnimeGAN Python接口 + +```python +fd.vision.generation.AnimeGAN(model_file, params_file, runtime_option=None, model_format=ModelFormat.PADDLE) +``` + +AnimeGAN模型加载和初始化,其中model_file和params_file为用于Paddle inference的模型结构文件和参数文件。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为Paddle格式 + + +### predict函数 + +> ```python +> AnimeGAN.predict(input_image) +> ``` +> +> 模型预测入口,输入图像输出风格迁移后的结果。 +> +> **参数** +> +> > * **input_image**(np.ndarray): 输入数据,注意需为HWC,BGR格式 + +> **返回** np.ndarray, 风格转换后的图像,BGR格式 + +### batch_predict函数 +> ```python +> AnimeGAN.batch_predict函数(input_images) +> ``` +> +> 模型预测入口,输入一组图像并输出风格迁移后的结果。 +> +> **参数** +> +> > * **input_images**(list(np.ndarray)): 输入数据,一组图像数据,注意需为HWC,BGR格式 + +> **返回** list(np.ndarray), 风格转换后的一组图像,BGR格式 + +## 其它文档 + +- [风格迁移 模型介绍](..) +- [C++部署](../cpp) +- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/examples/vision/generation/anemigan/python/infer.py b/examples/vision/generation/anemigan/python/infer.py new file mode 100644 index 000000000..69f610eda --- /dev/null +++ b/examples/vision/generation/anemigan/python/infer.py @@ -0,0 +1,43 @@ +import cv2 +import os +import fastdeploy as fd + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True, help="Name of the model.") + parser.add_argument( + "--image", type=str, required=True, help="Path of test image file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + if args.device.lower() == "gpu": + option.use_gpu() + else: + option.set_paddle_mkldnn(False) + return option + + +args = parse_arguments() + +# 配置runtime,加载模型 +runtime_option = build_option(args) +fd.download_model(name=args.model, path='./', format='paddle') +model_file = os.path.join(args.model, "model.pdmodel") +params_file = os.path.join(args.model, "model.pdiparams") +model = fd.vision.generation.AnimeGAN( + model_file, params_file, runtime_option=runtime_option) + +# 预测图片并保存结果 +im = cv2.imread(args.image) +result = model.predict(im) +cv2.imwrite('style_transfer_result.png', result) diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index ef2fc90a6..0714a9766 100644 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -55,6 +55,7 @@ #include "fastdeploy/vision/segmentation/ppseg/model.h" #include "fastdeploy/vision/sr/ppsr/model.h" #include "fastdeploy/vision/tracking/pptracking/model.h" +#include "fastdeploy/vision/generation/contrib/animegan.h" #endif diff --git a/fastdeploy/vision/generation/contrib/animegan.cc b/fastdeploy/vision/generation/contrib/animegan.cc new file mode 100644 index 000000000..22962daa1 --- /dev/null +++ b/fastdeploy/vision/generation/contrib/animegan.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/generation/contrib/animegan.h" +#include "fastdeploy/function/functions.h" + +namespace fastdeploy { +namespace vision { +namespace generation { + +AnimeGAN::AnimeGAN(const std::string& model_file, const std::string& params_file, + const RuntimeOption& custom_option, + const ModelFormat& model_format) { + + valid_cpu_backends = {Backend::PDINFER}; + valid_gpu_backends = {Backend::PDINFER}; + + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + + initialized = Initialize(); +} + +bool AnimeGAN::Initialize() { + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + + +bool AnimeGAN::Predict(cv::Mat& img, cv::Mat* result) { + std::vector results; + if (!BatchPredict({img}, &results)) { + return false; + } + *result = std::move(results[0]); + return true; +} + +bool AnimeGAN::BatchPredict(const std::vector& images, std::vector* results) { + std::vector fd_images = WrapMat(images); + std::vector processed_data(1); + if (!preprocessor_.Run(fd_images, &(processed_data))) { + FDERROR << "Failed to preprocess input data while using model:" + << ModelName() << "." << std::endl; + return false; + } + std::vector infer_result(1); + processed_data[0].name = InputInfoOfRuntime(0).name; + + if (!Infer(processed_data, &infer_result)) { + FDERROR << "Failed to inference by runtime." << std::endl; + return false; + } + if (!postprocessor_.Run(infer_result, results)) { + FDERROR << "Failed to postprocess while using model:" << ModelName() << "." + << std::endl; + return false; + } + return true; +} + +} // namespace generation +} // namespace vision +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/generation/contrib/animegan.h b/fastdeploy/vision/generation/contrib/animegan.h new file mode 100644 index 000000000..9d1f9aa27 --- /dev/null +++ b/fastdeploy/vision/generation/contrib/animegan.h @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/generation/contrib/preprocessor.h" +#include "fastdeploy/vision/generation/contrib/postprocessor.h" + +namespace fastdeploy { + +namespace vision { + +namespace generation { +/*! @brief AnimeGAN model object is used when load a AnimeGAN model. + */ +class FASTDEPLOY_DECL AnimeGAN : public FastDeployModel { + public: + /** \brief Set path of model file and the configuration of runtime. + * + * \param[in] model_file Path of model file, e.g ./model.pdmodel + * \param[in] params_file Path of parameter file, e.g ./model.pdiparams, if the model format is ONNX, this parameter will be ignored + * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends" + * \param[in] model_format Model format of the loaded model, default is PADDLE format + */ + AnimeGAN(const std::string& model_file, const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE); + + std::string ModelName() const { return "styletransfer/animegan"; } + + /** \brief Predict the style transfer result for an input image + * + * \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format + * \param[in] result The output style transfer result will be writen to this structure + * \return true if the prediction successed, otherwise false + */ + bool Predict(cv::Mat& img, cv::Mat* result); + + /** \brief Predict the style transfer result for a batch of input images + * + * \param[in] images The list of input images, each element comes from cv::imread(), is a 3-D array with layout HWC, BGR format + * \param[in] results The list of output style transfer results will be writen to this structure + * \return true if the batch prediction successed, otherwise false + */ + bool BatchPredict(const std::vector& images, + std::vector* results); + + // Get preprocessor reference of AnimeGAN + AnimeGANPreprocessor& GetPreprocessor() { + return preprocessor_; + } + + // Get postprocessor reference of AnimeGAN + AnimeGANPostprocessor& GetPostprocessor() { + return postprocessor_; + } + + private: + bool Initialize(); + + AnimeGANPreprocessor preprocessor_; + AnimeGANPostprocessor postprocessor_; +}; + +} // namespace generation +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/generation/contrib/animegan_pybind.cc b/fastdeploy/vision/generation/contrib/animegan_pybind.cc new file mode 100644 index 000000000..853069d71 --- /dev/null +++ b/fastdeploy/vision/generation/contrib/animegan_pybind.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindAnimeGAN(pybind11::module& m) { + pybind11::class_(m, "AnimeGAN") + .def(pybind11::init()) + .def("predict", + [](vision::generation::AnimeGAN& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + cv::Mat res; + self.Predict(mat, &res); + auto ret = pybind11::array_t( + {res.rows, res.cols, res.channels()}, res.data); + return ret; + }) + .def("batch_predict", + [](vision::generation::AnimeGAN& self, std::vector& data) { + std::vector images; + for (size_t i = 0; i < data.size(); ++i) { + images.push_back(PyArrayToCvMat(data[i])); + } + std::vector results; + self.BatchPredict(images, &results); + std::vector> ret; + for(size_t i = 0; i < results.size(); ++i){ + ret.push_back(pybind11::array_t( + {results[i].rows, results[i].cols, results[i].channels()}, results[i].data)); + } + return ret; + }) + .def_property_readonly("preprocessor", &vision::generation::AnimeGAN::GetPreprocessor) + .def_property_readonly("postprocessor", &vision::generation::AnimeGAN::GetPostprocessor); + + pybind11::class_( + m, "AnimeGANPreprocessor") + .def(pybind11::init<>()) + .def("run", [](vision::generation::AnimeGANPreprocessor& self, std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + } + std::vector outputs; + if (!self.Run(images, &outputs)) { + throw std::runtime_error("Failed to preprocess the input data in PaddleClasPreprocessor."); + } + for (size_t i = 0; i < outputs.size(); ++i) { + outputs[i].StopSharing(); + } + return outputs; + }); + pybind11::class_( + m, "AnimeGANPostprocessor") + .def(pybind11::init<>()) + .def("run", [](vision::generation::AnimeGANPostprocessor& self, std::vector& inputs) { + std::vector results; + if (!self.Run(inputs, &results)) { + throw std::runtime_error("Failed to postprocess the runtime result in YOLOv5Postprocessor."); + } + return results; + }); + +} +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/generation/contrib/postprocessor.cc b/fastdeploy/vision/generation/contrib/postprocessor.cc new file mode 100644 index 000000000..68dbaf8f3 --- /dev/null +++ b/fastdeploy/vision/generation/contrib/postprocessor.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/generation/contrib/postprocessor.h" + +namespace fastdeploy { +namespace vision { +namespace generation { + +bool AnimeGANPostprocessor::Run(std::vector& infer_results, + std::vector* results) { + // 1. Reverse normalization + // 2. RGB2BGR + FDTensor& output_tensor = infer_results.at(0); + std::vector shape = output_tensor.Shape(); // n, h, w, c + int size = shape[1] * shape[2] * shape[3]; + results->resize(shape[0]); + float* infer_result_data = reinterpret_cast(output_tensor.Data()); + for(size_t i = 0; i < results->size(); ++i){ + Mat result_mat = Mat::Create(shape[1], shape[2], 3, FDDataType::FP32, infer_result_data+i*size); + std::vector mean{127.5f, 127.5f, 127.5f}; + std::vector std{127.5f, 127.5f, 127.5f}; + Convert::Run(&result_mat, mean, std); + // tmp data type is float[0-1.0],convert to uint type + auto temp = result_mat.GetOpenCVMat(); + cv::Mat res = cv::Mat::zeros(temp->size(), CV_8UC3); + temp->convertTo(res, CV_8UC3, 1); + Mat fd_image = WrapMat(res); + BGR2RGB::Run(&fd_image); + res = *(fd_image.GetOpenCVMat()); + res.copyTo(results->at(i)); + } + return true; +} + +} // namespace generation +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/generation/contrib/postprocessor.h b/fastdeploy/vision/generation/contrib/postprocessor.h new file mode 100644 index 000000000..3f3a7728b --- /dev/null +++ b/fastdeploy/vision/generation/contrib/postprocessor.h @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/function/functions.h" + +namespace fastdeploy { +namespace vision { + +namespace generation { +/*! @brief Postprocessor object for AnimeGAN serials model. + */ +class FASTDEPLOY_DECL AnimeGANPostprocessor { + public: + /** \brief Create a postprocessor instance for AnimeGAN serials model + */ + AnimeGANPostprocessor() {} + + /** \brief Process the result of runtime + * + * \param[in] infer_results The inference results from runtime + * \param[in] results The output results of style transfer + * \return true if the postprocess successed, otherwise false + */ + bool Run(std::vector& infer_results, + std::vector* results); +}; + +} // namespace generation +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/generation/contrib/preprocessor.cc b/fastdeploy/vision/generation/contrib/preprocessor.cc new file mode 100644 index 000000000..24e75fdc3 --- /dev/null +++ b/fastdeploy/vision/generation/contrib/preprocessor.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/generation/contrib/preprocessor.h" + +namespace fastdeploy { +namespace vision { +namespace generation { + +bool AnimeGANPreprocessor::Run(std::vector& images, std::vector* outputs) { + // 1. BGR2RGB + // 2. Convert(opencv style) or Normalize + for (size_t i = 0; i < images.size(); ++i) { + auto ret = BGR2RGB::Run(&images[i]); + if (!ret) { + FDERROR << "Failed to processs image:" << i << " in " + << "BGR2RGB" << "." << std::endl; + return false; + } + ret = Cast::Run(&images[i], "float"); + if (!ret) { + FDERROR << "Failed to processs image:" << i << " in " + << "Cast" << "." << std::endl; + return false; + } + std::vector mean{1.f / 127.5f, 1.f / 127.5f, 1.f / 127.5f}; + std::vector std {-1.f, -1.f, -1.f}; + ret = Convert::Run(&images[i], mean, std); + if (!ret) { + FDERROR << "Failed to processs image:" << i << " in " + << "Cast" << "." << std::endl; + return false; + } + } + outputs->resize(1); + // Concat all the preprocessed data to a batch tensor + std::vector tensors(images.size()); + for (size_t i = 0; i < images.size(); ++i) { + images[i].ShareWithTensor(&(tensors[i])); + tensors[i].ExpandDim(0); + } + if (tensors.size() == 1) { + (*outputs)[0] = std::move(tensors[0]); + } else { + function::Concat(tensors, &((*outputs)[0]), 0); + } + return true; +} + +} // namespace generation +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/generation/contrib/preprocessor.h b/fastdeploy/vision/generation/contrib/preprocessor.h new file mode 100644 index 000000000..4fcf94a15 --- /dev/null +++ b/fastdeploy/vision/generation/contrib/preprocessor.h @@ -0,0 +1,42 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/function/functions.h" + +namespace fastdeploy { +namespace vision { + +namespace generation { +/*! @brief Preprocessor object for AnimeGAN serials model. + */ +class FASTDEPLOY_DECL AnimeGANPreprocessor { + public: + /** \brief Create a preprocessor instance for AnimeGAN serials model + */ + AnimeGANPreprocessor() {} + + /** \brief Process the input image and prepare input tensors for runtime + * + * \param[in] images The input image data list, all the elements are returned wrapped by FDMat. + * \param[in] output The output tensors which will feed in runtime + * \return true if the preprocess successed, otherwise false + */ + bool Run(std::vector& images, std::vector* output); +}; + +} // namespace generation +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/generation/generation_pybind.cc b/fastdeploy/vision/generation/generation_pybind.cc new file mode 100644 index 000000000..d4f02612e --- /dev/null +++ b/fastdeploy/vision/generation/generation_pybind.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { + +void BindAnimeGAN(pybind11::module& m); + +void BindGeneration(pybind11::module& m) { + auto generation_module = m.def_submodule("generation", "image generation submodule"); + BindAnimeGAN(generation_module); +} +} // namespace fastdeploy diff --git a/fastdeploy/vision/vision_pybind.cc b/fastdeploy/vision/vision_pybind.cc index cecd4f7c3..aa387b430 100644 --- a/fastdeploy/vision/vision_pybind.cc +++ b/fastdeploy/vision/vision_pybind.cc @@ -28,6 +28,7 @@ void BindTracking(pybind11::module& m); void BindKeyPointDetection(pybind11::module& m); void BindHeadPose(pybind11::module& m); void BindSR(pybind11::module& m); +void BindGeneration(pybind11::module& m); #ifdef ENABLE_VISION_VISUALIZE void BindVisualize(pybind11::module& m); #endif @@ -213,6 +214,7 @@ void BindVision(pybind11::module& m) { BindKeyPointDetection(m); BindHeadPose(m); BindSR(m); + BindGeneration(m); #ifdef ENABLE_VISION_VISUALIZE BindVisualize(m); #endif diff --git a/python/fastdeploy/vision/__init__.py b/python/fastdeploy/vision/__init__.py index a5531a8a9..ba9a2d0ca 100755 --- a/python/fastdeploy/vision/__init__.py +++ b/python/fastdeploy/vision/__init__.py @@ -26,6 +26,7 @@ from . import ocr from . import headpose from . import sr from . import evaluation +from . import generation from .utils import fd_result_to_json from .visualize import * from .. import C diff --git a/python/fastdeploy/vision/generation/__init__.py b/python/fastdeploy/vision/generation/__init__.py new file mode 100644 index 000000000..f568ed84d --- /dev/null +++ b/python/fastdeploy/vision/generation/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from .contrib.anemigan import AnimeGAN diff --git a/python/fastdeploy/vision/generation/contrib/__init__.py b/python/fastdeploy/vision/generation/contrib/__init__.py new file mode 100644 index 000000000..8034e10bf --- /dev/null +++ b/python/fastdeploy/vision/generation/contrib/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import diff --git a/python/fastdeploy/vision/generation/contrib/anemigan.py b/python/fastdeploy/vision/generation/contrib/anemigan.py new file mode 100644 index 000000000..eaed21c5e --- /dev/null +++ b/python/fastdeploy/vision/generation/contrib/anemigan.py @@ -0,0 +1,102 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import logging +from .... import FastDeployModel, ModelFormat +from .... import c_lib_wrap as C + + +class AnimeGANPreprocessor: + def __init__(self, config_file): + """Create a preprocessor for AnimeGAN. + """ + self._preprocessor = C.vision.generation.AnimeGANPreprocessor() + + def run(self, input_ims): + """Preprocess input images for AnimeGAN. + + :param: input_ims: (list of numpy.ndarray)The input image + :return: list of FDTensor + """ + return self._preprocessor.run(input_ims) + + +class AnimeGANPostprocessor: + def __init__(self): + """Create a postprocessor for AnimeGAN. + """ + self._postprocessor = C.vision.generation.AnimeGANPostprocessor() + + def run(self, runtime_results): + """Postprocess the runtime results for AnimeGAN + + :param: runtime_results: (list of FDTensor)The output FDTensor results from runtime + :return: results: (list) Final results + """ + return self._postprocessor.run(runtime_results) + + +class AnimeGAN(FastDeployModel): + def __init__(self, + model_file, + params_file="", + runtime_option=None, + model_format=ModelFormat.PADDLE): + """Load a AnimeGAN model. + + :param model_file: (str)Path of model file, e.g ./model.pdmodel + :param params_file: (str)Path of parameters file, e.g ./model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + # call super constructor to initialize self._runtime_option + super(AnimeGAN, self).__init__(runtime_option) + + self._model = C.vision.generation.AnimeGAN( + model_file, params_file, self._runtime_option, model_format) + # assert self.initialized to confirm initialization successfully. + assert self.initialized, "AnimeGAN initialize failed." + + def predict(self, input_image): + """ Predict the style transfer result for an input image + + :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :return: style transfer result + """ + return self._model.predict(input_image) + + def batch_predict(self, input_images): + """ Predict the style transfer result for multiple input images + + :param input_images: (list of numpy.ndarray)The list of input image data, each image is a 3-D array with layout HWC, BGR format + :return: a list of style transfer results + """ + return self._model.batch_predict(input_images) + + @property + def preprocessor(self): + """Get AnimeGANPreprocessor object of the loaded model + + :return AnimeGANPreprocessor + """ + return self._model.preprocessor + + @property + def postprocessor(self): + """Get AnimeGANPostprocessor object of the loaded model + + :return AnimeGANPostprocessor + """ + return self._model.postprocessor diff --git a/tests/models/test_animegan.py b/tests/models/test_animegan.py new file mode 100644 index 000000000..d698b05a8 --- /dev/null +++ b/tests/models/test_animegan.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fastdeploy as fd +import cv2 +import os +import numpy as np + + +def test_animegan(): + model_name = 'animegan_v1_hayao_60' + model_path = fd.download_model( + name=model_name, path='./resources', format='paddle') + test_img = 'https://bj.bcebos.com/paddlehub/fastdeploy/style_transfer_testimg.jpg' + label_img = 'https://bj.bcebos.com/paddlehub/fastdeploy/style_transfer_result.png' + fd.download(test_img, "./resources") + fd.download(label_img, "./resources") + # use default backend + runtime_option = fd.RuntimeOption() + runtime_option.set_paddle_mkldnn(False) + model_file = os.path.join(model_path, "model.pdmodel") + params_file = os.path.join(model_path, "model.pdiparams") + animegan = fd.vision.generation.AnimeGAN( + model_file, params_file, runtime_option=runtime_option) + + src_img = cv2.imread("./resources/style_transfer_testimg.jpg") + label_img = cv2.imread("./resources/style_transfer_result.png") + res = animegan.predict(src_img) + + diff = np.fabs(res.astype(np.float32) - label_img.astype(np.float32)) / 255 + assert diff.max() < 1e-04, "There's diff in prediction." + + +if __name__ == "__main__": + test_animegan() diff --git a/tests/models/test_basicvsr.py b/tests/models/test_basicvsr.py index 479343444..9aeabc509 100644 --- a/tests/models/test_basicvsr.py +++ b/tests/models/test_basicvsr.py @@ -69,3 +69,7 @@ def test_basicvsr(): if t >= 10: break capture.release() + + +if __name__ == "__main__": + test_basicvsr() diff --git a/tests/models/test_edvr.py b/tests/models/test_edvr.py index a9f9517e7..a874c7d3b 100644 --- a/tests/models/test_edvr.py +++ b/tests/models/test_edvr.py @@ -1,4 +1,4 @@ -test_pptracking.py # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -74,3 +74,7 @@ def test_edvr(): if t >= 10: break capture.release() + + +if __name__ == "__main__": + test_edvr()