mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 01:22:59 +08:00

* Add notes for tensors * Optimize some apis * move some warnings * Support build with Paddle2ONNX * Add protobuf support * Fix compile on mac * add clearn package script * Add paddle2onnx code * remove submodule * Add onnx ocde * remove softlink * add onnx code * fix error * Add cmake file * fix patchelf * update paddle2onnx * Delete .gitmodules --------- Co-authored-by: PaddleCI <paddle_ci@example.com> Co-authored-by: pangyoki <pangyoki@126.com> Co-authored-by: jiangjiajun <jiangjiajun@baidu.lcom>
84 lines
3.1 KiB
Python
84 lines
3.1 KiB
Python
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
|
|
import os
|
|
import paddle
|
|
import numpy as np
|
|
from paddle.fluid import core
|
|
from paddle.fluid.framework import Variable, program_guard
|
|
from paddle2onnx.utils import logging
|
|
|
|
|
|
def prepend_feed_ops(inference_program,
|
|
feed_target_names,
|
|
feed_holder_name='feed'):
|
|
if len(feed_target_names) == 0:
|
|
return
|
|
global_block = inference_program.global_block()
|
|
feed_var = global_block.create_var(
|
|
name=feed_holder_name,
|
|
type=core.VarDesc.VarType.FEED_MINIBATCH,
|
|
persistable=True)
|
|
for i, name in enumerate(feed_target_names):
|
|
if not global_block.has_var(name):
|
|
raise ValueError(
|
|
"The feed_var_names[{i}]: '{name}' doesn't exist in pruned inference program. "
|
|
"Please check whether '{name}' is a valid feed_var name, or remove it from feed_var_names "
|
|
"if '{name}' is not involved in the fetch_vars calculation.".
|
|
format(
|
|
i=i, name=name))
|
|
out = global_block.var(name)
|
|
global_block._prepend_op(
|
|
type='feed',
|
|
inputs={'X': [feed_var]},
|
|
outputs={'Out': [out]},
|
|
attrs={'col': i})
|
|
|
|
|
|
def append_fetch_ops(inference_program,
|
|
fetch_target_names,
|
|
fetch_holder_name='fetch'):
|
|
global_block = inference_program.global_block()
|
|
fetch_var = global_block.create_var(
|
|
name=fetch_holder_name,
|
|
type=core.VarDesc.VarType.FETCH_LIST,
|
|
persistable=True)
|
|
for i, name in enumerate(fetch_target_names):
|
|
global_block.append_op(
|
|
type='fetch',
|
|
inputs={'X': [name]},
|
|
outputs={'Out': [fetch_var]},
|
|
attrs={'col': i})
|
|
|
|
|
|
def get_program(program, feed_var_names, fetch_vars):
|
|
global_block = program.global_block()
|
|
need_to_remove_op_index = []
|
|
for i, op in enumerate(global_block.ops):
|
|
op.desc.set_is_target(False)
|
|
if op.type == "feed" or op.type == "fetch":
|
|
need_to_remove_op_index.append(i)
|
|
for index in need_to_remove_op_index[::-1]:
|
|
global_block._remove_op(index)
|
|
program.desc.flush()
|
|
program = program._prune_with_input(
|
|
feeded_var_names=feed_var_names, targets=fetch_vars)
|
|
program = program._inference_optimize(prune_read_op=True)
|
|
fetch_var_names = [v.name for v in fetch_vars]
|
|
prepend_feed_ops(program, feed_var_names)
|
|
append_fetch_ops(program, fetch_var_names)
|
|
return program
|