mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
Add trt backend
This commit is contained in:
@@ -13,8 +13,8 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "dpm_solver_multistep_scheduler.h"
|
||||
#include "fastdeploy/vision/common/processors/mat.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/common/processors/mat.h"
|
||||
#include "opencv2/highgui/highgui.hpp"
|
||||
#include "opencv2/imgproc/imgproc.hpp"
|
||||
#include "pipeline_stable_diffusion_inpaint.h"
|
||||
@@ -22,6 +22,13 @@
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#ifdef WIN32
|
||||
const char sep = '\\';
|
||||
#else
|
||||
const char sep = '/';
|
||||
#endif
|
||||
|
||||
template <typename T> std::string Str(const T* value, int size) {
|
||||
std::ostringstream oss;
|
||||
@@ -33,17 +40,40 @@ template <typename T> std::string Str(const T* value, int size) {
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
std::unique_ptr<fastdeploy::Runtime>
|
||||
CreateRuntime(const std::string& model_file, const std::string& params_file,
|
||||
bool use_paddle_backend = true) {
|
||||
std::unique_ptr<fastdeploy::Runtime> CreateRuntime(
|
||||
const std::string& model_file, const std::string& params_file,
|
||||
bool use_trt_backend = false, bool use_fp16 = false,
|
||||
const std::unordered_map<std::string, std::vector<std::vector<int>>>&
|
||||
dynamic_shapes = {},
|
||||
const std::vector<std::string>& disable_paddle_trt_ops = {}) {
|
||||
fastdeploy::RuntimeOption runtime_option;
|
||||
runtime_option.SetModelPath(model_file, params_file,
|
||||
fastdeploy::ModelFormat::PADDLE);
|
||||
runtime_option.UseGpu();
|
||||
if (use_paddle_backend) {
|
||||
if (!use_trt_backend) {
|
||||
runtime_option.UsePaddleBackend();
|
||||
} else {
|
||||
runtime_option.UseOrtBackend();
|
||||
runtime_option.UseTrtBackend();
|
||||
runtime_option.EnablePaddleToTrt();
|
||||
for (auto it = dynamic_shapes.begin(); it != dynamic_shapes.end(); ++it) {
|
||||
if (it->second.size() != 3) {
|
||||
std::cerr << "The size of dynamic_shapes of input `" << it->first
|
||||
<< "` should be 3, but receive " << it->second.size()
|
||||
<< std::endl;
|
||||
continue;
|
||||
}
|
||||
std::vector<int> min_shape = (it->second)[0];
|
||||
std::vector<int> opt_shape = (it->second)[1];
|
||||
std::vector<int> max_shape = (it->second)[2];
|
||||
runtime_option.SetTrtInputShape(it->first, min_shape, opt_shape,
|
||||
max_shape);
|
||||
}
|
||||
runtime_option.SetTrtCacheFile("");
|
||||
runtime_option.EnablePaddleTrtCollectShape();
|
||||
runtime_option.DisablePaddleTrtOPs(disable_paddle_trt_ops);
|
||||
if (use_fp16) {
|
||||
runtime_option.EnableTrtFP16();
|
||||
}
|
||||
}
|
||||
std::unique_ptr<fastdeploy::Runtime> runtime =
|
||||
std::unique_ptr<fastdeploy::Runtime>(new fastdeploy::Runtime());
|
||||
@@ -59,6 +89,13 @@ CreateRuntime(const std::string& model_file, const std::string& params_file,
|
||||
}
|
||||
|
||||
int main() {
|
||||
// 0. Init all configs
|
||||
std::string model_dir = "sd15_inpaint";
|
||||
int max_length = 77;
|
||||
bool use_trt_backend = true;
|
||||
bool use_fp16 = true;
|
||||
int batch_size = 1;
|
||||
|
||||
// 1. Init scheduler
|
||||
std::unique_ptr<fastdeploy::Scheduler> dpm(
|
||||
new fastdeploy::DPMSolverMultistepScheduler(
|
||||
@@ -77,37 +114,74 @@ int main() {
|
||||
/* lower_order_final = */ true));
|
||||
|
||||
// 2. Init text encoder runtime
|
||||
std::string text_model_file = "sd15_inpaint/text_encoder/inference.pdmodel";
|
||||
std::string text_params_file =
|
||||
"sd15_inpaint/text_encoder/inference.pdiparams";
|
||||
std::unordered_map<std::string, std::vector<std::vector<int>>>
|
||||
text_dynamic_shape = {{"input_ids",
|
||||
{/* min_shape */ {1, max_length},
|
||||
/* opt_shape */ {batch_size, max_length},
|
||||
/* max_shape */ {2 * batch_size, max_length}}}};
|
||||
std::string text_model_dir = model_dir + sep + "text_encoder";
|
||||
std::string text_model_file = text_model_dir + sep + "inference.pdmodel";
|
||||
std::string text_params_file = text_model_dir + sep + "inference.pdiparams";
|
||||
std::unique_ptr<fastdeploy::Runtime> text_encoder_runtime =
|
||||
CreateRuntime(text_model_file, text_params_file, false);
|
||||
CreateRuntime(text_model_file, text_params_file, use_trt_backend,
|
||||
use_fp16, text_dynamic_shape);
|
||||
|
||||
// 3. Init vae encoder runtime
|
||||
std::unordered_map<std::string, std::vector<std::vector<int>>>
|
||||
vae_encoder_dynamic_shape = {
|
||||
{"sample",
|
||||
{/* min_shape */ {1, 3, 512, 512},
|
||||
/* opt_shape */ {2 * batch_size, 3, 512, 512},
|
||||
/* max_shape */ {2 * batch_size, 3, 512, 512}}}};
|
||||
std::string vae_encoder_model_dir = model_dir + sep + "vae_encoder";
|
||||
std::string vae_encoder_model_file =
|
||||
"sd15_inpaint/vae_encoder/inference.pdmodel";
|
||||
vae_encoder_model_dir + sep + "inference.pdmodel";
|
||||
std::string vae_encoder_params_file =
|
||||
"sd15_inpaint/vae_encoder/inference.pdiparams";
|
||||
vae_encoder_model_dir + sep + "inference.pdiparams";
|
||||
std::unique_ptr<fastdeploy::Runtime> vae_encoder_runtime =
|
||||
CreateRuntime(vae_encoder_model_file, vae_encoder_params_file);
|
||||
CreateRuntime(vae_encoder_model_file, vae_encoder_params_file,
|
||||
use_trt_backend, use_fp16, vae_encoder_dynamic_shape);
|
||||
|
||||
// 4. Init vae decoder runtime
|
||||
std::unordered_map<std::string, std::vector<std::vector<int>>>
|
||||
vae_decoder_dynamic_shape = {
|
||||
{"latent_sample",
|
||||
{/* min_shape */ {1, 4, 64, 64},
|
||||
/* opt_shape */ {2 * batch_size, 4, 64, 64},
|
||||
/* max_shape */ {2 * batch_size, 4, 64, 64}}}};
|
||||
std::string vae_decoder_model_dir = model_dir + sep + "vae_decoder";
|
||||
std::string vae_decoder_model_file =
|
||||
"sd15_inpaint/vae_decoder/inference.pdmodel";
|
||||
vae_decoder_model_dir + sep + "inference.pdmodel";
|
||||
std::string vae_decoder_params_file =
|
||||
"sd15_inpaint/vae_decoder/inference.pdiparams";
|
||||
vae_decoder_model_dir + sep + "inference.pdiparams";
|
||||
std::unique_ptr<fastdeploy::Runtime> vae_decoder_runtime =
|
||||
CreateRuntime(vae_decoder_model_file, vae_decoder_params_file);
|
||||
CreateRuntime(vae_decoder_model_file, vae_decoder_params_file,
|
||||
use_trt_backend, use_fp16, vae_decoder_dynamic_shape);
|
||||
|
||||
// 5. Init unet runtime
|
||||
std::string unet_model_file = "sd15_inpaint/unet/inference.pdmodel";
|
||||
std::string unet_params_file = "sd15_inpaint/unet/inference.pdiparams";
|
||||
constexpr int unet_inpaint_channels = 9;
|
||||
std::unordered_map<std::string, std::vector<std::vector<int>>>
|
||||
unet_dynamic_shape = {
|
||||
{"sample",
|
||||
{/* min_shape */ {1, unet_inpaint_channels, 64, 64},
|
||||
/* opt_shape */ {2 * batch_size, unet_inpaint_channels, 64, 64},
|
||||
/* max_shape */ {2 * batch_size, unet_inpaint_channels, 64, 64}}},
|
||||
{"timesteps", {{1}, {1}, {1}}},
|
||||
{"encoder_hidden_states",
|
||||
{{1, max_length, 768},
|
||||
{2 * batch_size, max_length, 768},
|
||||
{2 * batch_size, max_length, 768}}}};
|
||||
std::vector<std::string> unet_disable_paddle_trt_ops = {"sin", "cos"};
|
||||
std::string unet_model_dir = model_dir + sep + "unet";
|
||||
std::string unet_model_file = unet_model_dir + sep + "inference.pdmodel";
|
||||
std::string unet_params_file = unet_model_dir + sep + "inference.pdiparams";
|
||||
std::unique_ptr<fastdeploy::Runtime> unet_runtime =
|
||||
CreateRuntime(unet_model_file, unet_params_file);
|
||||
CreateRuntime(unet_model_file, unet_params_file, use_trt_backend,
|
||||
use_fp16, unet_dynamic_shape, unet_disable_paddle_trt_ops);
|
||||
|
||||
// 6. Init fast tokenizer
|
||||
paddlenlp::fast_tokenizer::tokenizers_impl::ClipFastTokenizer tokenizer(
|
||||
"clip/vocab.json", "clip/merges.txt", /* max_length = */ 77);
|
||||
"clip/vocab.json", "clip/merges.txt", /* max_length = */ max_length);
|
||||
fastdeploy::StableDiffusionInpaintPipeline pipe(
|
||||
std::move(vae_encoder_runtime), std::move(vae_decoder_runtime),
|
||||
std::move(text_encoder_runtime), std::move(unet_runtime),
|
||||
|
Reference in New Issue
Block a user