/* * SPDX-License-Identifier: Apache-2.0 */ // ATTENTION: The code in this file is highly EXPERIMENTAL. // Adventurous users should note that the APIs will probably change. #include "onnxoptimizer/optimize.h" namespace ONNX_NAMESPACE { namespace optimization { GlobalPassRegistry Optimizer::passes; Optimizer::Optimizer( const std::vector& names, const bool fixed_point) { if (fixed_point) { this->pass_manager = std::shared_ptr(new FixedPointPassManager()); } else { this->pass_manager = std::shared_ptr(new GeneralPassManager()); } for (const auto& name : names) { auto pass = passes.find(name); this->pass_manager->add(pass); } } Optimizer::~Optimizer() {} ModelProto Optimize( const ModelProto& mp_in, const std::vector& names) { Optimizer current_opt(names, false); return current_opt.optimize(mp_in); } ModelProto OptimizeFixed( const ModelProto& mp_in, const std::vector& names) { Optimizer current_opt(names, true); return current_opt.optimize(mp_in); } const std::vector GetAvailablePasses() { return Optimizer::passes.GetAvailablePasses(); } const std::vector GetFuseAndEliminationPass() { return Optimizer::passes.GetFuseAndEliminationPass(); } } // namespace optimization } // namespace ONNX_NAMESPACE