# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import import os import paddle import numpy as np from paddle.fluid import core from paddle.fluid.framework import Variable, program_guard from paddle2onnx.utils import logging def prepend_feed_ops(inference_program, feed_target_names, feed_holder_name='feed'): if len(feed_target_names) == 0: return global_block = inference_program.global_block() feed_var = global_block.create_var( name=feed_holder_name, type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True) for i, name in enumerate(feed_target_names): if not global_block.has_var(name): raise ValueError( "The feed_var_names[{i}]: '{name}' doesn't exist in pruned inference program. " "Please check whether '{name}' is a valid feed_var name, or remove it from feed_var_names " "if '{name}' is not involved in the fetch_vars calculation.". format( i=i, name=name)) out = global_block.var(name) global_block._prepend_op( type='feed', inputs={'X': [feed_var]}, outputs={'Out': [out]}, attrs={'col': i}) def append_fetch_ops(inference_program, fetch_target_names, fetch_holder_name='fetch'): global_block = inference_program.global_block() fetch_var = global_block.create_var( name=fetch_holder_name, type=core.VarDesc.VarType.FETCH_LIST, persistable=True) for i, name in enumerate(fetch_target_names): global_block.append_op( type='fetch', inputs={'X': [name]}, outputs={'Out': [fetch_var]}, attrs={'col': i}) def get_program(program, feed_var_names, fetch_vars): global_block = program.global_block() need_to_remove_op_index = [] for i, op in enumerate(global_block.ops): op.desc.set_is_target(False) if op.type == "feed" or op.type == "fetch": need_to_remove_op_index.append(i) for index in need_to_remove_op_index[::-1]: global_block._remove_op(index) program.desc.flush() program = program._prune_with_input( feeded_var_names=feed_var_names, targets=fetch_vars) program = program._inference_optimize(prune_read_op=True) fetch_var_names = [v.name for v in fetch_vars] prepend_feed_ops(program, feed_var_names) append_fetch_ops(program, fetch_var_names) return program