// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include "paddle2onnx/mapper/mapper.h" #include "paddle2onnx/parser/parser.h" namespace paddle2onnx { struct QuantizeModelProcessor { public: std::vector quantize_info; const PaddleParser* parser_; OnnxHelper* helper_; std::vector>* parameters_; std::vector>* inputs_; std::vector>* outputs_; std::vector>* nodes_; // All types that support quantization std::vector supported_quantize_type_; std::map>> name2node_dict_; std::vector tensors_to_be_quantize; // records those tensors // that need to add quantize // and dequantize op std::vector only_dequantize_tensors; // records those tensors // that only need to add // the dequantize op // Convert to different model formats based on backend, backend can be // TensorRT, ONNXRuntime and Others void ProcessQuantizeModel( std::vector>* parameters, std::vector>* inputs, std::vector>* outputs, std::vector>* nodes, OnnxHelper* helper, const std::string& deploy_backend, const PaddleParser& parser, std::string* calibration_cache = nullptr); // Remove all Quantize and Dequantize ops void RemoveAllQuantizeOps(); // If all tensors in tensor_names have quantize info and all the next nodes // can be quantized, return True, otherwise // return false bool CanBeQuantize(const std::vector& tensor_names, const std::vector& output_index = {-1}); // only_dequantize records those tensors that only need to add the dequantize // op void AppendQuantizeTensor(const std::string& tensor, const bool& only_dequantize = false); // Add QDQ for ORT according to: // https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc void AddQDQForORT(); // Determine if the tensor is directly linked to the output by identity bool ConnectToOutput(const std::string& output_name); // Generate cache file for TensorRT8.X int8 deploy void GenerateCache(std::string* calibration_cache); // Add QDQ for TRT according to: // https://github.com/NVIDIA/TensorRT/tree/main/tools/pytorch-quantization/pytorch_quantization/nn/modules void AddTrtQDQ(); // Add QDQ for RKNN void AddQDQForRKNN(); // Add quantize related op in model according to tensor names void AddQDQInModel(const std::vector& tensors_to_be_quantize); void QuantizeInfoBroadcast(); // merge conv + add void MergeConvAdd(); // merge conv + BN void MergeConvBN(); // Determine whether a tensor is an output bool IsGraphOutput(const std::string& name); // Because processing the quantize model will add new nodes, which will // destroy the topo sorting of nodes, this function will sort the nodes again void SortNodes(); bool GetTensorShape(const std::string& name, std::vector* shape); // return the value of tensor by name template bool GetTensorByName(const std::string& name, std::vector* value); // Perform tensor wise quantization, returning scale and zero void GetTensorWiseQuantizeInfo(const std::vector& tensor, std::vector* scale, std::vector* zero); // Perform channel wise quantization, returning scale and zero void GetChannelWiseQuantizeInfo(const std::vector& tensor, const std::vector& shape, const int64_t& quant_axis, std::vector* scale, std::vector* zero); // Generate name2node_dict to save input name and its related nodes void UpdateInputNameToNodes(); void RemoveNodeByName(const std::string& name, const bool& update_io = true); void ReplaceInputOfAllNodes( const std::string& old_name, const std::string& new_name, const std::vector>& except_nodes = {}); }; } // namespace paddle2onnx