Add paddle quantize model support for ORT, TRT and MKLDNN deploy backend (#257)

* add quantize model support for trt and paddle

* fix bugs

* fix

* update paddle2onnx version

* update version

* add quantize test

Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
yeliang2258
2022-10-09 20:00:05 +08:00
committed by GitHub
parent ff5e798b7f
commit 2a68a23baf
10 changed files with 187 additions and 5 deletions

2
cmake/paddle2onnx.cmake Normal file → Executable file
View File

@@ -43,7 +43,7 @@ else()
endif(WIN32)
set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/fastdeploy/third_libs/")
set(PADDLE2ONNX_VERSION "1.0.1rc")
set(PADDLE2ONNX_VERSION "1.0.1")
if(WIN32)
set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip")
if(NOT CMAKE_CL_64)

View File

@@ -16,13 +16,23 @@
namespace fastdeploy {
void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
void PaddleBackend::BuildOption(const PaddleBackendOption& option,
const std::string& model_file) {
if (option.use_gpu) {
config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id);
} else {
config_.DisableGpu();
if (option.enable_mkldnn) {
config_.EnableMKLDNN();
std::string contents;
if (!ReadBinaryFromFile(model_file, &contents)) {
return;
}
auto reader =
paddle2onnx::PaddleReader(contents.c_str(), contents.size());
if (reader.is_quantize_model) {
config_.EnableMkldnnInt8();
}
config_.SetMkldnnCacheCapacity(option.mkldnn_cache_size);
}
}
@@ -52,7 +62,7 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
return false;
}
config_.SetModel(model_file, params_file);
BuildOption(option);
BuildOption(option, model_file);
predictor_ = paddle_infer::CreatePredictor(config_);
std::vector<std::string> input_names = predictor_->GetInputNames();
std::vector<std::string> output_names = predictor_->GetOutputNames();

4
fastdeploy/backends/paddle/paddle_backend.h Normal file → Executable file
View File

@@ -20,6 +20,7 @@
#include <vector>
#include "fastdeploy/backends/backend.h"
#include "paddle2onnx/converter.h"
#include "paddle_inference_api.h" // NOLINT
namespace fastdeploy {
@@ -61,7 +62,8 @@ class PaddleBackend : public BaseBackend {
public:
PaddleBackend() {}
virtual ~PaddleBackend() = default;
void BuildOption(const PaddleBackendOption& option);
void BuildOption(const PaddleBackendOption& option,
const std::string& model_file);
bool InitFromPaddle(
const std::string& model_file, const std::string& params_file,

View File

@@ -131,10 +131,13 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
}
char* model_content_ptr;
int model_content_size = 0;
char* calibration_cache_ptr;
int calibration_cache_size = 0;
if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(),
&model_content_ptr, &model_content_size, 11, true,
verbose, true, true, true, custom_ops.data(),
custom_ops.size())) {
custom_ops.size(), "tensorrt",
&calibration_cache_ptr, &calibration_cache_size)) {
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
<< std::endl;
return false;
@@ -151,6 +154,13 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
delete[] model_content_ptr;
std::string onnx_model_proto(new_model, new_model + new_model_size);
delete[] new_model;
if (calibration_cache_size) {
std::string calibration_str(
calibration_cache_ptr,
calibration_cache_ptr + calibration_cache_size);
calibration_str_ = calibration_str;
delete[] calibration_cache_ptr;
}
return InitFromOnnx(onnx_model_proto, option, true);
}
@@ -158,6 +168,12 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
model_content_ptr + model_content_size);
delete[] model_content_ptr;
model_content_ptr = nullptr;
if (calibration_cache_size) {
std::string calibration_str(calibration_cache_ptr,
calibration_cache_ptr + calibration_cache_size);
calibration_str_ = calibration_str;
delete[] calibration_cache_ptr;
}
return InitFromOnnx(onnx_model_proto, option, true);
#else
FDERROR << "Didn't compile with PaddlePaddle frontend, you can try to "
@@ -409,6 +425,7 @@ bool TrtBackend::BuildTrtEngine() {
"will use FP32 instead."
<< std::endl;
} else {
FDINFO << "[TrtBackend] Use FP16 to inference." << std::endl;
config->setFlag(nvinfer1::BuilderFlag::kFP16);
}
}
@@ -459,6 +476,20 @@ bool TrtBackend::BuildTrtEngine() {
}
config->addOptimizationProfile(profile);
if (calibration_str_.size()) {
if (!builder_->platformHasFastInt8()) {
FDWARNING << "Detected INT8 is not supported in the current GPU, "
"will use FP32 instead."
<< std::endl;
} else {
FDINFO << "[TrtBackend] Use INT8 to inference." << std::endl;
config->setFlag(nvinfer1::BuilderFlag::kINT8);
Int8EntropyCalibrator2* calibrator =
new Int8EntropyCalibrator2(calibration_str_);
config->setInt8Calibrator(calibrator);
}
}
FDUniquePtr<nvinfer1::IHostMemory> plan{
builder_->buildSerializedNetwork(*network_, *config)};
if (!plan) {

28
fastdeploy/backends/tensorrt/trt_backend.h Normal file → Executable file
View File

@@ -26,6 +26,32 @@
#include "fastdeploy/backends/backend.h"
#include "fastdeploy/backends/tensorrt/utils.h"
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 {
public:
explicit Int8EntropyCalibrator2(const std::string& calibration_cache)
: calibration_cache_(calibration_cache) {}
int getBatchSize() const noexcept override { return 0; }
bool getBatch(void* bindings[], const char* names[],
int nbBindings) noexcept override {
return false;
}
const void* readCalibrationCache(size_t& length) noexcept override {
length = calibration_cache_.size();
return length ? calibration_cache_.data() : nullptr;
}
void writeCalibrationCache(const void* cache,
size_t length) noexcept override {
std::cout << "NOT IMPLEMENT." << std::endl;
}
private:
const std::string calibration_cache_;
};
namespace fastdeploy {
struct TrtValueInfo {
@@ -95,6 +121,8 @@ class TrtBackend : public BaseBackend {
std::map<std::string, FDDeviceBuffer> inputs_buffer_;
std::map<std::string, FDDeviceBuffer> outputs_buffer_;
std::string calibration_str_;
// Sometimes while the number of outputs > 1
// the output order of tensorrt may not be same
// with the original onnx model

1
fastdeploy/pybind/runtime.cc Normal file → Executable file
View File

@@ -25,6 +25,7 @@ void BindRuntime(pybind11::module& m) {
.def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum)
.def("use_paddle_backend", &RuntimeOption::UsePaddleBackend)
.def("use_ort_backend", &RuntimeOption::UseOrtBackend)
.def("set_ort_graph_opt_level", &RuntimeOption::SetOrtGraphOptLevel)
.def("use_trt_backend", &RuntimeOption::UseTrtBackend)
.def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend)
.def("use_lite_backend", &RuntimeOption::UseLiteBackend)

7
fastdeploy/runtime.cc Normal file → Executable file
View File

@@ -198,6 +198,13 @@ void RuntimeOption::SetCpuThreadNum(int thread_num) {
cpu_thread_num = thread_num;
}
void RuntimeOption::SetOrtGraphOptLevel(int level) {
std::vector<int> supported_level{-1, 0, 1, 2};
auto valid_level = std::find(supported_level.begin(), supported_level.end(), level) != supported_level.end();
FDASSERT(valid_level, "The level must be -1, 0, 1, 2.");
ort_graph_opt_level = level;
}
// use paddle inference backend
void RuntimeOption::UsePaddleBackend() {
#ifdef ENABLE_PADDLE_BACKEND

4
fastdeploy/runtime.h Normal file → Executable file
View File

@@ -22,6 +22,7 @@
#include <map>
#include <vector>
#include <algorithm>
#include "fastdeploy/backends/backend.h"
#include "fastdeploy/utils/perf.h"
@@ -104,6 +105,9 @@ struct FASTDEPLOY_DECL RuntimeOption {
*/
void SetCpuThreadNum(int thread_num);
/// Use ORT graph opt level
void SetOrtGraphOptLevel(int level = -1);
/// Set Paddle Inference as inference backend, support CPU/GPU
void UsePaddleBackend();

3
python/fastdeploy/runtime.py Normal file → Executable file
View File

@@ -117,6 +117,9 @@ class RuntimeOption:
"""
return self._option.set_cpu_thread_num(thread_num)
def set_ort_graph_opt_level(self, level=-1):
return self._option.set_ort_graph_opt_level(level)
def use_paddle_backend(self):
"""Use Paddle Inference backend, support inference Paddle model on CPU/Nvidia GPU.
"""

View File

@@ -0,0 +1,96 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
import pickle
import numpy as np
model_url = "https://bj.bcebos.com/fastdeploy/tests/yolov6_quant.tgz"
fd.download_and_decompress(model_url, ".")
def test_quant_mkldnn():
model_path = "./yolov6_quant"
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
input_file = os.path.join(model_path, "input.npy")
output_file = os.path.join(model_path, "mkldnn_output.npy")
option = fd.RuntimeOption()
option.use_paddle_backend()
option.use_cpu()
option.set_model_path(model_file, params_file)
runtime = fd.Runtime(option)
input_name = runtime.get_input_info(0).name
data = np.load(input_file)
outs = runtime.infer({input_name: data})
expected = np.load(output_file)
diff = np.fabs(outs[0] - expected)
thres = 1e-05
assert diff.max() < thres, "The diff is %f, which is bigger than %f" % (
diff.max(), thres)
def test_quant_ort():
model_path = "./yolov6_quant"
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
input_file = os.path.join(model_path, "input.npy")
output_file = os.path.join(model_path, "ort_output.npy")
option = fd.RuntimeOption()
option.use_ort_backend()
option.use_cpu()
option.set_ort_graph_opt_level(1)
option.set_model_path(model_file, params_file)
runtime = fd.Runtime(option)
input_name = runtime.get_input_info(0).name
data = np.load(input_file)
outs = runtime.infer({input_name: data})
expected = np.load(output_file)
diff = np.fabs(outs[0] - expected)
thres = 1e-05
assert diff.max() < thres, "The diff is %f, which is bigger than %f" % (
diff.max(), thres)
def test_quant_trt():
model_path = "./yolov6_quant"
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
input_file = os.path.join(model_path, "input.npy")
output_file = os.path.join(model_path, "trt_output.npy")
option = fd.RuntimeOption()
option.use_trt_backend()
option.use_gpu()
option.set_model_path(model_file, params_file)
runtime = fd.Runtime(option)
input_name = runtime.get_input_info(0).name
data = np.load(input_file)
outs = runtime.infer({input_name: data})
expected = np.load(output_file)
diff = np.fabs(outs[0] - expected)
thres = 1e-05
assert diff.max() < thres, "The diff is %f, which is bigger than %f" % (
diff.max(), thres)