modify sd demo infer.py for using paddle_kunlunxin_fp16 (#1612)

* modify sd infer.py for using paddle_kunlunxin_fp16

* Update infer.py
This commit is contained in:
wangguoya
2023-03-16 19:48:15 +08:00
committed by GitHub
parent 66275bcbfa
commit bf6caeb2ce

View File

@@ -175,7 +175,7 @@ def create_trt_runtime(model_dir,
return fd.Runtime(option)
def create_kunlunxin_runtime(model_dir, model_prefix, device_id=0):
def create_kunlunxin_runtime(model_dir, model_prefix, use_fp16=False, device_id=0):
option = fd.RuntimeOption()
option.use_kunlunxin(
device_id,
@@ -190,6 +190,8 @@ def create_kunlunxin_runtime(model_dir, model_prefix, device_id=0):
model_file = os.path.join(model_dir, model_prefix, "inference.pdmodel")
params_file = os.path.join(model_dir, model_prefix, "inference.pdiparams")
option.set_model_path(model_file, params_file)
if use_fp16:
option.enable_lite_fp16()
return fd.Runtime(option)
@@ -311,14 +313,19 @@ if __name__ == "__main__":
text_encoder_runtime = create_kunlunxin_runtime(
args.model_dir,
args.text_encoder_model_prefix,
use_fp16=False, #args.ues_fp16
device_id=args.device_id)
print("=== build vae_decoder_runtime")
vae_decoder_runtime = create_kunlunxin_runtime(
args.model_dir, args.vae_model_prefix, device_id=args.device_id)
args.model_dir, args.vae_model_prefix,
use_fp16=False, #args.ues_fp16
device_id=args.device_id)
print("=== build unet_runtime")
start = time.time()
unet_runtime = create_kunlunxin_runtime(
args.model_dir, args.unet_model_prefix, device_id=args.device_id)
args.model_dir, args.unet_model_prefix,
args.ues_fp16,
device_id=args.device_id)
print(f"Spend {time.time() - start : .2f} s to load unet model.")
pipe = StableDiffusionFastDeployPipeline(
vae_decoder_runtime=vae_decoder_runtime,