From bf6caeb2ce44ab75756faecbf0a5eaebdef0a934 Mon Sep 17 00:00:00 2001 From: wangguoya <39376046+wgy0804@users.noreply.github.com> Date: Thu, 16 Mar 2023 19:48:15 +0800 Subject: [PATCH] modify sd demo infer.py for using paddle_kunlunxin_fp16 (#1612) * modify sd infer.py for using paddle_kunlunxin_fp16 * Update infer.py --- examples/multimodal/stable_diffusion/infer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/multimodal/stable_diffusion/infer.py b/examples/multimodal/stable_diffusion/infer.py index a22b569c1..001f67864 100755 --- a/examples/multimodal/stable_diffusion/infer.py +++ b/examples/multimodal/stable_diffusion/infer.py @@ -175,7 +175,7 @@ def create_trt_runtime(model_dir, return fd.Runtime(option) -def create_kunlunxin_runtime(model_dir, model_prefix, device_id=0): +def create_kunlunxin_runtime(model_dir, model_prefix, use_fp16=False, device_id=0): option = fd.RuntimeOption() option.use_kunlunxin( device_id, @@ -190,6 +190,8 @@ def create_kunlunxin_runtime(model_dir, model_prefix, device_id=0): model_file = os.path.join(model_dir, model_prefix, "inference.pdmodel") params_file = os.path.join(model_dir, model_prefix, "inference.pdiparams") option.set_model_path(model_file, params_file) + if use_fp16: + option.enable_lite_fp16() return fd.Runtime(option) @@ -311,14 +313,19 @@ if __name__ == "__main__": text_encoder_runtime = create_kunlunxin_runtime( args.model_dir, args.text_encoder_model_prefix, + use_fp16=False, #args.ues_fp16 device_id=args.device_id) print("=== build vae_decoder_runtime") vae_decoder_runtime = create_kunlunxin_runtime( - args.model_dir, args.vae_model_prefix, device_id=args.device_id) + args.model_dir, args.vae_model_prefix, + use_fp16=False, #args.ues_fp16 + device_id=args.device_id) print("=== build unet_runtime") start = time.time() unet_runtime = create_kunlunxin_runtime( - args.model_dir, args.unet_model_prefix, device_id=args.device_id) + args.model_dir, args.unet_model_prefix, + args.ues_fp16, + device_id=args.device_id) print(f"Spend {time.time() - start : .2f} s to load unet model.") pipe = StableDiffusionFastDeployPipeline( vae_decoder_runtime=vae_decoder_runtime,