mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00

* Add notes for tensors * Optimize some apis * move some warnings * Support build with Paddle2ONNX * Add protobuf support * Fix compile on mac * add clearn package script * Add paddle2onnx code * remove submodule * Add onnx ocde * remove softlink * add onnx code * fix error * Add cmake file * fix patchelf * update paddle2onnx * Delete .gitmodules --------- Co-authored-by: PaddleCI <paddle_ci@example.com> Co-authored-by: pangyoki <pangyoki@126.com> Co-authored-by: jiangjiajun <jiangjiajun@baidu.lcom>
288 lines
10 KiB
Python
Executable File
288 lines
10 KiB
Python
Executable File
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
from six import text_type as _text_type
|
|
import argparse
|
|
import ast
|
|
import sys
|
|
import os
|
|
import paddle.fluid as fluid
|
|
from paddle2onnx.utils import logging
|
|
|
|
|
|
def str2list(v):
|
|
if len(v) == 0:
|
|
return None
|
|
v = v.replace(" ", "")
|
|
v = eval(v)
|
|
return v
|
|
|
|
|
|
def arg_parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--model_dir",
|
|
"-m",
|
|
type=_text_type,
|
|
default=None,
|
|
help="PaddlePaddle model directory, if params stored in single file, you need define '--model_filename' and 'params_filename'."
|
|
)
|
|
parser.add_argument(
|
|
"--model_filename",
|
|
"-mf",
|
|
type=_text_type,
|
|
default=None,
|
|
help="PaddlePaddle model's network file name, which under directory seted by --model_dir"
|
|
)
|
|
parser.add_argument(
|
|
"--params_filename",
|
|
"-pf",
|
|
type=_text_type,
|
|
default=None,
|
|
help="PaddlePaddle model's param file name(param files combined in single file), which under directory seted by --model_dir."
|
|
)
|
|
parser.add_argument(
|
|
"--save_file",
|
|
"-s",
|
|
type=_text_type,
|
|
default=None,
|
|
help="file path to save onnx model")
|
|
parser.add_argument(
|
|
"--opset_version",
|
|
"-ov",
|
|
type=int,
|
|
default=9,
|
|
help="set onnx opset version to export")
|
|
parser.add_argument(
|
|
"--input_shape_dict",
|
|
"-isd",
|
|
type=_text_type,
|
|
default="None",
|
|
help="define input shapes, e.g --input_shape_dict=\"{'image':[1, 3, 608, 608]}\" or" \
|
|
"--input_shape_dict=\"{'image':[1, 3, 608, 608], 'im_shape': [1, 2], 'scale_factor': [1, 2]}\"")
|
|
parser.add_argument(
|
|
"--enable_dev_version",
|
|
type=ast.literal_eval,
|
|
default=False,
|
|
help="whether to use new version of Paddle2ONNX which is under developing, default False"
|
|
)
|
|
parser.add_argument(
|
|
"--enable_onnx_checker",
|
|
type=ast.literal_eval,
|
|
default=True,
|
|
help="whether check onnx model validity, default True")
|
|
parser.add_argument(
|
|
"--enable_paddle_fallback",
|
|
type=ast.literal_eval,
|
|
default=False,
|
|
help="whether use PaddleFallback for custom op, default is False")
|
|
parser.add_argument(
|
|
"--version",
|
|
"-v",
|
|
action="store_true",
|
|
default=False,
|
|
help="get version of paddle2onnx")
|
|
parser.add_argument(
|
|
"--output_names",
|
|
"-on",
|
|
type=str2list,
|
|
default=None,
|
|
help="define output names, e.g --output_names=\"[\"output1\"]\" or \
|
|
--output_names=\"[\"output1\", \"output2\", \"output3\"]\" or \
|
|
--output_names=\"{\"Paddleoutput\":\"Onnxoutput\"}\"")
|
|
parser.add_argument(
|
|
"--enable_auto_update_opset",
|
|
type=ast.literal_eval,
|
|
default=True,
|
|
help="whether enable auto_update_opset, default is True")
|
|
return parser
|
|
|
|
|
|
def c_paddle_to_onnx(model_file,
|
|
params_file="",
|
|
save_file=None,
|
|
opset_version=7,
|
|
auto_upgrade_opset=True,
|
|
verbose=True,
|
|
enable_onnx_checker=True,
|
|
enable_experimental_op=True,
|
|
enable_optimize=True):
|
|
import paddle2onnx.paddle2onnx_cpp2py_export as c_p2o
|
|
onnx_model_str = c_p2o.export(
|
|
model_file, params_file, opset_version, auto_upgrade_opset, verbose,
|
|
enable_onnx_checker, enable_experimental_op, enable_optimize)
|
|
if save_file is not None:
|
|
with open(save_file, "wb") as f:
|
|
f.write(onnx_model_str)
|
|
else:
|
|
return onnx_model_str
|
|
|
|
|
|
def program2onnx(model_dir,
|
|
save_file,
|
|
model_filename=None,
|
|
params_filename=None,
|
|
opset_version=9,
|
|
enable_onnx_checker=False,
|
|
operator_export_type="ONNX",
|
|
input_shape_dict=None,
|
|
output_names=None,
|
|
auto_update_opset=True):
|
|
try:
|
|
import paddle
|
|
except:
|
|
logging.error(
|
|
"paddlepaddle not installed, use \"pip install paddlepaddle\"")
|
|
|
|
v0, v1, v2 = paddle.__version__.split('.')
|
|
if v0 == '0' and v1 == '0' and v2 == '0':
|
|
logging.warning("You are use develop version of paddlepaddle")
|
|
elif int(v0) <= 1 and int(v1) < 8:
|
|
raise ImportError("paddlepaddle>=1.8.0 is required")
|
|
|
|
import paddle2onnx as p2o
|
|
# convert model save with 'paddle.fluid.io.save_inference_model'
|
|
if hasattr(paddle, 'enable_static'):
|
|
paddle.enable_static()
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
if model_filename is None and params_filename is None:
|
|
[program, feed_var_names, fetch_vars] = fluid.io.load_inference_model(
|
|
model_dir, exe)
|
|
else:
|
|
[program, feed_var_names, fetch_vars] = fluid.io.load_inference_model(
|
|
model_dir,
|
|
exe,
|
|
model_filename=model_filename,
|
|
params_filename=params_filename)
|
|
|
|
OP_WITHOUT_KERNEL_SET = {
|
|
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
|
|
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
|
|
'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify',
|
|
'gen_bkcl_id', 'c_gen_bkcl_id', 'gen_nccl_id', 'c_gen_nccl_id',
|
|
'c_comm_init', 'c_sync_calc_stream', 'c_sync_comm_stream',
|
|
'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv',
|
|
'c_wait_comm', 'c_wait_compute', 'c_gen_hccl_id', 'c_comm_init_hccl',
|
|
'copy_cross_scope'
|
|
}
|
|
if input_shape_dict is not None:
|
|
import paddle2onnx
|
|
paddle2onnx.legacy.process_old_ops_desc(program)
|
|
paddle_version = paddle.__version__
|
|
model_version = program.desc._version()
|
|
major_ver = model_version // 1000000
|
|
minor_ver = (model_version - major_ver * 1000000) // 1000
|
|
patch_ver = model_version - major_ver * 1000000 - minor_ver * 1000
|
|
model_version = "{}.{}.{}".format(major_ver, minor_ver, patch_ver)
|
|
if model_version != paddle_version:
|
|
logging.warning(
|
|
"The model is saved by paddlepaddle v{}, but now your paddlepaddle is version of {}, this difference may cause error, it is recommend you reinstall a same version of paddlepaddle for this model".
|
|
format(model_version, paddle_version))
|
|
|
|
for k, v in input_shape_dict.items():
|
|
program.blocks[0].var(k).desc.set_shape(v)
|
|
for i in range(len(program.blocks[0].ops)):
|
|
if program.blocks[0].ops[i].type in OP_WITHOUT_KERNEL_SET:
|
|
continue
|
|
program.blocks[0].ops[i].desc.infer_shape(program.blocks[0].desc)
|
|
p2o.program2onnx(
|
|
program,
|
|
fluid.global_scope(),
|
|
save_file,
|
|
feed_var_names=feed_var_names,
|
|
target_vars=fetch_vars,
|
|
opset_version=opset_version,
|
|
enable_onnx_checker=enable_onnx_checker,
|
|
operator_export_type=operator_export_type,
|
|
auto_update_opset=auto_update_opset,
|
|
output_names=output_names)
|
|
|
|
|
|
def main():
|
|
if len(sys.argv) < 2:
|
|
logging.info("Use \"paddle2onnx -h\" to print the help information")
|
|
logging.info(
|
|
"For more information, please follow our github repo below:")
|
|
logging.info("Github: https://github.com/PaddlePaddle/paddle2onnx.git")
|
|
return
|
|
|
|
parser = arg_parser()
|
|
args = parser.parse_args()
|
|
|
|
if args.version:
|
|
import paddle2onnx
|
|
logging.info("paddle2onnx-{} with python>=2.7, paddlepaddle>=1.8.0".
|
|
format(paddle2onnx.__version__))
|
|
return
|
|
|
|
assert args.model_dir is not None, "--model_dir should be defined while translating paddle model to onnx"
|
|
assert args.save_file is not None, "--save_file should be defined while translating paddle model to onnx"
|
|
|
|
input_shape_dict = eval(args.input_shape_dict)
|
|
|
|
operator_export_type = "ONNX"
|
|
if args.enable_paddle_fallback:
|
|
operator_export_type = "PaddleFallback"
|
|
|
|
if args.output_names is not None:
|
|
if not isinstance(args.output_names, (list, dict)):
|
|
raise TypeError(
|
|
"The output_names should be 'list' or 'dict', but received type is %s."
|
|
% type(args.output_names))
|
|
|
|
if args.enable_dev_version:
|
|
if args.enable_paddle_fallback:
|
|
logging.warn(
|
|
"--enable_paddle_fallback is deprecated while --enable_dev_version=True."
|
|
)
|
|
if args.output_names is not None:
|
|
logging.warn(
|
|
"--output_names is deprecated while --enable_dev_version=True.")
|
|
if input_shape_dict is not None:
|
|
logging.warn(
|
|
"--input_shape_dict is deprecated while --enable_dev_version=True."
|
|
)
|
|
model_file = os.path.join(args.model_dir, args.model_filename)
|
|
if args.params_filename is None:
|
|
params_file = ""
|
|
else:
|
|
params_file = os.path.join(args.model_dir, args.params_filename)
|
|
return c_paddle_to_onnx(
|
|
model_file=model_file,
|
|
params_file=params_file,
|
|
save_file=args.save_file,
|
|
opset_version=args.opset_version,
|
|
auto_upgrade_opset=args.enable_auto_update_opset,
|
|
verbose=True,
|
|
enable_onnx_checker=args.enable_onnx_checker,
|
|
enable_experimental_op=True,
|
|
enable_optimize=True)
|
|
|
|
program2onnx(
|
|
args.model_dir,
|
|
args.save_file,
|
|
args.model_filename,
|
|
args.params_filename,
|
|
opset_version=args.opset_version,
|
|
enable_onnx_checker=args.enable_onnx_checker,
|
|
operator_export_type=operator_export_type,
|
|
input_shape_dict=input_shape_dict,
|
|
output_names=args.output_names,
|
|
auto_update_opset=args.enable_auto_update_opset)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|