[Example] Merge Download Paddle Model, Paddle->ONNX, ONNX -> MLIR, MLIR -> BModel into infer.py (#1622)

fix infer.py and README

Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
Yi-sir
2023-03-16 19:47:50 +08:00
committed by GitHub
parent 2de2166472
commit 66275bcbfa
3 changed files with 90 additions and 14 deletions

View File

@@ -15,8 +15,11 @@ cd FastDeploy/examples/vision/classification/paddleclas/sophgo/python
# Download images. # Download images.
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
# Inference. # Inference. Need to manually set the model, configuration file and image path used for inference.
python3 infer.py --model_file ./bmodel/resnet50_1684x_f32.bmodel --config_file ResNet50_vd_infer/inference_cls.yaml --image ILSVRC2012_val_00000010.jpeg python3 infer.py --auto False --model_file ./bmodel/resnet50_1684x_f32.bmodel --config_file ResNet50_vd_infer/inference_cls.yaml --image ILSVRC2012_val_00000010.jpeg
# Automatic completion of downloading data - model compilation - inference, no need to set up model, configuration file and image paths.
python3 infer.py --auto True --model '' --config_file '' --image ''
# The returned result. # The returned result.
ClassifyResult( ClassifyResult(

View File

@@ -15,8 +15,12 @@ cd FastDeploy/examples/vision/classification/paddleclas/sophgo/python
# 下载图片 # 下载图片
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
# 推理 # 手动设置推理使用的模型、配置文件和图片路径
python3 infer.py --model_file ./bmodel/resnet50_1684x_f32.bmodel --config_file ResNet50_vd_infer/inference_cls.yaml --image ILSVRC2012_val_00000010.jpeg python3 infer.py --auto False --model_file ./bmodel/resnet50_1684x_f32.bmodel --config_file ResNet50_vd_infer/inference_cls.yaml --image ILSVRC2012_val_00000010.jpeg
# 自动完成下载数据-模型编译-推理,不需要设置模型、配置文件和图片路径
python3 infer.py --auto True --model '' --config_file '' --image ''
# 运行完成后返回结果如下所示 # 运行完成后返回结果如下所示
ClassifyResult( ClassifyResult(

View File

@@ -1,15 +1,15 @@
import fastdeploy as fd import fastdeploy as fd
import cv2 import cv2
import os import os
from subprocess import run
def parse_arguments(): def parse_arguments():
import argparse import argparse
import ast import ast
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, help="Path of model.") parser.add_argument("--auto", required=True, help="Auto download, convert, compile and infer if True")
parser.add_argument( parser.add_argument("--model", required=True, help="Path of bmodel")
"--config_file", required=True, help="Path of config file.") parser.add_argument("--config_file", required=True, help="Path of config file")
parser.add_argument( parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.") "--image", type=str, required=True, help="Path of test image file.")
parser.add_argument( parser.add_argument(
@@ -17,17 +17,86 @@ def parse_arguments():
return parser.parse_args() return parser.parse_args()
def download():
cmd_str = 'wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz'
jpg_str = 'wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg'
tar_str = 'tar xvf ResNet50_vd_infer.tgz'
if not os.path.exists('ResNet50_vd_infer.tgz'):
run(cmd_str, shell=True)
if not os.path.exists('ILSVRC2012_val_00000010.jpeg'):
run(jpg_str, shell=True)
run(tar_str, shell=True)
def paddle2onnx():
cmd_str = 'paddle2onnx --model_dir ResNet50_vd_infer \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
--save_file ResNet50_vd_infer.onnx \
--enable_dev_version True'
print(cmd_str)
run(cmd_str, shell=True)
def mlir_prepare():
mlir_path = os.getenv("MODEL_ZOO_PATH")
mlir_path = mlir_path[:-13]
cmd_list = ['mkdir ResNet50',
'cp -rf ' + os.path.join(mlir_path, 'regression/dataset/COCO2017/') + ' ./ResNet50',
'cp -rf ' + os.path.join(mlir_path, 'regression/image/') + ' ./ResNet50',
'cp ResNet50_vd_infer.onnx ./ResNet50/',
'mkdir ./ResNet50/workspace']
for str in cmd_list:
print(str)
run(str, shell=True)
def onnx2mlir():
cmd_str = 'model_transform.py \
--model_name ResNet50_vd_infer \
--model_def ../ResNet50_vd_infer.onnx \
--input_shapes [[1,3,224,224]] \
--mean 0.0,0.0,0.0 \
--scale 0.0039216,0.0039216,0.0039216 \
--keep_aspect_ratio \
--pixel_format rgb \
--output_names save_infer_model/scale_0.tmp_1 \
--test_input ../image/dog.jpg \
--test_result ./ResNet50_vd_infer_top_outputs.npz \
--mlir ./ResNet50_vd_infer.mlir'
print(cmd_str)
os.chdir('./ResNet50/workspace/')
run(cmd_str, shell=True)
os.chdir('../../')
def mlir2bmodel():
cmd_str = 'model_deploy.py \
--mlir ./ResNet50_vd_infer.mlir \
--quantize F32 \
--chip bm1684x \
--test_input ./ResNet50_vd_infer_in_f32.npz \
--test_reference ./ResNet50_vd_infer_top_outputs.npz \
--model ./ResNet50_vd_infer_1684x_f32.bmodel'
print(cmd_str)
os.chdir('./ResNet50/workspace')
run(cmd_str, shell=True)
os.chdir('../../')
args = parse_arguments() args = parse_arguments()
# 配置runtime加载模型 if(args.auto):
download()
paddle2onnx()
mlir_prepare()
onnx2mlir()
mlir2bmodel()
# config runtime and load the model
runtime_option = fd.RuntimeOption() runtime_option = fd.RuntimeOption()
runtime_option.use_sophgo() runtime_option.use_sophgo()
model_file = args.model model_file = './ResNet50/workspace/ResNet50_vd_infer_1684x_f32.bmodel' if args.auto else args.model
params_file = "" params_file = ""
config_file = args.config_file config_file = './ResNet50_vd_infer/inference_cls.yaml' if args.auto else args.config_file
image_file = './ILSVRC2012_val_00000010.jpeg' if args.auto else args.image
model = fd.vision.classification.PaddleClasModel( model = fd.vision.classification.PaddleClasModel(
model_file, model_file,
params_file, params_file,
@@ -35,7 +104,7 @@ model = fd.vision.classification.PaddleClasModel(
runtime_option=runtime_option, runtime_option=runtime_option,
model_format=fd.ModelFormat.SOPHGO) model_format=fd.ModelFormat.SOPHGO)
# 预测图片分类结果 # predict the results of image classification
im = cv2.imread(args.image) im = cv2.imread(image_file)
result = model.predict(im, args.topk) result = model.predict(im, args.topk)
print(result) print(result)