mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Add paddle quantize model support for ORT, TRT and MKLDNN deploy backend (#257)
* add quantize model support for trt and paddle * fix bugs * fix * update paddle2onnx version * update version * add quantize test Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
2
cmake/paddle2onnx.cmake
Normal file → Executable file
2
cmake/paddle2onnx.cmake
Normal file → Executable file
@@ -43,7 +43,7 @@ else()
|
||||
endif(WIN32)
|
||||
|
||||
set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/fastdeploy/third_libs/")
|
||||
set(PADDLE2ONNX_VERSION "1.0.1rc")
|
||||
set(PADDLE2ONNX_VERSION "1.0.1")
|
||||
if(WIN32)
|
||||
set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip")
|
||||
if(NOT CMAKE_CL_64)
|
||||
|
@@ -16,13 +16,23 @@
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
|
||||
void PaddleBackend::BuildOption(const PaddleBackendOption& option,
|
||||
const std::string& model_file) {
|
||||
if (option.use_gpu) {
|
||||
config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id);
|
||||
} else {
|
||||
config_.DisableGpu();
|
||||
if (option.enable_mkldnn) {
|
||||
config_.EnableMKLDNN();
|
||||
std::string contents;
|
||||
if (!ReadBinaryFromFile(model_file, &contents)) {
|
||||
return;
|
||||
}
|
||||
auto reader =
|
||||
paddle2onnx::PaddleReader(contents.c_str(), contents.size());
|
||||
if (reader.is_quantize_model) {
|
||||
config_.EnableMkldnnInt8();
|
||||
}
|
||||
config_.SetMkldnnCacheCapacity(option.mkldnn_cache_size);
|
||||
}
|
||||
}
|
||||
@@ -52,7 +62,7 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
|
||||
return false;
|
||||
}
|
||||
config_.SetModel(model_file, params_file);
|
||||
BuildOption(option);
|
||||
BuildOption(option, model_file);
|
||||
predictor_ = paddle_infer::CreatePredictor(config_);
|
||||
std::vector<std::string> input_names = predictor_->GetInputNames();
|
||||
std::vector<std::string> output_names = predictor_->GetOutputNames();
|
||||
|
4
fastdeploy/backends/paddle/paddle_backend.h
Normal file → Executable file
4
fastdeploy/backends/paddle/paddle_backend.h
Normal file → Executable file
@@ -20,6 +20,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "fastdeploy/backends/backend.h"
|
||||
#include "paddle2onnx/converter.h"
|
||||
#include "paddle_inference_api.h" // NOLINT
|
||||
|
||||
namespace fastdeploy {
|
||||
@@ -61,7 +62,8 @@ class PaddleBackend : public BaseBackend {
|
||||
public:
|
||||
PaddleBackend() {}
|
||||
virtual ~PaddleBackend() = default;
|
||||
void BuildOption(const PaddleBackendOption& option);
|
||||
void BuildOption(const PaddleBackendOption& option,
|
||||
const std::string& model_file);
|
||||
|
||||
bool InitFromPaddle(
|
||||
const std::string& model_file, const std::string& params_file,
|
||||
|
@@ -131,10 +131,13 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
|
||||
}
|
||||
char* model_content_ptr;
|
||||
int model_content_size = 0;
|
||||
char* calibration_cache_ptr;
|
||||
int calibration_cache_size = 0;
|
||||
if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(),
|
||||
&model_content_ptr, &model_content_size, 11, true,
|
||||
verbose, true, true, true, custom_ops.data(),
|
||||
custom_ops.size())) {
|
||||
custom_ops.size(), "tensorrt",
|
||||
&calibration_cache_ptr, &calibration_cache_size)) {
|
||||
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
|
||||
<< std::endl;
|
||||
return false;
|
||||
@@ -151,6 +154,13 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
|
||||
delete[] model_content_ptr;
|
||||
std::string onnx_model_proto(new_model, new_model + new_model_size);
|
||||
delete[] new_model;
|
||||
if (calibration_cache_size) {
|
||||
std::string calibration_str(
|
||||
calibration_cache_ptr,
|
||||
calibration_cache_ptr + calibration_cache_size);
|
||||
calibration_str_ = calibration_str;
|
||||
delete[] calibration_cache_ptr;
|
||||
}
|
||||
return InitFromOnnx(onnx_model_proto, option, true);
|
||||
}
|
||||
|
||||
@@ -158,6 +168,12 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
|
||||
model_content_ptr + model_content_size);
|
||||
delete[] model_content_ptr;
|
||||
model_content_ptr = nullptr;
|
||||
if (calibration_cache_size) {
|
||||
std::string calibration_str(calibration_cache_ptr,
|
||||
calibration_cache_ptr + calibration_cache_size);
|
||||
calibration_str_ = calibration_str;
|
||||
delete[] calibration_cache_ptr;
|
||||
}
|
||||
return InitFromOnnx(onnx_model_proto, option, true);
|
||||
#else
|
||||
FDERROR << "Didn't compile with PaddlePaddle frontend, you can try to "
|
||||
@@ -409,6 +425,7 @@ bool TrtBackend::BuildTrtEngine() {
|
||||
"will use FP32 instead."
|
||||
<< std::endl;
|
||||
} else {
|
||||
FDINFO << "[TrtBackend] Use FP16 to inference." << std::endl;
|
||||
config->setFlag(nvinfer1::BuilderFlag::kFP16);
|
||||
}
|
||||
}
|
||||
@@ -459,6 +476,20 @@ bool TrtBackend::BuildTrtEngine() {
|
||||
}
|
||||
config->addOptimizationProfile(profile);
|
||||
|
||||
if (calibration_str_.size()) {
|
||||
if (!builder_->platformHasFastInt8()) {
|
||||
FDWARNING << "Detected INT8 is not supported in the current GPU, "
|
||||
"will use FP32 instead."
|
||||
<< std::endl;
|
||||
} else {
|
||||
FDINFO << "[TrtBackend] Use INT8 to inference." << std::endl;
|
||||
config->setFlag(nvinfer1::BuilderFlag::kINT8);
|
||||
Int8EntropyCalibrator2* calibrator =
|
||||
new Int8EntropyCalibrator2(calibration_str_);
|
||||
config->setInt8Calibrator(calibrator);
|
||||
}
|
||||
}
|
||||
|
||||
FDUniquePtr<nvinfer1::IHostMemory> plan{
|
||||
builder_->buildSerializedNetwork(*network_, *config)};
|
||||
if (!plan) {
|
||||
|
28
fastdeploy/backends/tensorrt/trt_backend.h
Normal file → Executable file
28
fastdeploy/backends/tensorrt/trt_backend.h
Normal file → Executable file
@@ -26,6 +26,32 @@
|
||||
#include "fastdeploy/backends/backend.h"
|
||||
#include "fastdeploy/backends/tensorrt/utils.h"
|
||||
|
||||
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 {
|
||||
public:
|
||||
explicit Int8EntropyCalibrator2(const std::string& calibration_cache)
|
||||
: calibration_cache_(calibration_cache) {}
|
||||
|
||||
int getBatchSize() const noexcept override { return 0; }
|
||||
|
||||
bool getBatch(void* bindings[], const char* names[],
|
||||
int nbBindings) noexcept override {
|
||||
return false;
|
||||
}
|
||||
|
||||
const void* readCalibrationCache(size_t& length) noexcept override {
|
||||
length = calibration_cache_.size();
|
||||
return length ? calibration_cache_.data() : nullptr;
|
||||
}
|
||||
|
||||
void writeCalibrationCache(const void* cache,
|
||||
size_t length) noexcept override {
|
||||
std::cout << "NOT IMPLEMENT." << std::endl;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string calibration_cache_;
|
||||
};
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
struct TrtValueInfo {
|
||||
@@ -95,6 +121,8 @@ class TrtBackend : public BaseBackend {
|
||||
std::map<std::string, FDDeviceBuffer> inputs_buffer_;
|
||||
std::map<std::string, FDDeviceBuffer> outputs_buffer_;
|
||||
|
||||
std::string calibration_str_;
|
||||
|
||||
// Sometimes while the number of outputs > 1
|
||||
// the output order of tensorrt may not be same
|
||||
// with the original onnx model
|
||||
|
1
fastdeploy/pybind/runtime.cc
Normal file → Executable file
1
fastdeploy/pybind/runtime.cc
Normal file → Executable file
@@ -25,6 +25,7 @@ void BindRuntime(pybind11::module& m) {
|
||||
.def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum)
|
||||
.def("use_paddle_backend", &RuntimeOption::UsePaddleBackend)
|
||||
.def("use_ort_backend", &RuntimeOption::UseOrtBackend)
|
||||
.def("set_ort_graph_opt_level", &RuntimeOption::SetOrtGraphOptLevel)
|
||||
.def("use_trt_backend", &RuntimeOption::UseTrtBackend)
|
||||
.def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend)
|
||||
.def("use_lite_backend", &RuntimeOption::UseLiteBackend)
|
||||
|
7
fastdeploy/runtime.cc
Normal file → Executable file
7
fastdeploy/runtime.cc
Normal file → Executable file
@@ -198,6 +198,13 @@ void RuntimeOption::SetCpuThreadNum(int thread_num) {
|
||||
cpu_thread_num = thread_num;
|
||||
}
|
||||
|
||||
void RuntimeOption::SetOrtGraphOptLevel(int level) {
|
||||
std::vector<int> supported_level{-1, 0, 1, 2};
|
||||
auto valid_level = std::find(supported_level.begin(), supported_level.end(), level) != supported_level.end();
|
||||
FDASSERT(valid_level, "The level must be -1, 0, 1, 2.");
|
||||
ort_graph_opt_level = level;
|
||||
}
|
||||
|
||||
// use paddle inference backend
|
||||
void RuntimeOption::UsePaddleBackend() {
|
||||
#ifdef ENABLE_PADDLE_BACKEND
|
||||
|
4
fastdeploy/runtime.h
Normal file → Executable file
4
fastdeploy/runtime.h
Normal file → Executable file
@@ -22,6 +22,7 @@
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "fastdeploy/backends/backend.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
@@ -104,6 +105,9 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
*/
|
||||
void SetCpuThreadNum(int thread_num);
|
||||
|
||||
/// Use ORT graph opt level
|
||||
void SetOrtGraphOptLevel(int level = -1);
|
||||
|
||||
/// Set Paddle Inference as inference backend, support CPU/GPU
|
||||
void UsePaddleBackend();
|
||||
|
||||
|
3
python/fastdeploy/runtime.py
Normal file → Executable file
3
python/fastdeploy/runtime.py
Normal file → Executable file
@@ -117,6 +117,9 @@ class RuntimeOption:
|
||||
"""
|
||||
return self._option.set_cpu_thread_num(thread_num)
|
||||
|
||||
def set_ort_graph_opt_level(self, level=-1):
|
||||
return self._option.set_ort_graph_opt_level(level)
|
||||
|
||||
def use_paddle_backend(self):
|
||||
"""Use Paddle Inference backend, support inference Paddle model on CPU/Nvidia GPU.
|
||||
"""
|
||||
|
96
tests/eval_example/test_quantize_diff.py
Executable file
96
tests/eval_example/test_quantize_diff.py
Executable file
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
model_url = "https://bj.bcebos.com/fastdeploy/tests/yolov6_quant.tgz"
|
||||
fd.download_and_decompress(model_url, ".")
|
||||
|
||||
|
||||
def test_quant_mkldnn():
|
||||
model_path = "./yolov6_quant"
|
||||
model_file = os.path.join(model_path, "model.pdmodel")
|
||||
params_file = os.path.join(model_path, "model.pdiparams")
|
||||
|
||||
input_file = os.path.join(model_path, "input.npy")
|
||||
output_file = os.path.join(model_path, "mkldnn_output.npy")
|
||||
|
||||
option = fd.RuntimeOption()
|
||||
option.use_paddle_backend()
|
||||
option.use_cpu()
|
||||
|
||||
option.set_model_path(model_file, params_file)
|
||||
runtime = fd.Runtime(option)
|
||||
input_name = runtime.get_input_info(0).name
|
||||
data = np.load(input_file)
|
||||
outs = runtime.infer({input_name: data})
|
||||
expected = np.load(output_file)
|
||||
diff = np.fabs(outs[0] - expected)
|
||||
thres = 1e-05
|
||||
assert diff.max() < thres, "The diff is %f, which is bigger than %f" % (
|
||||
diff.max(), thres)
|
||||
|
||||
|
||||
def test_quant_ort():
|
||||
model_path = "./yolov6_quant"
|
||||
model_file = os.path.join(model_path, "model.pdmodel")
|
||||
params_file = os.path.join(model_path, "model.pdiparams")
|
||||
|
||||
input_file = os.path.join(model_path, "input.npy")
|
||||
output_file = os.path.join(model_path, "ort_output.npy")
|
||||
|
||||
option = fd.RuntimeOption()
|
||||
option.use_ort_backend()
|
||||
option.use_cpu()
|
||||
|
||||
option.set_ort_graph_opt_level(1)
|
||||
|
||||
option.set_model_path(model_file, params_file)
|
||||
runtime = fd.Runtime(option)
|
||||
input_name = runtime.get_input_info(0).name
|
||||
data = np.load(input_file)
|
||||
outs = runtime.infer({input_name: data})
|
||||
expected = np.load(output_file)
|
||||
diff = np.fabs(outs[0] - expected)
|
||||
thres = 1e-05
|
||||
assert diff.max() < thres, "The diff is %f, which is bigger than %f" % (
|
||||
diff.max(), thres)
|
||||
|
||||
|
||||
def test_quant_trt():
|
||||
model_path = "./yolov6_quant"
|
||||
model_file = os.path.join(model_path, "model.pdmodel")
|
||||
params_file = os.path.join(model_path, "model.pdiparams")
|
||||
|
||||
input_file = os.path.join(model_path, "input.npy")
|
||||
output_file = os.path.join(model_path, "trt_output.npy")
|
||||
|
||||
option = fd.RuntimeOption()
|
||||
option.use_trt_backend()
|
||||
option.use_gpu()
|
||||
|
||||
option.set_model_path(model_file, params_file)
|
||||
runtime = fd.Runtime(option)
|
||||
input_name = runtime.get_input_info(0).name
|
||||
data = np.load(input_file)
|
||||
outs = runtime.infer({input_name: data})
|
||||
expected = np.load(output_file)
|
||||
diff = np.fabs(outs[0] - expected)
|
||||
thres = 1e-05
|
||||
assert diff.max() < thres, "The diff is %f, which is bigger than %f" % (
|
||||
diff.max(), thres)
|
Reference in New Issue
Block a user