mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-24 00:53:22 +08:00

* Add notes for tensors * Optimize some apis * move some warnings * Support build with Paddle2ONNX * Add protobuf support * Fix compile on mac * add clearn package script * Add paddle2onnx code * remove submodule * Add onnx ocde * remove softlink * add onnx code * fix error * Add cmake file * fix patchelf * update paddle2onnx * Delete .gitmodules --------- Co-authored-by: PaddleCI <paddle_ci@example.com> Co-authored-by: pangyoki <pangyoki@126.com> Co-authored-by: jiangjiajun <jiangjiajun@baidu.lcom>
103 lines
3.8 KiB
Python
103 lines
3.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
import argparse
|
|
import config
|
|
import gc
|
|
import onnx
|
|
import os
|
|
from pathlib import Path
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
|
|
cwd_path = Path.cwd()
|
|
|
|
|
|
def run_lfs_install():
|
|
result = subprocess.run(['git', 'lfs', 'install'], cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
print('Git LFS install completed with return code= {}'.format(result.returncode))
|
|
|
|
|
|
def pull_lfs_file(file_name):
|
|
result = subprocess.run(['git', 'lfs', 'pull', '--include', file_name, '--exclude', '\'\''], cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
|
print('LFS pull completed with return code= {}'.format(result.returncode))
|
|
print(result)
|
|
|
|
|
|
def run_lfs_prune():
|
|
result = subprocess.run(['git', 'lfs', 'prune'], cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
print('LFS prune completed with return code= {}'.format(result.returncode))
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Test settings')
|
|
# default: test all models in the repo
|
|
# if test_dir is specified, only test files under that specified path
|
|
parser.add_argument('--test_dir', required=False, default='', type=str,
|
|
help='Directory path for testing. e.g., text, vision')
|
|
args = parser.parse_args()
|
|
parent_dir = []
|
|
# if not set, go throught each directory
|
|
if not args.test_dir:
|
|
for file in os.listdir():
|
|
if os.path.isdir(file):
|
|
parent_dir.append(file)
|
|
else:
|
|
parent_dir.append(args.test_dir)
|
|
model_list = []
|
|
for directory in parent_dir:
|
|
for root, _, files in os.walk(directory):
|
|
for file in files:
|
|
if file.endswith('.onnx'):
|
|
onnx_model_path = os.path.join(root, file)
|
|
model_list.append(onnx_model_path)
|
|
print(onnx_model_path)
|
|
# run lfs install before starting the tests
|
|
run_lfs_install()
|
|
|
|
print('=== Running ONNX Checker on {} models ==='.format(len(model_list)))
|
|
# run checker on each model
|
|
failed_models = []
|
|
failed_messages = []
|
|
skip_models = []
|
|
for model_path in model_list:
|
|
start = time.time()
|
|
model_name = model_path.split('/')[-1]
|
|
# if the model_path exists in the skip list, simply skip it
|
|
if model_path.replace('\\', '/') in config.SKIP_CHECKER_MODELS:
|
|
print('Skip model: {}'.format(model_path))
|
|
skip_models.append(model_path)
|
|
continue
|
|
print('-----------------Testing: {}-----------------'.format(model_name))
|
|
try:
|
|
pull_lfs_file(model_path)
|
|
model = onnx.load(model_path)
|
|
# stricter onnx.checker with onnx.shape_inference
|
|
onnx.checker.check_model(model, True)
|
|
# remove the model to save space in CIs
|
|
if os.path.exists(model_path):
|
|
os.remove(model_path)
|
|
# clean git lfs cache
|
|
run_lfs_prune()
|
|
print('[PASS]: {} is checked by onnx. '.format(model_name))
|
|
|
|
except Exception as e:
|
|
print('[FAIL]: {}'.format(e))
|
|
failed_models.append(model_path)
|
|
failed_messages.append((model_name, e))
|
|
end = time.time()
|
|
print('--------------Time used: {} secs-------------'.format(end - start))
|
|
# enable gc collection to prevent MemoryError by loading too many large models
|
|
gc.collect()
|
|
|
|
if len(failed_models) == 0:
|
|
print('{} models have been checked. {} models were skipped.'.format(len(model_list), len(skip_models)))
|
|
else:
|
|
print('In all {} models, {} models failed, {} models were skipped'.format(len(model_list), len(failed_models), len(skip_models)))
|
|
for model, error in failed_messages:
|
|
print('{} failed because: {}'.format(model, error))
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|