mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Model] add style transfer model (#922)
* add style transfer model * add examples for generation model * add unit test * add speed comparison * add speed comparison * add variable for constant * add preprocessor and postprocessor * add preprocessor and postprocessor * fix * fix according to review Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
36
examples/vision/generation/anemigan/README.md
Normal file
36
examples/vision/generation/anemigan/README.md
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# 图像生成模型
|
||||||
|
|
||||||
|
FastDeploy目前支持PaddleHub预训练模型库中如下风格迁移模型的部署
|
||||||
|
|
||||||
|
| 模型 | 说明 | 模型格式 |
|
||||||
|
| :--- | :--- | :------- |
|
||||||
|
|[animegan_v1_hayao_60](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v1_hayao_60&en_category=GANs)|可将输入的图像转换成宫崎骏动漫风格,模型权重转换自AnimeGAN V1官方开源项目|paddle|
|
||||||
|
|[animegan_v2_paprika_97](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_paprika_97&en_category=GANs)|可将输入的图像转换成今敏红辣椒动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle|
|
||||||
|
|[animegan_v2_hayao_64](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_hayao_64&en_category=GANs)|可将输入的图像转换成宫崎骏动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle|
|
||||||
|
|[animegan_v2_shinkai_53](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_shinkai_53&en_category=GANs)|可将输入的图像转换成新海诚动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle|
|
||||||
|
|[animegan_v2_shinkai_33](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_shinkai_33&en_category=GANs)|可将输入的图像转换成新海诚动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle|
|
||||||
|
|[animegan_v2_paprika_54](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_paprika_54&en_category=GANs)|可将输入的图像转换成今敏红辣椒动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle|
|
||||||
|
|[animegan_v2_hayao_99](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_hayao_99&en_category=GANs)|可将输入的图像转换成宫崎骏动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle|
|
||||||
|
|[animegan_v2_paprika_74](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_paprika_74&en_category=GANs)|可将输入的图像转换成今敏红辣椒动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle|
|
||||||
|
|[animegan_v2_paprika_98](https://www.paddlepaddle.org.cn/hubdetail?name=animegan_v2_paprika_98&en_category=GANs)|可将输入的图像转换成今敏红辣椒动漫风格,模型权重转换自AnimeGAN V2官方开源项目|paddle|
|
||||||
|
|
||||||
|
## FastDeploy paddle backend部署和hub速度对比(ips, 越高越好)
|
||||||
|
| Device | FastDeploy | Hub |
|
||||||
|
| :--- | :--- | :------- |
|
||||||
|
| CPU | 0.075 | 0.069|
|
||||||
|
| GPU | 8.33 | 8.26 |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 下载预训练模型
|
||||||
|
使用fastdeploy.download_model即可以下载模型, 例如下载animegan_v1_hayao_60
|
||||||
|
```python
|
||||||
|
import fastdeploy as fd
|
||||||
|
fd.download_model(name='animegan_v1_hayao_60', path='./', format='paddle')
|
||||||
|
```
|
||||||
|
将会在当前目录获得animegan_v1_hayao_60的预训练模型。
|
||||||
|
|
||||||
|
## 详细部署文档
|
||||||
|
|
||||||
|
- [Python部署](python)
|
||||||
|
- [C++部署](cpp)
|
13
examples/vision/generation/anemigan/cpp/CMakeLists.txt
Executable file
13
examples/vision/generation/anemigan/cpp/CMakeLists.txt
Executable file
@@ -0,0 +1,13 @@
|
|||||||
|
PROJECT(infer_demo C CXX)
|
||||||
|
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
|
||||||
|
|
||||||
|
# 指定下载解压后的fastdeploy库路径
|
||||||
|
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
|
||||||
|
include(${FASTDEPLOY_INSTALL_DIR}/utils/gflags.cmake)
|
||||||
|
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
|
||||||
|
|
||||||
|
# 添加FastDeploy依赖头文件
|
||||||
|
include_directories(${FASTDEPLOY_INCS})
|
||||||
|
|
||||||
|
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
|
||||||
|
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS} ${GFLAGS_LIBRARIES})
|
84
examples/vision/generation/anemigan/cpp/README.md
Executable file
84
examples/vision/generation/anemigan/cpp/README.md
Executable file
@@ -0,0 +1,84 @@
|
|||||||
|
# AnimeGAN C++部署示例
|
||||||
|
|
||||||
|
本目录下提供`infer.cc`快速完成AnimeGAN在CPU/GPU部署的示例。
|
||||||
|
|
||||||
|
在部署前,需确认以下两个步骤
|
||||||
|
|
||||||
|
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
|
||||||
|
- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
|
||||||
|
|
||||||
|
以Linux上AnimeGAN推理为例,在本目录执行如下命令即可完成编译测试,支持此模型需保证FastDeploy版本1.0.2以上(x.x.x>=1.0.2)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
# 下载FastDeploy预编译库,用户可在上文提到的`FastDeploy预编译库`中自行选择合适的版本使用
|
||||||
|
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz
|
||||||
|
tar xvf fastdeploy-linux-x64-x.x.x.tgz
|
||||||
|
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x
|
||||||
|
make -j
|
||||||
|
|
||||||
|
# 下载准备好的模型文件和测试图片
|
||||||
|
wget https://bj.bcebos.com/paddlehub/fastdeploy/style_transfer_testimg.jpg
|
||||||
|
wget https://bj.bcebos.com/paddlehub/fastdeploy/animegan_v1_hayao_60_v1.0.0.tgz
|
||||||
|
tar xvfz animegan_v1_hayao_60_v1.0.0.tgz
|
||||||
|
|
||||||
|
# CPU推理
|
||||||
|
./infer_demo --model animegan_v1_hayao_60 --image style_transfer_testimg.jpg --device cpu
|
||||||
|
# GPU推理
|
||||||
|
./infer_demo --model animegan_v1_hayao_60 --image style_transfer_testimg.jpg --device gpu
|
||||||
|
```
|
||||||
|
|
||||||
|
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
|
||||||
|
- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md)
|
||||||
|
|
||||||
|
## AnimeGAN C++接口
|
||||||
|
|
||||||
|
### AnimeGAN类
|
||||||
|
|
||||||
|
```c++
|
||||||
|
fastdeploy::vision::generation::AnimeGAN(
|
||||||
|
const string& model_file,
|
||||||
|
const string& params_file = "",
|
||||||
|
const RuntimeOption& runtime_option = RuntimeOption(),
|
||||||
|
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||||
|
```
|
||||||
|
|
||||||
|
AnimeGAN模型加载和初始化,其中model_file为导出的Paddle模型结构文件,params_file为模型参数文件。
|
||||||
|
|
||||||
|
**参数**
|
||||||
|
|
||||||
|
> * **model_file**(str): 模型文件路径
|
||||||
|
> * **params_file**(str): 参数文件路径
|
||||||
|
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||||
|
> * **model_format**(ModelFormat): 模型格式,默认为Paddle格式
|
||||||
|
|
||||||
|
#### Predict函数
|
||||||
|
|
||||||
|
> ```c++
|
||||||
|
> bool AnimeGAN::Predict(cv::Mat& image, cv::Mat* result)
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> 模型预测入口,输入图像输出风格迁移后的结果。
|
||||||
|
>
|
||||||
|
> **参数**
|
||||||
|
>
|
||||||
|
> > * **image**: 输入数据,注意需为HWC,BGR格式
|
||||||
|
> > * **result**: 风格转换后的图像,BGR格式
|
||||||
|
|
||||||
|
#### BatchPredict函数
|
||||||
|
|
||||||
|
> ```c++
|
||||||
|
> bool AnimeGAN::BatchPredict(const std::vector<cv::Mat>& images, std::vector<cv::Mat>* results);
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> 模型预测入口,输入一组图像并输出风格迁移后的结果。
|
||||||
|
>
|
||||||
|
> **参数**
|
||||||
|
>
|
||||||
|
> > * **images**: 输入数据,一组图像数据,注意需为HWC,BGR格式
|
||||||
|
> > * **results**: 风格转换后的一组图像,BGR格式
|
||||||
|
|
||||||
|
- [模型介绍](../../)
|
||||||
|
- [Python部署](../python)
|
||||||
|
- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md)
|
69
examples/vision/generation/anemigan/cpp/infer.cc
Normal file
69
examples/vision/generation/anemigan/cpp/infer.cc
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/vision.h"
|
||||||
|
#include "gflags/gflags.h"
|
||||||
|
|
||||||
|
DEFINE_string(model, "", "Directory of the inference model.");
|
||||||
|
DEFINE_string(image, "", "Path of the image file.");
|
||||||
|
DEFINE_string(device, "cpu",
|
||||||
|
"Type of inference device, support 'cpu' or 'gpu'.");
|
||||||
|
|
||||||
|
void PrintUsage() {
|
||||||
|
std::cout << "Usage: infer_demo --model model_path --image img_path --device [cpu|gpu]"
|
||||||
|
<< std::endl;
|
||||||
|
std::cout << "Default value of device: cpu" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CreateRuntimeOption(fastdeploy::RuntimeOption* option) {
|
||||||
|
if (FLAGS_device == "gpu") {
|
||||||
|
option->UseGpu();
|
||||||
|
}
|
||||||
|
else if (FLAGS_device == "cpu") {
|
||||||
|
option->SetPaddleMKLDNN(false);
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
std::cerr << "Only support device CPU/GPU now, " << FLAGS_device << " is not supported." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
google::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
|
auto option = fastdeploy::RuntimeOption();
|
||||||
|
if (!CreateRuntimeOption(&option)) {
|
||||||
|
PrintUsage();
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto model = fastdeploy::vision::generation::AnimeGAN(FLAGS_model+"/model.pdmodel", FLAGS_model+"/model.pdiparams", option);
|
||||||
|
if (!model.Initialized()) {
|
||||||
|
std::cerr << "Failed to initialize." << std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto im = cv::imread(FLAGS_image);
|
||||||
|
cv::Mat res;
|
||||||
|
if (!model.Predict(im, &res)) {
|
||||||
|
std::cerr << "Failed to predict." << std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
cv::imwrite("style_transfer_result.png", res);
|
||||||
|
std::cout << "Visualized result saved in ./style_transfer_result.png" << std::endl;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
70
examples/vision/generation/anemigan/python/README.md
Normal file
70
examples/vision/generation/anemigan/python/README.md
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
# AnimeGAN Python部署示例
|
||||||
|
|
||||||
|
在部署前,需确认以下两个步骤
|
||||||
|
|
||||||
|
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
|
||||||
|
- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
|
||||||
|
|
||||||
|
本目录下提供`infer.py`快速完成AnimeGAN在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 下载部署示例代码
|
||||||
|
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
||||||
|
cd FastDeploy/examples/vision/generation/anemigan/python
|
||||||
|
# 下载准备好的测试图片
|
||||||
|
wget https://bj.bcebos.com/paddlehub/fastdeploy/style_transfer_testimg.jpg
|
||||||
|
|
||||||
|
# CPU推理
|
||||||
|
python infer.py --model animegan_v1_hayao_60 --image style_transfer_testimg.jpg --device cpu
|
||||||
|
# GPU推理
|
||||||
|
python infer.py --model animegan_v1_hayao_60 --image style_transfer_testimg.jpg --device gpu
|
||||||
|
```
|
||||||
|
|
||||||
|
## AnimeGAN Python接口
|
||||||
|
|
||||||
|
```python
|
||||||
|
fd.vision.generation.AnimeGAN(model_file, params_file, runtime_option=None, model_format=ModelFormat.PADDLE)
|
||||||
|
```
|
||||||
|
|
||||||
|
AnimeGAN模型加载和初始化,其中model_file和params_file为用于Paddle inference的模型结构文件和参数文件。
|
||||||
|
|
||||||
|
**参数**
|
||||||
|
|
||||||
|
> * **model_file**(str): 模型文件路径
|
||||||
|
> * **params_file**(str): 参数文件路径
|
||||||
|
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||||
|
> * **model_format**(ModelFormat): 模型格式,默认为Paddle格式
|
||||||
|
|
||||||
|
|
||||||
|
### predict函数
|
||||||
|
|
||||||
|
> ```python
|
||||||
|
> AnimeGAN.predict(input_image)
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> 模型预测入口,输入图像输出风格迁移后的结果。
|
||||||
|
>
|
||||||
|
> **参数**
|
||||||
|
>
|
||||||
|
> > * **input_image**(np.ndarray): 输入数据,注意需为HWC,BGR格式
|
||||||
|
|
||||||
|
> **返回** np.ndarray, 风格转换后的图像,BGR格式
|
||||||
|
|
||||||
|
### batch_predict函数
|
||||||
|
> ```python
|
||||||
|
> AnimeGAN.batch_predict函数(input_images)
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> 模型预测入口,输入一组图像并输出风格迁移后的结果。
|
||||||
|
>
|
||||||
|
> **参数**
|
||||||
|
>
|
||||||
|
> > * **input_images**(list(np.ndarray)): 输入数据,一组图像数据,注意需为HWC,BGR格式
|
||||||
|
|
||||||
|
> **返回** list(np.ndarray), 风格转换后的一组图像,BGR格式
|
||||||
|
|
||||||
|
## 其它文档
|
||||||
|
|
||||||
|
- [风格迁移 模型介绍](..)
|
||||||
|
- [C++部署](../cpp)
|
||||||
|
- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md)
|
43
examples/vision/generation/anemigan/python/infer.py
Normal file
43
examples/vision/generation/anemigan/python/infer.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
import fastdeploy as fd
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model", required=True, help="Name of the model.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--image", type=str, required=True, help="Path of test image file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default='cpu',
|
||||||
|
help="Type of inference device, support 'cpu' or 'gpu'.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def build_option(args):
|
||||||
|
option = fd.RuntimeOption()
|
||||||
|
if args.device.lower() == "gpu":
|
||||||
|
option.use_gpu()
|
||||||
|
else:
|
||||||
|
option.set_paddle_mkldnn(False)
|
||||||
|
return option
|
||||||
|
|
||||||
|
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
# 配置runtime,加载模型
|
||||||
|
runtime_option = build_option(args)
|
||||||
|
fd.download_model(name=args.model, path='./', format='paddle')
|
||||||
|
model_file = os.path.join(args.model, "model.pdmodel")
|
||||||
|
params_file = os.path.join(args.model, "model.pdiparams")
|
||||||
|
model = fd.vision.generation.AnimeGAN(
|
||||||
|
model_file, params_file, runtime_option=runtime_option)
|
||||||
|
|
||||||
|
# 预测图片并保存结果
|
||||||
|
im = cv2.imread(args.image)
|
||||||
|
result = model.predict(im)
|
||||||
|
cv2.imwrite('style_transfer_result.png', result)
|
@@ -55,6 +55,7 @@
|
|||||||
#include "fastdeploy/vision/segmentation/ppseg/model.h"
|
#include "fastdeploy/vision/segmentation/ppseg/model.h"
|
||||||
#include "fastdeploy/vision/sr/ppsr/model.h"
|
#include "fastdeploy/vision/sr/ppsr/model.h"
|
||||||
#include "fastdeploy/vision/tracking/pptracking/model.h"
|
#include "fastdeploy/vision/tracking/pptracking/model.h"
|
||||||
|
#include "fastdeploy/vision/generation/contrib/animegan.h"
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
80
fastdeploy/vision/generation/contrib/animegan.cc
Normal file
80
fastdeploy/vision/generation/contrib/animegan.cc
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/generation/contrib/animegan.h"
|
||||||
|
#include "fastdeploy/function/functions.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
namespace generation {
|
||||||
|
|
||||||
|
AnimeGAN::AnimeGAN(const std::string& model_file, const std::string& params_file,
|
||||||
|
const RuntimeOption& custom_option,
|
||||||
|
const ModelFormat& model_format) {
|
||||||
|
|
||||||
|
valid_cpu_backends = {Backend::PDINFER};
|
||||||
|
valid_gpu_backends = {Backend::PDINFER};
|
||||||
|
|
||||||
|
runtime_option = custom_option;
|
||||||
|
runtime_option.model_format = model_format;
|
||||||
|
runtime_option.model_file = model_file;
|
||||||
|
runtime_option.params_file = params_file;
|
||||||
|
|
||||||
|
initialized = Initialize();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnimeGAN::Initialize() {
|
||||||
|
if (!InitRuntime()) {
|
||||||
|
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool AnimeGAN::Predict(cv::Mat& img, cv::Mat* result) {
|
||||||
|
std::vector<cv::Mat> results;
|
||||||
|
if (!BatchPredict({img}, &results)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*result = std::move(results[0]);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnimeGAN::BatchPredict(const std::vector<cv::Mat>& images, std::vector<cv::Mat>* results) {
|
||||||
|
std::vector<FDMat> fd_images = WrapMat(images);
|
||||||
|
std::vector<FDTensor> processed_data(1);
|
||||||
|
if (!preprocessor_.Run(fd_images, &(processed_data))) {
|
||||||
|
FDERROR << "Failed to preprocess input data while using model:"
|
||||||
|
<< ModelName() << "." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::vector<FDTensor> infer_result(1);
|
||||||
|
processed_data[0].name = InputInfoOfRuntime(0).name;
|
||||||
|
|
||||||
|
if (!Infer(processed_data, &infer_result)) {
|
||||||
|
FDERROR << "Failed to inference by runtime." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!postprocessor_.Run(infer_result, results)) {
|
||||||
|
FDERROR << "Failed to postprocess while using model:" << ModelName() << "."
|
||||||
|
<< std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace generation
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
79
fastdeploy/vision/generation/contrib/animegan.h
Normal file
79
fastdeploy/vision/generation/contrib/animegan.h
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "fastdeploy/fastdeploy_model.h"
|
||||||
|
#include "fastdeploy/vision/common/processors/transform.h"
|
||||||
|
#include "fastdeploy/vision/generation/contrib/preprocessor.h"
|
||||||
|
#include "fastdeploy/vision/generation/contrib/postprocessor.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
|
||||||
|
namespace vision {
|
||||||
|
|
||||||
|
namespace generation {
|
||||||
|
/*! @brief AnimeGAN model object is used when load a AnimeGAN model.
|
||||||
|
*/
|
||||||
|
class FASTDEPLOY_DECL AnimeGAN : public FastDeployModel {
|
||||||
|
public:
|
||||||
|
/** \brief Set path of model file and the configuration of runtime.
|
||||||
|
*
|
||||||
|
* \param[in] model_file Path of model file, e.g ./model.pdmodel
|
||||||
|
* \param[in] params_file Path of parameter file, e.g ./model.pdiparams, if the model format is ONNX, this parameter will be ignored
|
||||||
|
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends"
|
||||||
|
* \param[in] model_format Model format of the loaded model, default is PADDLE format
|
||||||
|
*/
|
||||||
|
AnimeGAN(const std::string& model_file, const std::string& params_file = "",
|
||||||
|
const RuntimeOption& custom_option = RuntimeOption(),
|
||||||
|
const ModelFormat& model_format = ModelFormat::PADDLE);
|
||||||
|
|
||||||
|
std::string ModelName() const { return "styletransfer/animegan"; }
|
||||||
|
|
||||||
|
/** \brief Predict the style transfer result for an input image
|
||||||
|
*
|
||||||
|
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
|
||||||
|
* \param[in] result The output style transfer result will be writen to this structure
|
||||||
|
* \return true if the prediction successed, otherwise false
|
||||||
|
*/
|
||||||
|
bool Predict(cv::Mat& img, cv::Mat* result);
|
||||||
|
|
||||||
|
/** \brief Predict the style transfer result for a batch of input images
|
||||||
|
*
|
||||||
|
* \param[in] images The list of input images, each element comes from cv::imread(), is a 3-D array with layout HWC, BGR format
|
||||||
|
* \param[in] results The list of output style transfer results will be writen to this structure
|
||||||
|
* \return true if the batch prediction successed, otherwise false
|
||||||
|
*/
|
||||||
|
bool BatchPredict(const std::vector<cv::Mat>& images,
|
||||||
|
std::vector<cv::Mat>* results);
|
||||||
|
|
||||||
|
// Get preprocessor reference of AnimeGAN
|
||||||
|
AnimeGANPreprocessor& GetPreprocessor() {
|
||||||
|
return preprocessor_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get postprocessor reference of AnimeGAN
|
||||||
|
AnimeGANPostprocessor& GetPostprocessor() {
|
||||||
|
return postprocessor_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool Initialize();
|
||||||
|
|
||||||
|
AnimeGANPreprocessor preprocessor_;
|
||||||
|
AnimeGANPostprocessor postprocessor_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace generation
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
78
fastdeploy/vision/generation/contrib/animegan_pybind.cc
Normal file
78
fastdeploy/vision/generation/contrib/animegan_pybind.cc
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/pybind/main.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
void BindAnimeGAN(pybind11::module& m) {
|
||||||
|
pybind11::class_<vision::generation::AnimeGAN, FastDeployModel>(m, "AnimeGAN")
|
||||||
|
.def(pybind11::init<std::string, std::string, RuntimeOption,
|
||||||
|
ModelFormat>())
|
||||||
|
.def("predict",
|
||||||
|
[](vision::generation::AnimeGAN& self, pybind11::array& data) {
|
||||||
|
auto mat = PyArrayToCvMat(data);
|
||||||
|
cv::Mat res;
|
||||||
|
self.Predict(mat, &res);
|
||||||
|
auto ret = pybind11::array_t<unsigned char>(
|
||||||
|
{res.rows, res.cols, res.channels()}, res.data);
|
||||||
|
return ret;
|
||||||
|
})
|
||||||
|
.def("batch_predict",
|
||||||
|
[](vision::generation::AnimeGAN& self, std::vector<pybind11::array>& data) {
|
||||||
|
std::vector<cv::Mat> images;
|
||||||
|
for (size_t i = 0; i < data.size(); ++i) {
|
||||||
|
images.push_back(PyArrayToCvMat(data[i]));
|
||||||
|
}
|
||||||
|
std::vector<cv::Mat> results;
|
||||||
|
self.BatchPredict(images, &results);
|
||||||
|
std::vector<pybind11::array_t<unsigned char>> ret;
|
||||||
|
for(size_t i = 0; i < results.size(); ++i){
|
||||||
|
ret.push_back(pybind11::array_t<unsigned char>(
|
||||||
|
{results[i].rows, results[i].cols, results[i].channels()}, results[i].data));
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
})
|
||||||
|
.def_property_readonly("preprocessor", &vision::generation::AnimeGAN::GetPreprocessor)
|
||||||
|
.def_property_readonly("postprocessor", &vision::generation::AnimeGAN::GetPostprocessor);
|
||||||
|
|
||||||
|
pybind11::class_<vision::generation::AnimeGANPreprocessor>(
|
||||||
|
m, "AnimeGANPreprocessor")
|
||||||
|
.def(pybind11::init<>())
|
||||||
|
.def("run", [](vision::generation::AnimeGANPreprocessor& self, std::vector<pybind11::array>& im_list) {
|
||||||
|
std::vector<vision::FDMat> images;
|
||||||
|
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||||
|
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
|
||||||
|
}
|
||||||
|
std::vector<FDTensor> outputs;
|
||||||
|
if (!self.Run(images, &outputs)) {
|
||||||
|
throw std::runtime_error("Failed to preprocess the input data in PaddleClasPreprocessor.");
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
|
outputs[i].StopSharing();
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
});
|
||||||
|
pybind11::class_<vision::generation::AnimeGANPostprocessor>(
|
||||||
|
m, "AnimeGANPostprocessor")
|
||||||
|
.def(pybind11::init<>())
|
||||||
|
.def("run", [](vision::generation::AnimeGANPostprocessor& self, std::vector<FDTensor>& inputs) {
|
||||||
|
std::vector<cv::Mat> results;
|
||||||
|
if (!self.Run(inputs, &results)) {
|
||||||
|
throw std::runtime_error("Failed to postprocess the runtime result in YOLOv5Postprocessor.");
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
});
|
||||||
|
|
||||||
|
}
|
||||||
|
} // namespace fastdeploy
|
49
fastdeploy/vision/generation/contrib/postprocessor.cc
Normal file
49
fastdeploy/vision/generation/contrib/postprocessor.cc
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/generation/contrib/postprocessor.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
namespace generation {
|
||||||
|
|
||||||
|
bool AnimeGANPostprocessor::Run(std::vector<FDTensor>& infer_results,
|
||||||
|
std::vector<cv::Mat>* results) {
|
||||||
|
// 1. Reverse normalization
|
||||||
|
// 2. RGB2BGR
|
||||||
|
FDTensor& output_tensor = infer_results.at(0);
|
||||||
|
std::vector<int64_t> shape = output_tensor.Shape(); // n, h, w, c
|
||||||
|
int size = shape[1] * shape[2] * shape[3];
|
||||||
|
results->resize(shape[0]);
|
||||||
|
float* infer_result_data = reinterpret_cast<float*>(output_tensor.Data());
|
||||||
|
for(size_t i = 0; i < results->size(); ++i){
|
||||||
|
Mat result_mat = Mat::Create(shape[1], shape[2], 3, FDDataType::FP32, infer_result_data+i*size);
|
||||||
|
std::vector<float> mean{127.5f, 127.5f, 127.5f};
|
||||||
|
std::vector<float> std{127.5f, 127.5f, 127.5f};
|
||||||
|
Convert::Run(&result_mat, mean, std);
|
||||||
|
// tmp data type is float[0-1.0],convert to uint type
|
||||||
|
auto temp = result_mat.GetOpenCVMat();
|
||||||
|
cv::Mat res = cv::Mat::zeros(temp->size(), CV_8UC3);
|
||||||
|
temp->convertTo(res, CV_8UC3, 1);
|
||||||
|
Mat fd_image = WrapMat(res);
|
||||||
|
BGR2RGB::Run(&fd_image);
|
||||||
|
res = *(fd_image.GetOpenCVMat());
|
||||||
|
res.copyTo(results->at(i));
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace generation
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
43
fastdeploy/vision/generation/contrib/postprocessor.h
Normal file
43
fastdeploy/vision/generation/contrib/postprocessor.h
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "fastdeploy/vision/common/processors/transform.h"
|
||||||
|
#include "fastdeploy/function/functions.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
|
||||||
|
namespace generation {
|
||||||
|
/*! @brief Postprocessor object for AnimeGAN serials model.
|
||||||
|
*/
|
||||||
|
class FASTDEPLOY_DECL AnimeGANPostprocessor {
|
||||||
|
public:
|
||||||
|
/** \brief Create a postprocessor instance for AnimeGAN serials model
|
||||||
|
*/
|
||||||
|
AnimeGANPostprocessor() {}
|
||||||
|
|
||||||
|
/** \brief Process the result of runtime
|
||||||
|
*
|
||||||
|
* \param[in] infer_results The inference results from runtime
|
||||||
|
* \param[in] results The output results of style transfer
|
||||||
|
* \return true if the postprocess successed, otherwise false
|
||||||
|
*/
|
||||||
|
bool Run(std::vector<FDTensor>& infer_results,
|
||||||
|
std::vector<cv::Mat>* results);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace generation
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
63
fastdeploy/vision/generation/contrib/preprocessor.cc
Normal file
63
fastdeploy/vision/generation/contrib/preprocessor.cc
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/generation/contrib/preprocessor.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
namespace generation {
|
||||||
|
|
||||||
|
bool AnimeGANPreprocessor::Run(std::vector<Mat>& images, std::vector<FDTensor>* outputs) {
|
||||||
|
// 1. BGR2RGB
|
||||||
|
// 2. Convert(opencv style) or Normalize
|
||||||
|
for (size_t i = 0; i < images.size(); ++i) {
|
||||||
|
auto ret = BGR2RGB::Run(&images[i]);
|
||||||
|
if (!ret) {
|
||||||
|
FDERROR << "Failed to processs image:" << i << " in "
|
||||||
|
<< "BGR2RGB" << "." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ret = Cast::Run(&images[i], "float");
|
||||||
|
if (!ret) {
|
||||||
|
FDERROR << "Failed to processs image:" << i << " in "
|
||||||
|
<< "Cast" << "." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::vector<float> mean{1.f / 127.5f, 1.f / 127.5f, 1.f / 127.5f};
|
||||||
|
std::vector<float> std {-1.f, -1.f, -1.f};
|
||||||
|
ret = Convert::Run(&images[i], mean, std);
|
||||||
|
if (!ret) {
|
||||||
|
FDERROR << "Failed to processs image:" << i << " in "
|
||||||
|
<< "Cast" << "." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
outputs->resize(1);
|
||||||
|
// Concat all the preprocessed data to a batch tensor
|
||||||
|
std::vector<FDTensor> tensors(images.size());
|
||||||
|
for (size_t i = 0; i < images.size(); ++i) {
|
||||||
|
images[i].ShareWithTensor(&(tensors[i]));
|
||||||
|
tensors[i].ExpandDim(0);
|
||||||
|
}
|
||||||
|
if (tensors.size() == 1) {
|
||||||
|
(*outputs)[0] = std::move(tensors[0]);
|
||||||
|
} else {
|
||||||
|
function::Concat(tensors, &((*outputs)[0]), 0);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace generation
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
42
fastdeploy/vision/generation/contrib/preprocessor.h
Normal file
42
fastdeploy/vision/generation/contrib/preprocessor.h
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "fastdeploy/vision/common/processors/transform.h"
|
||||||
|
#include "fastdeploy/function/functions.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
|
||||||
|
namespace generation {
|
||||||
|
/*! @brief Preprocessor object for AnimeGAN serials model.
|
||||||
|
*/
|
||||||
|
class FASTDEPLOY_DECL AnimeGANPreprocessor {
|
||||||
|
public:
|
||||||
|
/** \brief Create a preprocessor instance for AnimeGAN serials model
|
||||||
|
*/
|
||||||
|
AnimeGANPreprocessor() {}
|
||||||
|
|
||||||
|
/** \brief Process the input image and prepare input tensors for runtime
|
||||||
|
*
|
||||||
|
* \param[in] images The input image data list, all the elements are returned wrapped by FDMat.
|
||||||
|
* \param[in] output The output tensors which will feed in runtime
|
||||||
|
* \return true if the preprocess successed, otherwise false
|
||||||
|
*/
|
||||||
|
bool Run(std::vector<Mat>& images, std::vector<FDTensor>* output);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace generation
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
25
fastdeploy/vision/generation/generation_pybind.cc
Normal file
25
fastdeploy/vision/generation/generation_pybind.cc
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/pybind/main.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
|
||||||
|
void BindAnimeGAN(pybind11::module& m);
|
||||||
|
|
||||||
|
void BindGeneration(pybind11::module& m) {
|
||||||
|
auto generation_module = m.def_submodule("generation", "image generation submodule");
|
||||||
|
BindAnimeGAN(generation_module);
|
||||||
|
}
|
||||||
|
} // namespace fastdeploy
|
@@ -28,6 +28,7 @@ void BindTracking(pybind11::module& m);
|
|||||||
void BindKeyPointDetection(pybind11::module& m);
|
void BindKeyPointDetection(pybind11::module& m);
|
||||||
void BindHeadPose(pybind11::module& m);
|
void BindHeadPose(pybind11::module& m);
|
||||||
void BindSR(pybind11::module& m);
|
void BindSR(pybind11::module& m);
|
||||||
|
void BindGeneration(pybind11::module& m);
|
||||||
#ifdef ENABLE_VISION_VISUALIZE
|
#ifdef ENABLE_VISION_VISUALIZE
|
||||||
void BindVisualize(pybind11::module& m);
|
void BindVisualize(pybind11::module& m);
|
||||||
#endif
|
#endif
|
||||||
@@ -213,6 +214,7 @@ void BindVision(pybind11::module& m) {
|
|||||||
BindKeyPointDetection(m);
|
BindKeyPointDetection(m);
|
||||||
BindHeadPose(m);
|
BindHeadPose(m);
|
||||||
BindSR(m);
|
BindSR(m);
|
||||||
|
BindGeneration(m);
|
||||||
#ifdef ENABLE_VISION_VISUALIZE
|
#ifdef ENABLE_VISION_VISUALIZE
|
||||||
BindVisualize(m);
|
BindVisualize(m);
|
||||||
#endif
|
#endif
|
||||||
|
@@ -26,6 +26,7 @@ from . import ocr
|
|||||||
from . import headpose
|
from . import headpose
|
||||||
from . import sr
|
from . import sr
|
||||||
from . import evaluation
|
from . import evaluation
|
||||||
|
from . import generation
|
||||||
from .utils import fd_result_to_json
|
from .utils import fd_result_to_json
|
||||||
from .visualize import *
|
from .visualize import *
|
||||||
from .. import C
|
from .. import C
|
||||||
|
16
python/fastdeploy/vision/generation/__init__.py
Normal file
16
python/fastdeploy/vision/generation/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from .contrib.anemigan import AnimeGAN
|
15
python/fastdeploy/vision/generation/contrib/__init__.py
Normal file
15
python/fastdeploy/vision/generation/contrib/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
102
python/fastdeploy/vision/generation/contrib/anemigan.py
Normal file
102
python/fastdeploy/vision/generation/contrib/anemigan.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
import logging
|
||||||
|
from .... import FastDeployModel, ModelFormat
|
||||||
|
from .... import c_lib_wrap as C
|
||||||
|
|
||||||
|
|
||||||
|
class AnimeGANPreprocessor:
|
||||||
|
def __init__(self, config_file):
|
||||||
|
"""Create a preprocessor for AnimeGAN.
|
||||||
|
"""
|
||||||
|
self._preprocessor = C.vision.generation.AnimeGANPreprocessor()
|
||||||
|
|
||||||
|
def run(self, input_ims):
|
||||||
|
"""Preprocess input images for AnimeGAN.
|
||||||
|
|
||||||
|
:param: input_ims: (list of numpy.ndarray)The input image
|
||||||
|
:return: list of FDTensor
|
||||||
|
"""
|
||||||
|
return self._preprocessor.run(input_ims)
|
||||||
|
|
||||||
|
|
||||||
|
class AnimeGANPostprocessor:
|
||||||
|
def __init__(self):
|
||||||
|
"""Create a postprocessor for AnimeGAN.
|
||||||
|
"""
|
||||||
|
self._postprocessor = C.vision.generation.AnimeGANPostprocessor()
|
||||||
|
|
||||||
|
def run(self, runtime_results):
|
||||||
|
"""Postprocess the runtime results for AnimeGAN
|
||||||
|
|
||||||
|
:param: runtime_results: (list of FDTensor)The output FDTensor results from runtime
|
||||||
|
:return: results: (list) Final results
|
||||||
|
"""
|
||||||
|
return self._postprocessor.run(runtime_results)
|
||||||
|
|
||||||
|
|
||||||
|
class AnimeGAN(FastDeployModel):
|
||||||
|
def __init__(self,
|
||||||
|
model_file,
|
||||||
|
params_file="",
|
||||||
|
runtime_option=None,
|
||||||
|
model_format=ModelFormat.PADDLE):
|
||||||
|
"""Load a AnimeGAN model.
|
||||||
|
|
||||||
|
:param model_file: (str)Path of model file, e.g ./model.pdmodel
|
||||||
|
:param params_file: (str)Path of parameters file, e.g ./model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
|
||||||
|
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
|
||||||
|
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
|
||||||
|
"""
|
||||||
|
# call super constructor to initialize self._runtime_option
|
||||||
|
super(AnimeGAN, self).__init__(runtime_option)
|
||||||
|
|
||||||
|
self._model = C.vision.generation.AnimeGAN(
|
||||||
|
model_file, params_file, self._runtime_option, model_format)
|
||||||
|
# assert self.initialized to confirm initialization successfully.
|
||||||
|
assert self.initialized, "AnimeGAN initialize failed."
|
||||||
|
|
||||||
|
def predict(self, input_image):
|
||||||
|
""" Predict the style transfer result for an input image
|
||||||
|
|
||||||
|
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
|
||||||
|
:return: style transfer result
|
||||||
|
"""
|
||||||
|
return self._model.predict(input_image)
|
||||||
|
|
||||||
|
def batch_predict(self, input_images):
|
||||||
|
""" Predict the style transfer result for multiple input images
|
||||||
|
|
||||||
|
:param input_images: (list of numpy.ndarray)The list of input image data, each image is a 3-D array with layout HWC, BGR format
|
||||||
|
:return: a list of style transfer results
|
||||||
|
"""
|
||||||
|
return self._model.batch_predict(input_images)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def preprocessor(self):
|
||||||
|
"""Get AnimeGANPreprocessor object of the loaded model
|
||||||
|
|
||||||
|
:return AnimeGANPreprocessor
|
||||||
|
"""
|
||||||
|
return self._model.preprocessor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def postprocessor(self):
|
||||||
|
"""Get AnimeGANPostprocessor object of the loaded model
|
||||||
|
|
||||||
|
:return AnimeGANPostprocessor
|
||||||
|
"""
|
||||||
|
return self._model.postprocessor
|
46
tests/models/test_animegan.py
Normal file
46
tests/models/test_animegan.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import fastdeploy as fd
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def test_animegan():
|
||||||
|
model_name = 'animegan_v1_hayao_60'
|
||||||
|
model_path = fd.download_model(
|
||||||
|
name=model_name, path='./resources', format='paddle')
|
||||||
|
test_img = 'https://bj.bcebos.com/paddlehub/fastdeploy/style_transfer_testimg.jpg'
|
||||||
|
label_img = 'https://bj.bcebos.com/paddlehub/fastdeploy/style_transfer_result.png'
|
||||||
|
fd.download(test_img, "./resources")
|
||||||
|
fd.download(label_img, "./resources")
|
||||||
|
# use default backend
|
||||||
|
runtime_option = fd.RuntimeOption()
|
||||||
|
runtime_option.set_paddle_mkldnn(False)
|
||||||
|
model_file = os.path.join(model_path, "model.pdmodel")
|
||||||
|
params_file = os.path.join(model_path, "model.pdiparams")
|
||||||
|
animegan = fd.vision.generation.AnimeGAN(
|
||||||
|
model_file, params_file, runtime_option=runtime_option)
|
||||||
|
|
||||||
|
src_img = cv2.imread("./resources/style_transfer_testimg.jpg")
|
||||||
|
label_img = cv2.imread("./resources/style_transfer_result.png")
|
||||||
|
res = animegan.predict(src_img)
|
||||||
|
|
||||||
|
diff = np.fabs(res.astype(np.float32) - label_img.astype(np.float32)) / 255
|
||||||
|
assert diff.max() < 1e-04, "There's diff in prediction."
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_animegan()
|
@@ -69,3 +69,7 @@ def test_basicvsr():
|
|||||||
if t >= 10:
|
if t >= 10:
|
||||||
break
|
break
|
||||||
capture.release()
|
capture.release()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_basicvsr()
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
test_pptracking.py # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -74,3 +74,7 @@ def test_edvr():
|
|||||||
if t >= 10:
|
if t >= 10:
|
||||||
break
|
break
|
||||||
capture.release()
|
capture.release()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_edvr()
|
||||||
|
Reference in New Issue
Block a user