[Model] Add stable diffusion model based on fastdeploy (#297)

* Add stable diffusion model base on fastdeploy

* Add sd infer

* pipelines->multimodal

* add create_ort_runtime

* use fp16 input

* fix pil

* Add optimize unet model

* add hf license

* Add workspace args

* Add profile func

* Add schedulers

* usrelace torch.Tenosr  byp.ndarray

* Add readme

* Add trt shape setting

* add dynamic shape

* Add dynamic shape for stable diffusion

* fix max shape setting

* rename tensorrt file suffix

* update dynamic shape setting

* Add scheduler output

* Add inference_steps and benchmark steps

* add diffuser benchmark

* Add paddle infer script

* Rename 1

* Rename infer.py to torch_onnx_infer.py

* Add export torch to onnx model

* renmove export model

* Add paddle export model for diffusion

* Fix export model

* mv torch onnx infer to infer

* Fix export model

* Fix infer

* modif create_trt_runtime create_ort_runtime

* update export torch

* update requirements

* add paddle inference backend

* Fix unet pp run

* remove print

* Add paddle model export and infer

* Add device id

* remove profile to utils

* Add -1 device id

* Add safety checker args

* remove safety checker temporarily

* Add export model description

* Add predict description

* Fix readme

* Fix device_id description

* add timestep shape

* add use fp16 precision

* move use gpu

* Add EulerAncestralDiscreteScheduler

* Use EulerAncestralDiscreteScheduler with v1-5 model

* Add export model readme

* Add link of exported model

* Update scheduler on README

* Addd stable-diffusion-v1-5
This commit is contained in:
Jack Zhou
2022-11-10 14:59:07 +08:00
committed by GitHub
parent fa807340be
commit d4995e5468
13 changed files with 2301 additions and 0 deletions

View File

@@ -0,0 +1,59 @@
# FastDeploy Diffusion模型高性能部署
本部署示例使用⚡️`FastDeploy`在Huggingface团队[Diffusers](https://github.com/huggingface/diffusers)项目设计的`DiffusionPipeline`基础上完成Diffusion模型的高性能部署。
### 部署模型准备
本示例需要使用训练模型导出后的部署模型。有两种部署模型的获取方式:
- 模型导出方式,可参考[模型导出文档](./export.md)导出部署模型。
- 下载部署模型。为了方便开发者快速测试本示例,我们已经将部分`Diffusion`模型预先导出,开发者只要下载模型就可以快速测试:
| 模型 | Scheduler |
|----------|--------------|
| [CompVis/stable-diffusion-v1-4](https://bj.bcebos.com/fastdeploy/models/stable-diffusion/CompVis/stable-diffusion-v1-4.tgz) | PNDM |
| [runwayml/stable-diffusion-v1-5](https://bj.bcebos.com/fastdeploy/models/stable-diffusion/runwayml/stable-diffusion-v1-5.tgz) | EulerAncestral |
## 环境依赖
在示例中使用了PaddleNLP的CLIP模型的分词器所以需要执行以下命令安装依赖。
```shell
pip install paddlenlp paddlepaddle-gpu
```
### 快速体验
我们经过部署模型准备,可以开始进行测试。下面将指定模型目录以及推理引擎后端,运行`infer.py`脚本,完成推理。
```
python infer.py --model_dir stable-diffusion-v1-4/ --scheduler "pndm" --backend paddle
```
得到的图像文件为fd_astronaut_rides_horse.png。生成的图片示例如下每次生成的图片都不相同示例仅作参考
![fd_astronaut_rides_horse.png](https://user-images.githubusercontent.com/10826371/200261112-68e53389-e0a0-42d1-8c3a-f35faa6627d7.png)
如果使用stable-diffusion-v1-5模型则可执行以下命令完成推理
```
python infer.py --model_dir stable-diffusion-v1-5/ --scheduler "euler_ancestral" --backend paddle
```
#### 参数说明
`infer.py` 除了以上示例的命令行参数,还支持更多命令行参数的设置。以下为各命令行参数的说明。
| 参数 |参数说明 |
|----------|--------------|
| --model_dir | 导出后模型的目录。 |
| --model_format | 模型格式。默认为`'paddle'`,可选列表:`['paddle', 'onnx']`。 |
| --backend | 推理引擎后端。默认为`paddle`,可选列表:`['onnx_runtime', 'paddle']`,当模型格式为`onnx`时,可选列表为`['onnx_runtime']`。 |
| --scheduler | StableDiffusion 模型的scheduler。默认为`'pndm'`。可选列表:`['pndm', 'euler_ancestral']`StableDiffusio模型对应的scheduler可参考[ppdiffuser模型列表](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/ppdiffusers/examples/textual_inversion)。|
| --unet_model_prefix | UNet模型前缀。默认为`unet`。 |
| --vae_model_prefix | VAE模型前缀。默认为`vae_decoder`。 |
| --text_encoder_model_prefix | TextEncoder模型前缀。默认为`text_encoder`。 |
| --inference_steps | UNet模型运行的次数默认为100。 |
| --image_path | 生成图片的路径。默认为`fd_astronaut_rides_horse.png`。 |
| --device_id | gpu设备的id。若`device_id`为-1视为使用cpu推理。 |
| --use_fp16 | 是否使用fp16精度。默认为`False`。使用tensorrt或者paddle-tensorrt后端时可以设为`True`开启。 |

View File

@@ -0,0 +1,156 @@
# Copyright 2022 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
from collections import OrderedDict
from typing import Any, Dict, Tuple, Union
class ConfigMixin:
r"""
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
- [`~ConfigMixin.from_config`]
- [`~ConfigMixin.save_config`]
Class attributes:
- **config_name** (`str`) -- A filename under which the config should stored when calling
[`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by parent class).
"""
config_name = None
ignore_for_config = []
def register_to_config(self, **kwargs):
if self.config_name is None:
raise NotImplementedError(
f"Make sure that {self.__class__} has defined a class name `config_name`"
)
kwargs["_class_name"] = self.__class__.__name__
# Special case for `kwargs` used in deprecation warning added to schedulers
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
# or solve in a more general way.
kwargs.pop("kwargs", None)
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
if not hasattr(self, "_internal_dict"):
internal_dict = kwargs
else:
previous_dict = dict(self._internal_dict)
internal_dict = { ** self._internal_dict, ** kwargs}
logger.debug(
f"Updating config from {previous_dict} to {internal_dict}")
self._internal_dict = FrozenDict(internal_dict)
@property
def config(self) -> Dict[str, Any]:
return self._internal_dict
class FrozenDict(OrderedDict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for key, value in self.items():
setattr(self, key, value)
self.__frozen = True
def __delitem__(self, *args, **kwargs):
raise Exception(
f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
)
def setdefault(self, *args, **kwargs):
raise Exception(
f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
)
def pop(self, *args, **kwargs):
raise Exception(
f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(
f"You cannot use ``update`` on a {self.__class__.__name__} instance."
)
def __setattr__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(
f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance."
)
super().__setattr__(name, value)
def __setitem__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(
f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance."
)
super().__setitem__(name, value)
def register_to_config(init):
r"""
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
shouldn't be registered in the config, use the `ignore_for_config` class variable
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
"""
@functools.wraps(init)
def inner_init(self, *args, **kwargs):
# Ignore private kwargs in the init.
init_kwargs = {
k: v
for k, v in kwargs.items() if not k.startswith("_")
}
init(self, *args, **init_kwargs)
if not isinstance(self, ConfigMixin):
raise RuntimeError(
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
"not inherit from `ConfigMixin`.")
ignore = getattr(self, "ignore_for_config", [])
# Get positional arguments aligned with kwargs
new_kwargs = {}
signature = inspect.signature(init)
parameters = {
name: p.default
for i, (name, p) in enumerate(signature.parameters.items())
if i > 0 and name not in ignore
}
for arg, name in zip(args, parameters.keys()):
new_kwargs[name] = arg
# Then add all kwargs
new_kwargs.update({
k: init_kwargs.get(k, default)
for k, default in parameters.items()
if k not in ignore and k not in new_kwargs
})
getattr(self, "register_to_config")(**new_kwargs)
return inner_init

View File

@@ -0,0 +1,105 @@
# Diffusion模型导出教程
本项目支持两种模型导出方式:[PPDiffusers](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/ppdiffusers)模型导出以及[Diffusers](https://github.com/huggingface/diffusers)模型导出。下面分别介绍这两种模型导出方式。
## PPDiffusers 模型导出
[PPDiffusers](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/ppdiffusers)是一款支持跨模态如图像与语音训练和推理的扩散模型Diffusion Model工具箱其借鉴了🤗 Huggingface团队的[Diffusers](https://github.com/huggingface/diffusers)的优秀设计,并且依托[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)框架和[PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP)自然语言处理库。下面介绍如何使用FastDeploy将PPDiffusers提供的Diffusion模型进行高性能部署。
### 依赖安装
模型导出需要依赖`paddlepaddle`, `paddlenlp`以及`ppdiffusers`,可使用`pip`执行下面的命令进行快速安装。
```shell
pip install -r requirements_paddle.txt
```
### 模型导出
___注意模型导出过程中需要下载StableDiffusion模型。为了使用该模型与权重你必须接受该模型所要求的License请访问HuggingFace的[model card](https://huggingface.co/runwayml/stable-diffusion-v1-5), 仔细阅读里面的License然后签署该协议。___
___Tips: Stable Diffusion是基于以下的License: The CreativeML OpenRAIL M license is an Open RAIL M license, adapted from the work that BigScience and the RAIL Initiative are jointly carrying in the area of responsible AI licensing. See also the article about the BLOOM Open RAIL license on which this license is based.___
可执行以下命令行完成模型导出。
```shell
python export_model.py --pretrained_model_name_or_path CompVis/stable-diffusion-v1-4 --output_path stable-diffusion-v1-4
```
输出的模型目录结构如下:
```shell
stable-diffusion-v1-4/
├── text_encoder
│   ├── inference.pdiparams
│   ├── inference.pdiparams.info
│   └── inference.pdmodel
├── unet
│   ├── inference.pdiparams
│   ├── inference.pdiparams.info
│   └── inference.pdmodel
└── vae_decoder
├── inference.pdiparams
├── inference.pdiparams.info
└── inference.pdmodel
```
#### 参数说明
`export_model.py` 各命令行参数的说明。
| 参数 |参数说明 |
|----------|--------------|
|<div style="width: 230pt">--pretrained_model_name_or_path </div> | ppdiffuers提供的diffusion预训练模型。默认为"CompVis/stable-diffusion-v1-4 "。更多diffusion预训练模型可参考[ppdiffuser模型列表](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/ppdiffusers/examples/textual_inversion)。|
|--output_path | 导出的模型目录。 |
## Diffusers 模型导出
[Diffusers](https://github.com/huggingface/diffusers)是一款由HuggingFace打造的支持跨模态如图像与语音训练和推理的扩散模型Diffusion Model工具箱。其底层的模型代码提供PyTorch实现的版本以及Flax实现的版本两种版本。本示例将介绍如何使用FastDeploy将PyTorch实现的Diffusion模型进行高性能部署。
### 依赖安装
模型导出需要依赖`onnx`, `torch`, `diffusers`以及`transformers`,可使用`pip`执行下面的命令进行快速安装。
```shell
pip install -r requirements_torch.txt
```
### 模型导出
___注意模型导出过程中需要下载StableDiffusion模型。为了使用该模型与权重你必须接受该模型所要求的License并且获取HF Hub授予的Token。请访问HuggingFace的[model card](https://huggingface.co/runwayml/stable-diffusion-v1-5), 仔细阅读里面的License然后签署该协议。___
___Tips: Stable Diffusion是基于以下的License: The CreativeML OpenRAIL M license is an Open RAIL M license, adapted from the work that BigScience and the RAIL Initiative are jointly carrying in the area of responsible AI licensing. See also the article about the BLOOM Open RAIL license on which this license is based.___
若第一次导出模型需要先登录HuggingFace客户端。执行以下命令进行登录
```shell
huggingface-cli login
```
完成登录后,执行以下命令行完成模型导出。
```shell
python export_torch_to_onnx_model.py --pretrained_model_name_or_path CompVis/stable-diffusion-v1-4 --output_path torch_diffusion_model
```
输出的模型目录结构如下:
```shell
torch_diffusion_model/
├── text_encoder
│   └── inference.onnx
├── unet
│   └── inference.onnx
└── vae_decoder
└── inference.onnx
```
#### 参数说明
`export_torch_to_onnx_model.py` 各命令行参数的说明。
| 参数 |参数说明 |
|----------|--------------|
|<div style="width: 230pt">--pretrained_model_name_or_path </div> | ppdiffuers提供的diffusion预训练模型。默认为"CompVis/stable-diffusion-v1-4 "。更多diffusion预训练模型可参考[HuggingFace模型列表说明](https://huggingface.co/CompVis/stable-diffusion-v1-4)。|
|--output_path | 导出的模型目录。 |

View File

@@ -0,0 +1,100 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddlenlp
from ppdiffusers import UNet2DConditionModel, AutoencoderKL
from ppdiffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from paddlenlp.transformers import CLIPTextModel
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model_name_or_path",
default='CompVis/stable-diffusion-v1-4',
help="The pretrained diffusion model.")
parser.add_argument(
"--output_path",
type=str,
required=True,
help="The pretrained diffusion model.")
return parser.parse_args()
class VAEDecoder(AutoencoderKL):
def forward(self, z):
return self.decode(z, True).sample
if __name__ == "__main__":
paddle.set_device('cpu')
args = parse_arguments()
# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(
os.path.join(args.pretrained_model_name_or_path, "text_encoder"))
vae_decoder = VAEDecoder.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet")
# Convert to static graph with specific input description
text_encoder = paddle.jit.to_static(
text_encoder,
input_spec=[
paddle.static.InputSpec(
shape=[None, None], dtype="int64",
name="input_ids") # input_ids
])
# Save text_encoder in static graph model.
save_path = os.path.join(args.output_path, "text_encoder", "inference")
paddle.jit.save(text_encoder, save_path)
print(f"Save text_encoder model in {save_path} successfully.")
# Convert to static graph with specific input description
vae_decoder = paddle.jit.to_static(
vae_decoder,
input_spec=[
paddle.static.InputSpec(
shape=[None, 4, 64, 64], dtype="float32",
name="latent"), # latent
])
# Save vae_decoder in static graph model.
save_path = os.path.join(args.output_path, "vae_decoder", "inference")
paddle.jit.save(vae_decoder, save_path)
print(f"Save vae_decoder model in {save_path} successfully.")
# Convert to static graph with specific input description
unet = paddle.jit.to_static(
unet,
input_spec=[
paddle.static.InputSpec(
shape=[None, 4, None, None],
dtype="float32",
name="latent_input"), # latent
paddle.static.InputSpec(
shape=[1], dtype="int64", name="timestep"), # timesteps
paddle.static.InputSpec(
shape=[None, None, 768],
dtype="float32",
name="encoder_embedding") # encoder_embedding
])
save_path = os.path.join(args.output_path, "unet", "inference")
paddle.jit.save(unet, save_path)
print(f"Save unet model in {save_path} successfully.")

View File

@@ -0,0 +1,159 @@
# Copyright 2022 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import onnx
import torch
import onnxsim
from typing import Optional, Tuple, Union
from diffusers import UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model_name_or_path",
default='CompVis/stable-diffusion-v1-4',
help="The pretrained diffusion model.")
parser.add_argument(
"--output_path",
type=str,
required=True,
help="The pretrained diffusion model.")
return parser.parse_args()
class VAEDecoder(AutoencoderKL):
def forward(self, z):
return self.decode(z, True).sample
if __name__ == "__main__":
args = parse_arguments()
# 1. Load VAE model
vae_decoder = VAEDecoder.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch.float16,
revision="fp16",
subfolder="vae",
use_auth_token=True)
# 2. Load UNet model
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch.float16,
revision="fp16",
subfolder="unet",
use_auth_token=True)
# 3. Load CLIP model
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14")
vae_decoder.cuda()
unet.cuda()
text_encoder.cuda()
os.makedirs(args.output_path, exist_ok=True)
vae_decoder_path = os.path.join(args.output_path, "vae_decoder")
text_encoder_path = os.path.join(args.output_path, "text_encoder")
unet_path = os.path.join(args.output_path, "unet")
for p in [vae_decoder_path, text_encoder_path, unet_path]:
os.makedirs(p, exist_ok=True)
with torch.inference_mode():
# Export vae decoder model
vae_inputs = (torch.randn(
1, 4, 64, 64, dtype=torch.half, device='cuda'), )
torch.onnx.export(
vae_decoder, # model being run
vae_inputs, # model input (or a tuple for multiple inputs)
os.path.join(
vae_decoder_path, "inference.onnx"
), # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=12, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['latent'],
dynamic_axes={
'latent': {
0: 'batch_size',
},
'image': {
0: 'batch_size',
},
},
output_names=['image'])
print("Finish exporting vae decoder.")
# Export the unet model
unet_inputs = (torch.randn(
2, 4, 64, 64, dtype=torch.half, device='cuda'), torch.randn(
1, dtype=torch.half, device='cuda'), torch.randn(
2, 77, 768, dtype=torch.half, device='cuda'))
torch.onnx.export(
unet, # model being run
unet_inputs, # model input (or a tuple for multiple inputs)
os.path.join(
unet_path, "inference.onnx"
), # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=12, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['latent_input', 'timestep', 'encoder_embedding'],
dynamic_axes={
'latent_input': {
0: 'batch_size',
},
'encoder_embedding': {
0: 'batch_size',
1: 'sequence'
},
'latent_output': {
0: 'batch_size',
},
},
output_names=['latent_output'])
print("Finish exporting unet.")
# Export the text_encoder
text_encoder_inputs = (torch.randint(0, 1, (2, 77), device='cuda'), )
torch.onnx.export(
text_encoder, # model being run
text_encoder_inputs, # model input (or a tuple for multiple inputs)
os.path.join(
text_encoder_path, "inference.onnx"
), # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=14, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input_ids'],
dynamic_axes={
'input_ids': {
0: 'batch_size',
1: 'sequence'
},
'logits': {
0: 'batch_size',
1: 'sequence'
}
},
output_names=['logits'])
print("Finish exporting text encoder.")

View File

@@ -0,0 +1,320 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import os
from pipeline_stable_diffusion import StableDiffusionFastDeployPipeline
from scheduling_utils import PNDMScheduler, EulerAncestralDiscreteScheduler
try:
from paddlenlp.transformers import CLIPTokenizer
except ImportError:
from transformers import CLIPTokenizer
import fastdeploy as fd
from fastdeploy import ModelFormat
import numpy as np
import distutils.util
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_dir",
default="paddle_diffusion_model",
help="The model directory of diffusion_model.")
parser.add_argument(
"--model_format",
default="paddle",
choices=['paddle', 'onnx'],
help="The model format.")
parser.add_argument(
"--unet_model_prefix",
default='unet',
help="The file prefix of unet model.")
parser.add_argument(
"--vae_model_prefix",
default='vae_decoder',
help="The file prefix of vae model.")
parser.add_argument(
"--text_encoder_model_prefix",
default='text_encoder',
help="The file prefix of text_encoder model.")
parser.add_argument(
"--inference_steps",
type=int,
default=100,
help="The number of unet inference steps.")
parser.add_argument(
"--benchmark_steps",
type=int,
default=1,
help="The number of performance benchmark steps.")
parser.add_argument(
"--backend",
type=str,
default='paddle',
# Note(zhoushunjie): Will support 'tensorrt', 'paddle-tensorrt' soon.
choices=[
'onnx_runtime',
'paddle',
],
help="The inference runtime backend of unet model and text encoder model."
)
parser.add_argument(
"--image_path",
default="fd_astronaut_rides_horse.png",
help="The model directory of diffusion_model.")
parser.add_argument(
"--use_fp16",
type=distutils.util.strtobool,
default=False,
help="Wheter to use FP16 mode")
parser.add_argument(
"--device_id",
type=int,
default=0,
help="The selected gpu id. -1 means use cpu")
parser.add_argument(
"--scheduler",
type=str,
default='pndm',
choices=['pndm', 'euler_ancestral'],
help="The scheduler type of stable diffusion.")
return parser.parse_args()
def create_ort_runtime(model_dir, model_prefix, model_format, device_id=0):
option = fd.RuntimeOption()
option.use_ort_backend()
option.use_gpu(device_id)
if model_format == "paddle":
model_file = os.path.join(model_dir, model_prefix, "inference.pdmodel")
params_file = os.path.join(model_dir, model_prefix,
"inference.pdiparams")
option.set_model_path(model_file, params_file)
else:
onnx_file = os.path.join(model_dir, model_prefix, "inference.onnx")
option.set_model_path(onnx_file, model_format=ModelFormat.ONNX)
return fd.Runtime(option)
def create_paddle_inference_runtime(model_dir,
model_prefix,
use_trt=False,
dynamic_shape=None,
use_fp16=False,
device_id=0):
option = fd.RuntimeOption()
option.use_paddle_backend()
if device_id == -1:
option.use_cpu()
else:
option.use_gpu(device_id)
if use_trt:
option.use_trt_backend()
option.enable_paddle_to_trt()
if use_fp16:
option.enable_trt_fp16()
cache_file = os.path.join(model_dir, model_prefix, "inference.trt")
option.set_trt_cache_file(cache_file)
# Need to enable collect shape for ernie
if dynamic_shape is not None:
option.enable_paddle_trt_collect_shape()
for key, shape_dict in dynamic_shape.items():
option.set_trt_input_shape(
key,
min_shape=shape_dict["min_shape"],
opt_shape=shape_dict.get("opt_shape", None),
max_shape=shape_dict.get("max_shape", None))
model_file = os.path.join(model_dir, model_prefix, "inference.pdmodel")
params_file = os.path.join(model_dir, model_prefix, "inference.pdiparams")
option.set_model_path(model_file, params_file)
return fd.Runtime(option)
def create_trt_runtime(model_dir,
model_prefix,
model_format,
workspace=(1 << 31),
dynamic_shape=None,
device_id=0):
option = fd.RuntimeOption()
option.use_trt_backend()
option.use_gpu(device_id)
option.enable_trt_fp16()
option.set_trt_max_workspace_size(workspace)
if dynamic_shape is not None:
for key, shape_dict in dynamic_shape.items():
option.set_trt_input_shape(
key,
min_shape=shape_dict["min_shape"],
opt_shape=shape_dict.get("opt_shape", None),
max_shape=shape_dict.get("max_shape", None))
if model_format == "paddle":
model_file = os.path.join(model_dir, model_prefix, "inference.pdmodel")
params_file = os.path.join(model_dir, model_prefix,
"inference.pdiparams")
option.set_model_path(model_file, params_file)
else:
onnx_file = os.path.join(model_dir, model_prefix, "inference.onnx")
option.set_model_path(onnx_file, model_format=ModelFormat.ONNX)
cache_file = os.path.join(model_dir, model_prefix, "inference.trt")
option.set_trt_cache_file(cache_file)
return fd.Runtime(option)
def get_scheduler(args):
if args.scheduler == "pndm":
scheduler = PNDMScheduler(
beta_end=0.012,
beta_schedule="scaled_linear",
beta_start=0.00085,
num_train_timesteps=1000,
skip_prk_steps=True)
elif args.scheduler == "euler_ancestral":
scheduler = EulerAncestralDiscreteScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
else:
raise ValueError(
f"Scheduler '{args.scheduler}' is not supportted right now.")
return scheduler
if __name__ == "__main__":
args = parse_arguments()
# 1. Init scheduler
scheduler = get_scheduler(args)
# 2. Init tokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# 3. Set dynamic shape for trt backend
vae_dynamic_shape = {
"latent": {
"min_shape": [1, 4, 64, 64],
"max_shape": [2, 4, 64, 64],
"opt_shape": [2, 4, 64, 64],
}
}
unet_dynamic_shape = {
"latent_input": {
"min_shape": [1, 4, 64, 64],
"max_shape": [2, 4, 64, 64],
"opt_shape": [2, 4, 64, 64],
},
"timestep": {
"min_shape": [1],
"max_shape": [1],
"opt_shape": [1],
},
"encoder_embedding": {
"min_shape": [1, 77, 768],
"max_shape": [2, 77, 768],
"opt_shape": [2, 77, 768],
},
}
# 4. Init runtime
if args.backend == "onnx_runtime":
text_encoder_runtime = create_ort_runtime(
args.model_dir,
args.text_encoder_model_prefix,
args.model_format,
device_id=args.device_id)
vae_decoder_runtime = create_ort_runtime(
args.model_dir,
args.vae_model_prefix,
args.model_format,
device_id=args.device_id)
start = time.time()
unet_runtime = create_ort_runtime(
args.model_dir,
args.unet_model_prefix,
args.model_format,
device_id=args.device_id)
print(f"Spend {time.time() - start : .2f} s to load unet model.")
elif args.backend == "paddle" or args.backend == "paddle-tensorrt":
use_trt = True if args.backend == "paddle-tensorrt" else False
# Note(zhoushunjie): Will change to paddle runtime later
text_encoder_runtime = create_ort_runtime(
args.model_dir,
args.text_encoder_model_prefix,
args.model_format,
device_id=args.device_id)
vae_decoder_runtime = create_paddle_inference_runtime(
args.model_dir,
args.vae_model_prefix,
use_trt,
vae_dynamic_shape,
use_fp16=args.use_fp16,
device_id=args.device_id)
start = time.time()
unet_runtime = create_paddle_inference_runtime(
args.model_dir,
args.unet_model_prefix,
use_trt,
unet_dynamic_shape,
use_fp16=args.use_fp16,
device_id=args.device_id)
print(f"Spend {time.time() - start : .2f} s to load unet model.")
elif args.backend == "tensorrt":
text_encoder_runtime = create_ort_runtime(
args.model_dir, args.text_encoder_model_prefix, args.model_format)
vae_decoder_runtime = create_trt_runtime(
args.model_dir,
args.vae_model_prefix,
args.model_format,
workspace=(1 << 30),
dynamic_shape=vae_dynamic_shape,
device_id=args.device_id)
start = time.time()
unet_runtime = create_trt_runtime(
args.model_dir,
args.unet_model_prefix,
args.model_format,
dynamic_shape=unet_dynamic_shape,
device_id=args.device_id)
print(f"Spend {time.time() - start : .2f} s to load unet model.")
pipe = StableDiffusionFastDeployPipeline(
vae_decoder_runtime=vae_decoder_runtime,
text_encoder_runtime=text_encoder_runtime,
tokenizer=tokenizer,
unet_runtime=unet_runtime,
scheduler=scheduler)
prompt = "a photo of an astronaut riding a horse on mars"
# Warm up
pipe(prompt, num_inference_steps=10)
time_costs = []
print(
f"Run the stable diffusion pipeline {args.benchmark_steps} times to test the performance."
)
for step in range(args.benchmark_steps):
start = time.time()
image = pipe(prompt, num_inference_steps=args.inference_steps)[0]
latency = time.time() - start
time_costs += [latency]
print(f"No {step:3d} time cost: {latency:2f} s")
print(
f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
)
image.save(args.image_path)
print(f"Image saved in {args.image_path}!")

View File

@@ -0,0 +1,236 @@
# Copyright 2022 The HuggingFace Inc. team.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
from paddlenlp.transformers import CLIPTokenizer
import fastdeploy as fd
from scheduling_utils import PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler
import PIL
from PIL import Image
import logging
class StableDiffusionFastDeployPipeline(object):
vae_decoder_runtime: fd.Runtime
text_encoder_runtime: fd.Runtime
tokenizer: CLIPTokenizer
unet_runtime: fd.Runtime
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler,
EulerAncestralDiscreteScheduler]
def __init__(self,
vae_decoder_runtime: fd.Runtime,
text_encoder_runtime: fd.Runtime,
tokenizer: CLIPTokenizer,
unet_runtime: fd.Runtime,
scheduler: Union[DDIMScheduler, PNDMScheduler,
LMSDiscreteScheduler]):
self.vae_decoder_runtime = vae_decoder_runtime
self.text_encoder_runtime = text_encoder_runtime
self.unet_runtime = unet_runtime
self.scheduler = scheduler
self.tokenizer = tokenizer
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int]=512,
width: Optional[int]=512,
num_inference_steps: Optional[int]=50,
guidance_scale: Optional[float]=7.5,
negative_prompt: Optional[Union[str, List[str]]]=None,
num_images_per_prompt: Optional[int]=1,
eta: Optional[float]=0.0,
generator: Optional[np.random.RandomState]=None,
latents: Optional[np.ndarray]=None,
output_type: Optional[str]="pil",
return_dict: bool=True,
callback: Optional[Callable[[int, int, np.ndarray], None]]=None,
callback_steps: Optional[int]=1,
**kwargs, ):
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
)
if (callback_steps is None) or (callback_steps is not None and (
not isinstance(callback_steps, int) or callback_steps <= 0)):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}.")
if generator is None:
generator = np.random
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="np", )
text_input_ids = text_inputs.input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(
text_input_ids[:, self.tokenizer.model_max_length:])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}")
text_input_ids = text_input_ids[:, :
self.tokenizer.model_max_length]
input_name = self.text_encoder_runtime.get_input_info(0).name
text_embeddings = self.text_encoder_runtime.infer({
input_name: text_input_ids.astype(np.int64)
})[0]
text_embeddings = np.repeat(
text_embeddings, num_images_per_prompt, axis=0)
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}.")
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`.")
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="np")
uncond_embeddings = self.text_encoder_runtime.infer({
input_name: uncond_input.input_ids.astype(np.int64)
})[0]
uncond_embeddings = np.repeat(
uncond_embeddings, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = np.concatenate(
[uncond_embeddings, text_embeddings])
# get the initial random noise unless the user supplied it
latents_dtype = text_embeddings.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8,
width // 8)
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape:
raise ValueError(
f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}"
)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
latents = latents * self.scheduler.init_noise_sigma
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
for i, t in enumerate(self.scheduler.timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate(
[latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t)
# predict the noise residual
sample_name = self.unet_runtime.get_input_info(0).name
timestep_name = self.unet_runtime.get_input_info(1).name
encoder_hidden_states_name = self.unet_runtime.get_input_info(
2).name
# Required fp16 input.
input_type = [np.float16, np.float16, np.float16]
if self.unet_runtime.get_input_info(0).dtype == fd.FDDataType.FP32:
input_type = [np.float32, np.int64, np.float32]
noise_pred = self.unet_runtime.infer({
sample_name: latent_model_input.astype(input_type[0]),
timestep_name: np.array(
[t], dtype=input_type[1]),
encoder_hidden_states_name:
text_embeddings.astype(input_type[2]),
})[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents,
**extra_step_kwargs).prev_sample
latents = np.array(latents)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
sample_name = self.vae_decoder_runtime.get_input_info(0).name
input_dtype = np.float16
if self.vae_decoder_runtime.get_input_info(
0).dtype == fd.FDDataType.FP32:
input_dtype = np.float32
image = self.vae_decoder_runtime.infer({
sample_name: latents.astype(input_dtype)
})[0]
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
if output_type == "pil":
image = self.numpy_to_pil(image)
return image
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images

View File

@@ -0,0 +1,3 @@
ppdiffusers
paddlenlp
paddlepaddle-gpu

View File

@@ -0,0 +1,5 @@
onnx
torch
diffusers
transformers
scipy

File diff suppressed because it is too large Load Diff

View File

@@ -37,3 +37,4 @@ from . import vision
from . import pipeline from . import pipeline
from . import text from . import text
from .download import download, download_and_decompress, download_model from .download import download, download_and_decompress, download_model
from .utils import profile

View File

@@ -11,3 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .profile import profile

View File

@@ -0,0 +1,28 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cProfile, pstats, io
from pstats import SortKey
def profile(func, *args, **kwargs):
pr = cProfile.Profile()
pr.enable()
func(*args, **kwargs)
pr.disable()
s = io.StringIO()
sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())