mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
Add paddle quantize model support for ORT, TRT and MKLDNN deploy backend (#257)
* add quantize model support for trt and paddle * fix bugs * fix * update paddle2onnx version * update version * add quantize test Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
2
cmake/paddle2onnx.cmake
Normal file → Executable file
2
cmake/paddle2onnx.cmake
Normal file → Executable file
@@ -43,7 +43,7 @@ else()
|
|||||||
endif(WIN32)
|
endif(WIN32)
|
||||||
|
|
||||||
set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/fastdeploy/third_libs/")
|
set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/fastdeploy/third_libs/")
|
||||||
set(PADDLE2ONNX_VERSION "1.0.1rc")
|
set(PADDLE2ONNX_VERSION "1.0.1")
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip")
|
set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip")
|
||||||
if(NOT CMAKE_CL_64)
|
if(NOT CMAKE_CL_64)
|
||||||
|
@@ -16,13 +16,23 @@
|
|||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
|
void PaddleBackend::BuildOption(const PaddleBackendOption& option,
|
||||||
|
const std::string& model_file) {
|
||||||
if (option.use_gpu) {
|
if (option.use_gpu) {
|
||||||
config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id);
|
config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id);
|
||||||
} else {
|
} else {
|
||||||
config_.DisableGpu();
|
config_.DisableGpu();
|
||||||
if (option.enable_mkldnn) {
|
if (option.enable_mkldnn) {
|
||||||
config_.EnableMKLDNN();
|
config_.EnableMKLDNN();
|
||||||
|
std::string contents;
|
||||||
|
if (!ReadBinaryFromFile(model_file, &contents)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto reader =
|
||||||
|
paddle2onnx::PaddleReader(contents.c_str(), contents.size());
|
||||||
|
if (reader.is_quantize_model) {
|
||||||
|
config_.EnableMkldnnInt8();
|
||||||
|
}
|
||||||
config_.SetMkldnnCacheCapacity(option.mkldnn_cache_size);
|
config_.SetMkldnnCacheCapacity(option.mkldnn_cache_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -52,7 +62,7 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
config_.SetModel(model_file, params_file);
|
config_.SetModel(model_file, params_file);
|
||||||
BuildOption(option);
|
BuildOption(option, model_file);
|
||||||
predictor_ = paddle_infer::CreatePredictor(config_);
|
predictor_ = paddle_infer::CreatePredictor(config_);
|
||||||
std::vector<std::string> input_names = predictor_->GetInputNames();
|
std::vector<std::string> input_names = predictor_->GetInputNames();
|
||||||
std::vector<std::string> output_names = predictor_->GetOutputNames();
|
std::vector<std::string> output_names = predictor_->GetOutputNames();
|
||||||
|
4
fastdeploy/backends/paddle/paddle_backend.h
Normal file → Executable file
4
fastdeploy/backends/paddle/paddle_backend.h
Normal file → Executable file
@@ -20,6 +20,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "fastdeploy/backends/backend.h"
|
#include "fastdeploy/backends/backend.h"
|
||||||
|
#include "paddle2onnx/converter.h"
|
||||||
#include "paddle_inference_api.h" // NOLINT
|
#include "paddle_inference_api.h" // NOLINT
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
@@ -61,7 +62,8 @@ class PaddleBackend : public BaseBackend {
|
|||||||
public:
|
public:
|
||||||
PaddleBackend() {}
|
PaddleBackend() {}
|
||||||
virtual ~PaddleBackend() = default;
|
virtual ~PaddleBackend() = default;
|
||||||
void BuildOption(const PaddleBackendOption& option);
|
void BuildOption(const PaddleBackendOption& option,
|
||||||
|
const std::string& model_file);
|
||||||
|
|
||||||
bool InitFromPaddle(
|
bool InitFromPaddle(
|
||||||
const std::string& model_file, const std::string& params_file,
|
const std::string& model_file, const std::string& params_file,
|
||||||
|
@@ -131,10 +131,13 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
}
|
}
|
||||||
char* model_content_ptr;
|
char* model_content_ptr;
|
||||||
int model_content_size = 0;
|
int model_content_size = 0;
|
||||||
|
char* calibration_cache_ptr;
|
||||||
|
int calibration_cache_size = 0;
|
||||||
if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(),
|
if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(),
|
||||||
&model_content_ptr, &model_content_size, 11, true,
|
&model_content_ptr, &model_content_size, 11, true,
|
||||||
verbose, true, true, true, custom_ops.data(),
|
verbose, true, true, true, custom_ops.data(),
|
||||||
custom_ops.size())) {
|
custom_ops.size(), "tensorrt",
|
||||||
|
&calibration_cache_ptr, &calibration_cache_size)) {
|
||||||
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
|
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
return false;
|
return false;
|
||||||
@@ -151,6 +154,13 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
delete[] model_content_ptr;
|
delete[] model_content_ptr;
|
||||||
std::string onnx_model_proto(new_model, new_model + new_model_size);
|
std::string onnx_model_proto(new_model, new_model + new_model_size);
|
||||||
delete[] new_model;
|
delete[] new_model;
|
||||||
|
if (calibration_cache_size) {
|
||||||
|
std::string calibration_str(
|
||||||
|
calibration_cache_ptr,
|
||||||
|
calibration_cache_ptr + calibration_cache_size);
|
||||||
|
calibration_str_ = calibration_str;
|
||||||
|
delete[] calibration_cache_ptr;
|
||||||
|
}
|
||||||
return InitFromOnnx(onnx_model_proto, option, true);
|
return InitFromOnnx(onnx_model_proto, option, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,6 +168,12 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
model_content_ptr + model_content_size);
|
model_content_ptr + model_content_size);
|
||||||
delete[] model_content_ptr;
|
delete[] model_content_ptr;
|
||||||
model_content_ptr = nullptr;
|
model_content_ptr = nullptr;
|
||||||
|
if (calibration_cache_size) {
|
||||||
|
std::string calibration_str(calibration_cache_ptr,
|
||||||
|
calibration_cache_ptr + calibration_cache_size);
|
||||||
|
calibration_str_ = calibration_str;
|
||||||
|
delete[] calibration_cache_ptr;
|
||||||
|
}
|
||||||
return InitFromOnnx(onnx_model_proto, option, true);
|
return InitFromOnnx(onnx_model_proto, option, true);
|
||||||
#else
|
#else
|
||||||
FDERROR << "Didn't compile with PaddlePaddle frontend, you can try to "
|
FDERROR << "Didn't compile with PaddlePaddle frontend, you can try to "
|
||||||
@@ -409,6 +425,7 @@ bool TrtBackend::BuildTrtEngine() {
|
|||||||
"will use FP32 instead."
|
"will use FP32 instead."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
} else {
|
} else {
|
||||||
|
FDINFO << "[TrtBackend] Use FP16 to inference." << std::endl;
|
||||||
config->setFlag(nvinfer1::BuilderFlag::kFP16);
|
config->setFlag(nvinfer1::BuilderFlag::kFP16);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -459,6 +476,20 @@ bool TrtBackend::BuildTrtEngine() {
|
|||||||
}
|
}
|
||||||
config->addOptimizationProfile(profile);
|
config->addOptimizationProfile(profile);
|
||||||
|
|
||||||
|
if (calibration_str_.size()) {
|
||||||
|
if (!builder_->platformHasFastInt8()) {
|
||||||
|
FDWARNING << "Detected INT8 is not supported in the current GPU, "
|
||||||
|
"will use FP32 instead."
|
||||||
|
<< std::endl;
|
||||||
|
} else {
|
||||||
|
FDINFO << "[TrtBackend] Use INT8 to inference." << std::endl;
|
||||||
|
config->setFlag(nvinfer1::BuilderFlag::kINT8);
|
||||||
|
Int8EntropyCalibrator2* calibrator =
|
||||||
|
new Int8EntropyCalibrator2(calibration_str_);
|
||||||
|
config->setInt8Calibrator(calibrator);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
FDUniquePtr<nvinfer1::IHostMemory> plan{
|
FDUniquePtr<nvinfer1::IHostMemory> plan{
|
||||||
builder_->buildSerializedNetwork(*network_, *config)};
|
builder_->buildSerializedNetwork(*network_, *config)};
|
||||||
if (!plan) {
|
if (!plan) {
|
||||||
|
28
fastdeploy/backends/tensorrt/trt_backend.h
Normal file → Executable file
28
fastdeploy/backends/tensorrt/trt_backend.h
Normal file → Executable file
@@ -26,6 +26,32 @@
|
|||||||
#include "fastdeploy/backends/backend.h"
|
#include "fastdeploy/backends/backend.h"
|
||||||
#include "fastdeploy/backends/tensorrt/utils.h"
|
#include "fastdeploy/backends/tensorrt/utils.h"
|
||||||
|
|
||||||
|
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 {
|
||||||
|
public:
|
||||||
|
explicit Int8EntropyCalibrator2(const std::string& calibration_cache)
|
||||||
|
: calibration_cache_(calibration_cache) {}
|
||||||
|
|
||||||
|
int getBatchSize() const noexcept override { return 0; }
|
||||||
|
|
||||||
|
bool getBatch(void* bindings[], const char* names[],
|
||||||
|
int nbBindings) noexcept override {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const void* readCalibrationCache(size_t& length) noexcept override {
|
||||||
|
length = calibration_cache_.size();
|
||||||
|
return length ? calibration_cache_.data() : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void writeCalibrationCache(const void* cache,
|
||||||
|
size_t length) noexcept override {
|
||||||
|
std::cout << "NOT IMPLEMENT." << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string calibration_cache_;
|
||||||
|
};
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
struct TrtValueInfo {
|
struct TrtValueInfo {
|
||||||
@@ -95,6 +121,8 @@ class TrtBackend : public BaseBackend {
|
|||||||
std::map<std::string, FDDeviceBuffer> inputs_buffer_;
|
std::map<std::string, FDDeviceBuffer> inputs_buffer_;
|
||||||
std::map<std::string, FDDeviceBuffer> outputs_buffer_;
|
std::map<std::string, FDDeviceBuffer> outputs_buffer_;
|
||||||
|
|
||||||
|
std::string calibration_str_;
|
||||||
|
|
||||||
// Sometimes while the number of outputs > 1
|
// Sometimes while the number of outputs > 1
|
||||||
// the output order of tensorrt may not be same
|
// the output order of tensorrt may not be same
|
||||||
// with the original onnx model
|
// with the original onnx model
|
||||||
|
1
fastdeploy/pybind/runtime.cc
Normal file → Executable file
1
fastdeploy/pybind/runtime.cc
Normal file → Executable file
@@ -25,6 +25,7 @@ void BindRuntime(pybind11::module& m) {
|
|||||||
.def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum)
|
.def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum)
|
||||||
.def("use_paddle_backend", &RuntimeOption::UsePaddleBackend)
|
.def("use_paddle_backend", &RuntimeOption::UsePaddleBackend)
|
||||||
.def("use_ort_backend", &RuntimeOption::UseOrtBackend)
|
.def("use_ort_backend", &RuntimeOption::UseOrtBackend)
|
||||||
|
.def("set_ort_graph_opt_level", &RuntimeOption::SetOrtGraphOptLevel)
|
||||||
.def("use_trt_backend", &RuntimeOption::UseTrtBackend)
|
.def("use_trt_backend", &RuntimeOption::UseTrtBackend)
|
||||||
.def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend)
|
.def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend)
|
||||||
.def("use_lite_backend", &RuntimeOption::UseLiteBackend)
|
.def("use_lite_backend", &RuntimeOption::UseLiteBackend)
|
||||||
|
7
fastdeploy/runtime.cc
Normal file → Executable file
7
fastdeploy/runtime.cc
Normal file → Executable file
@@ -198,6 +198,13 @@ void RuntimeOption::SetCpuThreadNum(int thread_num) {
|
|||||||
cpu_thread_num = thread_num;
|
cpu_thread_num = thread_num;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RuntimeOption::SetOrtGraphOptLevel(int level) {
|
||||||
|
std::vector<int> supported_level{-1, 0, 1, 2};
|
||||||
|
auto valid_level = std::find(supported_level.begin(), supported_level.end(), level) != supported_level.end();
|
||||||
|
FDASSERT(valid_level, "The level must be -1, 0, 1, 2.");
|
||||||
|
ort_graph_opt_level = level;
|
||||||
|
}
|
||||||
|
|
||||||
// use paddle inference backend
|
// use paddle inference backend
|
||||||
void RuntimeOption::UsePaddleBackend() {
|
void RuntimeOption::UsePaddleBackend() {
|
||||||
#ifdef ENABLE_PADDLE_BACKEND
|
#ifdef ENABLE_PADDLE_BACKEND
|
||||||
|
4
fastdeploy/runtime.h
Normal file → Executable file
4
fastdeploy/runtime.h
Normal file → Executable file
@@ -22,6 +22,7 @@
|
|||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
#include "fastdeploy/backends/backend.h"
|
#include "fastdeploy/backends/backend.h"
|
||||||
#include "fastdeploy/utils/perf.h"
|
#include "fastdeploy/utils/perf.h"
|
||||||
@@ -104,6 +105,9 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
|||||||
*/
|
*/
|
||||||
void SetCpuThreadNum(int thread_num);
|
void SetCpuThreadNum(int thread_num);
|
||||||
|
|
||||||
|
/// Use ORT graph opt level
|
||||||
|
void SetOrtGraphOptLevel(int level = -1);
|
||||||
|
|
||||||
/// Set Paddle Inference as inference backend, support CPU/GPU
|
/// Set Paddle Inference as inference backend, support CPU/GPU
|
||||||
void UsePaddleBackend();
|
void UsePaddleBackend();
|
||||||
|
|
||||||
|
3
python/fastdeploy/runtime.py
Normal file → Executable file
3
python/fastdeploy/runtime.py
Normal file → Executable file
@@ -117,6 +117,9 @@ class RuntimeOption:
|
|||||||
"""
|
"""
|
||||||
return self._option.set_cpu_thread_num(thread_num)
|
return self._option.set_cpu_thread_num(thread_num)
|
||||||
|
|
||||||
|
def set_ort_graph_opt_level(self, level=-1):
|
||||||
|
return self._option.set_ort_graph_opt_level(level)
|
||||||
|
|
||||||
def use_paddle_backend(self):
|
def use_paddle_backend(self):
|
||||||
"""Use Paddle Inference backend, support inference Paddle model on CPU/Nvidia GPU.
|
"""Use Paddle Inference backend, support inference Paddle model on CPU/Nvidia GPU.
|
||||||
"""
|
"""
|
||||||
|
96
tests/eval_example/test_quantize_diff.py
Executable file
96
tests/eval_example/test_quantize_diff.py
Executable file
@@ -0,0 +1,96 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import fastdeploy as fd
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
model_url = "https://bj.bcebos.com/fastdeploy/tests/yolov6_quant.tgz"
|
||||||
|
fd.download_and_decompress(model_url, ".")
|
||||||
|
|
||||||
|
|
||||||
|
def test_quant_mkldnn():
|
||||||
|
model_path = "./yolov6_quant"
|
||||||
|
model_file = os.path.join(model_path, "model.pdmodel")
|
||||||
|
params_file = os.path.join(model_path, "model.pdiparams")
|
||||||
|
|
||||||
|
input_file = os.path.join(model_path, "input.npy")
|
||||||
|
output_file = os.path.join(model_path, "mkldnn_output.npy")
|
||||||
|
|
||||||
|
option = fd.RuntimeOption()
|
||||||
|
option.use_paddle_backend()
|
||||||
|
option.use_cpu()
|
||||||
|
|
||||||
|
option.set_model_path(model_file, params_file)
|
||||||
|
runtime = fd.Runtime(option)
|
||||||
|
input_name = runtime.get_input_info(0).name
|
||||||
|
data = np.load(input_file)
|
||||||
|
outs = runtime.infer({input_name: data})
|
||||||
|
expected = np.load(output_file)
|
||||||
|
diff = np.fabs(outs[0] - expected)
|
||||||
|
thres = 1e-05
|
||||||
|
assert diff.max() < thres, "The diff is %f, which is bigger than %f" % (
|
||||||
|
diff.max(), thres)
|
||||||
|
|
||||||
|
|
||||||
|
def test_quant_ort():
|
||||||
|
model_path = "./yolov6_quant"
|
||||||
|
model_file = os.path.join(model_path, "model.pdmodel")
|
||||||
|
params_file = os.path.join(model_path, "model.pdiparams")
|
||||||
|
|
||||||
|
input_file = os.path.join(model_path, "input.npy")
|
||||||
|
output_file = os.path.join(model_path, "ort_output.npy")
|
||||||
|
|
||||||
|
option = fd.RuntimeOption()
|
||||||
|
option.use_ort_backend()
|
||||||
|
option.use_cpu()
|
||||||
|
|
||||||
|
option.set_ort_graph_opt_level(1)
|
||||||
|
|
||||||
|
option.set_model_path(model_file, params_file)
|
||||||
|
runtime = fd.Runtime(option)
|
||||||
|
input_name = runtime.get_input_info(0).name
|
||||||
|
data = np.load(input_file)
|
||||||
|
outs = runtime.infer({input_name: data})
|
||||||
|
expected = np.load(output_file)
|
||||||
|
diff = np.fabs(outs[0] - expected)
|
||||||
|
thres = 1e-05
|
||||||
|
assert diff.max() < thres, "The diff is %f, which is bigger than %f" % (
|
||||||
|
diff.max(), thres)
|
||||||
|
|
||||||
|
|
||||||
|
def test_quant_trt():
|
||||||
|
model_path = "./yolov6_quant"
|
||||||
|
model_file = os.path.join(model_path, "model.pdmodel")
|
||||||
|
params_file = os.path.join(model_path, "model.pdiparams")
|
||||||
|
|
||||||
|
input_file = os.path.join(model_path, "input.npy")
|
||||||
|
output_file = os.path.join(model_path, "trt_output.npy")
|
||||||
|
|
||||||
|
option = fd.RuntimeOption()
|
||||||
|
option.use_trt_backend()
|
||||||
|
option.use_gpu()
|
||||||
|
|
||||||
|
option.set_model_path(model_file, params_file)
|
||||||
|
runtime = fd.Runtime(option)
|
||||||
|
input_name = runtime.get_input_info(0).name
|
||||||
|
data = np.load(input_file)
|
||||||
|
outs = runtime.infer({input_name: data})
|
||||||
|
expected = np.load(output_file)
|
||||||
|
diff = np.fabs(outs[0] - expected)
|
||||||
|
thres = 1e-05
|
||||||
|
assert diff.max() < thres, "The diff is %f, which is bigger than %f" % (
|
||||||
|
diff.max(), thres)
|
Reference in New Issue
Block a user