Fix runtime with python (#63)

This commit is contained in:
Jason
2022-08-01 17:51:13 +08:00
committed by GitHub
parent fc38ec90b8
commit 90e0f53e48
14 changed files with 219 additions and 225 deletions

View File

@@ -26,35 +26,6 @@ namespace fastdeploy {
std::vector<OrtCustomOp*> OrtBackend::custom_operators_ = std::vector<OrtCustomOp*> OrtBackend::custom_operators_ =
std::vector<OrtCustomOp*>(); std::vector<OrtCustomOp*>();
ONNXTensorElementDataType GetOrtDtype(FDDataType fd_dtype) {
if (fd_dtype == FDDataType::FP32) {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
} else if (fd_dtype == FDDataType::FP64) {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
} else if (fd_dtype == FDDataType::INT32) {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
} else if (fd_dtype == FDDataType::INT64) {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
FDERROR << "Unrecognized fastdeply data type:" << FDDataTypeStr(fd_dtype)
<< "." << std::endl;
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
}
FDDataType GetFdDtype(ONNXTensorElementDataType ort_dtype) {
if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
return FDDataType::FP32;
} else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
return FDDataType::FP64;
} else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
return FDDataType::INT32;
} else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
return FDDataType::INT64;
}
FDERROR << "Unrecognized ort data type:" << ort_dtype << "." << std::endl;
return FDDataType::FP32;
}
void OrtBackend::BuildOption(const OrtBackendOption& option) { void OrtBackend::BuildOption(const OrtBackendOption& option) {
option_ = option; option_ = option;
if (option.graph_optimization_level >= 0) { if (option.graph_optimization_level >= 0) {
@@ -263,6 +234,7 @@ bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
(*outputs)[i].name = outputs_desc_[i].name; (*outputs)[i].name = outputs_desc_[i].name;
CopyToCpu(ort_outputs[i], &((*outputs)[i])); CopyToCpu(ort_outputs[i], &((*outputs)[i]));
} }
return true; return true;
} }

View File

@@ -27,8 +27,8 @@ ONNXTensorElementDataType GetOrtDtype(const FDDataType& fd_dtype) {
} else if (fd_dtype == FDDataType::INT64) { } else if (fd_dtype == FDDataType::INT64) {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
} }
FDERROR << "Unrecognized fastdeply data type:" << FDDataTypeStr(fd_dtype) FDERROR << "Unrecognized fastdeply data type:" << Str(fd_dtype) << "."
<< "." << std::endl; << std::endl;
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
} }

View File

@@ -28,7 +28,7 @@ namespace fastdeploy {
ONNXTensorElementDataType GetOrtDtype(const FDDataType& fd_dtype); ONNXTensorElementDataType GetOrtDtype(const FDDataType& fd_dtype);
// Convert OrtDataType to FDDataType // Convert OrtDataType to FDDataType
FDDataType GetFdDtype(const ONNXTensorElementDataType* ort_dtype); FDDataType GetFdDtype(const ONNXTensorElementDataType& ort_dtype);
// Create Ort::Value // Create Ort::Value
// is_backend_cuda specify if the onnxruntime use CUDAExectionProvider // is_backend_cuda specify if the onnxruntime use CUDAExectionProvider

View File

@@ -119,8 +119,8 @@ void FDTensor::PrintInfo(const std::string& prefix) {
for (int i = 0; i < shape.size(); ++i) { for (int i = 0; i < shape.size(); ++i) {
std::cout << shape[i] << " "; std::cout << shape[i] << " ";
} }
std::cout << ", dtype=" << FDDataTypeStr(dtype) << ", mean=" << mean std::cout << ", dtype=" << Str(dtype) << ", mean=" << mean << ", max=" << max
<< ", max=" << max << ", min=" << min << std::endl; << ", min=" << min << std::endl;
} }
FDTensor::FDTensor(const std::string& tensor_name) { name = tensor_name; } FDTensor::FDTensor(const std::string& tensor_name) { name = tensor_name; }

View File

@@ -17,7 +17,7 @@
namespace fastdeploy { namespace fastdeploy {
int FDDataTypeSize(FDDataType data_type) { int FDDataTypeSize(const FDDataType& data_type) {
FDASSERT(data_type != FDDataType::FP16, "Float16 is not supported."); FDASSERT(data_type != FDDataType::FP16, "Float16 is not supported.");
if (data_type == FDDataType::BOOL) { if (data_type == FDDataType::BOOL) {
return sizeof(bool); return sizeof(bool);
@@ -34,38 +34,12 @@ int FDDataTypeSize(FDDataType data_type) {
} else if (data_type == FDDataType::UINT8) { } else if (data_type == FDDataType::UINT8) {
return sizeof(uint8_t); return sizeof(uint8_t);
} else { } else {
FDASSERT(false, "Unexpected data type: " + FDDataTypeStr(data_type)); FDASSERT(false, "Unexpected data type: " + Str(data_type));
} }
return -1; return -1;
} }
std::string FDDataTypeStr(FDDataType data_type) { std::string Str(const Device& d) {
FDASSERT(data_type != FDDataType::FP16, "Float16 is not supported.");
if (data_type == FDDataType::BOOL) {
return "bool";
} else if (data_type == FDDataType::INT16) {
return "int16";
} else if (data_type == FDDataType::INT32) {
return "int32";
} else if (data_type == FDDataType::INT64) {
return "int64";
} else if (data_type == FDDataType::FP16) {
return "float16";
} else if (data_type == FDDataType::FP32) {
return "float32";
} else if (data_type == FDDataType::FP64) {
return "float64";
} else if (data_type == FDDataType::UINT8) {
return "uint8";
} else if (data_type == FDDataType::INT8) {
return "int8";
} else {
FDASSERT(false, "Unexpected data type: " + FDDataTypeStr(data_type));
}
return "UNKNOWN!";
}
std::string Str(Device& d) {
std::string out; std::string out;
switch (d) { switch (d) {
case Device::DEFAULT: case Device::DEFAULT:
@@ -83,7 +57,7 @@ std::string Str(Device& d) {
return out; return out;
} }
std::string Str(FDDataType& fdt) { std::string Str(const FDDataType& fdt) {
std::string out; std::string out;
switch (fdt) { switch (fdt) {
case FDDataType::BOOL: case FDDataType::BOOL:

View File

@@ -24,7 +24,7 @@ namespace fastdeploy {
enum FASTDEPLOY_DECL Device { DEFAULT, CPU, GPU }; enum FASTDEPLOY_DECL Device { DEFAULT, CPU, GPU };
FASTDEPLOY_DECL std::string Str(Device& d); FASTDEPLOY_DECL std::string Str(const Device& d);
enum FASTDEPLOY_DECL FDDataType { enum FASTDEPLOY_DECL FDDataType {
BOOL, BOOL,
@@ -51,9 +51,7 @@ enum FASTDEPLOY_DECL FDDataType {
INT8 INT8
}; };
FASTDEPLOY_DECL std::string Str(FDDataType& fdt); FASTDEPLOY_DECL std::string Str(const FDDataType& fdt);
FASTDEPLOY_DECL int32_t FDDataTypeSize(FDDataType data_dtype); FASTDEPLOY_DECL int32_t FDDataTypeSize(const FDDataType& data_dtype);
FASTDEPLOY_DECL std::string FDDataTypeStr(FDDataType data_dtype);
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -79,6 +79,7 @@ void BindRuntime(pybind11::module& m) {
memcpy(inputs[index].data.data(), iter->second.mutable_data(), memcpy(inputs[index].data.data(), iter->second.mutable_data(),
iter->second.nbytes()); iter->second.nbytes());
inputs[index].name = iter->first; inputs[index].name = iter->first;
index += 1;
} }
std::vector<FDTensor> outputs(self.NumOutputs()); std::vector<FDTensor> outputs(self.NumOutputs());

View File

@@ -32,7 +32,7 @@ pybind11::dtype FDDataTypeToNumpyDataType(const FDDataType& fd_dtype) {
dt = pybind11::dtype::of<double>(); dt = pybind11::dtype::of<double>();
} else { } else {
FDASSERT(false, "The function doesn't support data type of " + FDASSERT(false, "The function doesn't support data type of " +
FDDataTypeStr(fd_dtype) + "."); Str(fd_dtype) + ".");
} }
return dt; return dt;
} }
@@ -47,7 +47,8 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype) {
} else if (np_dtype.is(pybind11::dtype::of<double>())) { } else if (np_dtype.is(pybind11::dtype::of<double>())) {
return FDDataType::FP64; return FDDataType::FP64;
} }
FDASSERT(false, "NumpyDataTypeToFDDataType() only support " FDASSERT(false,
"NumpyDataTypeToFDDataType() only support "
"int32/int64/float32/float64 now."); "int32/int64/float32/float64 now.");
return FDDataType::FP32; return FDDataType::FP32;
} }

View File

@@ -23,20 +23,31 @@ PPYOLOE::PPYOLOE(const std::string& model_file, const std::string& params_file,
initialized = Initialize(); initialized = Initialize();
} }
bool PPYOLOE::Initialize() { void PPYOLOE::GetNmsInfo() {
#ifdef ENABLE_PADDLE_FRONTEND
// remove multiclass_nms3 now
// this is a trick operation for ppyoloe while inference on trt
if (runtime_option.model_format == Frontend::PADDLE) { if (runtime_option.model_format == Frontend::PADDLE) {
std::string contents; std::string contents;
if (!ReadBinaryFromFile(runtime_option.model_file, &contents)) { if (!ReadBinaryFromFile(runtime_option.model_file, &contents)) {
return false; return;
} }
auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size()); auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size());
if (reader.has_nms) { if (reader.has_nms) {
has_nms_ = true; has_nms_ = true;
background_label = reader.nms_params.background_label;
keep_top_k = reader.nms_params.keep_top_k;
nms_eta = reader.nms_params.nms_eta;
nms_threshold = reader.nms_params.nms_threshold;
score_threshold = reader.nms_params.score_threshold;
nms_top_k = reader.nms_params.nms_top_k;
normalized = reader.nms_params.normalized;
} }
} }
}
bool PPYOLOE::Initialize() {
#ifdef ENABLE_PADDLE_FRONTEND
// remove multiclass_nms3 now
// this is a trick operation for ppyoloe while inference on trt
GetNmsInfo();
runtime_option.remove_multiclass_nms_ = true; runtime_option.remove_multiclass_nms_ = true;
runtime_option.custom_op_info_["multiclass_nms3"] = "MultiClassNMS"; runtime_option.custom_op_info_["multiclass_nms3"] = "MultiClassNMS";
#endif #endif
@@ -52,8 +63,12 @@ bool PPYOLOE::Initialize() {
if (has_nms_ && runtime_option.backend == Backend::TRT) { if (has_nms_ && runtime_option.backend == Backend::TRT) {
FDINFO << "Detected operator multiclass_nms3 in your model, will replace " FDINFO << "Detected operator multiclass_nms3 in your model, will replace "
"it with fastdeploy::backend::MultiClassNMS replace it." "it with fastdeploy::backend::MultiClassNMS(background_label="
<< std::endl; << background_label << ", keep_top_k=" << keep_top_k
<< ", nms_eta=" << nms_eta << ", nms_threshold=" << nms_threshold
<< ", score_threshold=" << score_threshold
<< ", nms_top_k=" << nms_top_k << ", normalized=" << normalized
<< ")." << std::endl;
has_nms_ = false; has_nms_ = false;
} }
return true; return true;

View File

@@ -42,6 +42,10 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel {
int64_t nms_top_k = 10000; int64_t nms_top_k = 10000;
bool normalized = true; bool normalized = true;
bool has_nms_ = false; bool has_nms_ = false;
// This function will used to check if this model contains multiclass_nms
// and get parameters from the operator
void GetNmsInfo();
}; };
} // namespace ppdet } // namespace ppdet
} // namespace vision } // namespace vision

View File

@@ -43,7 +43,7 @@ else()
endif(WIN32) endif(WIN32)
set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/paddle2onnx/libs/") set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/paddle2onnx/libs/")
set(PADDLE2ONNX_VERSION "1.0.0rc2") set(PADDLE2ONNX_VERSION "1.0.0rc3")
if(WIN32) if(WIN32)
set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip") set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip")
if(NOT CMAKE_CL_64) if(NOT CMAKE_CL_64)

View File

@@ -16,12 +16,14 @@ import logging
import os import os
import sys import sys
def add_dll_search_dir(dir_path): def add_dll_search_dir(dir_path):
os.environ["path"] = dir_path + ";" + os.environ["path"] os.environ["path"] = dir_path + ";" + os.environ["path"]
sys.path.insert(0, dir_path) sys.path.insert(0, dir_path)
if sys.version_info[:2] >= (3, 8): if sys.version_info[:2] >= (3, 8):
os.add_dll_directory(dir_path) os.add_dll_directory(dir_path)
if os.name == "nt": if os.name == "nt":
current_path = os.path.abspath(__file__) current_path = os.path.abspath(__file__)
dirname = os.path.dirname(current_path) dirname = os.path.dirname(current_path)
@@ -33,82 +35,19 @@ if os.name == "nt":
add_dll_search_dir(os.path.join(dirname, root, d)) add_dll_search_dir(os.path.join(dirname, root, d))
from .fastdeploy_main import Frontend, Backend, FDDataType, TensorInfo, Device from .fastdeploy_main import Frontend, Backend, FDDataType, TensorInfo, Device
from .fastdeploy_runtime import * from .runtime import Runtime, RuntimeOption
from .model import FastDeployModel
from . import fastdeploy_main as C from . import fastdeploy_main as C
from . import vision from . import vision
from .download import download, download_and_decompress from .download import download, download_and_decompress
def TensorInfoStr(tensor_info): def TensorInfoStr(tensor_info):
message = "TensorInfo(name : '{}', dtype : '{}', shape : '{}')".format( message = "TensorInfo(name : '{}', dtype : '{}', shape : '{}')".format(
tensor_info.name, tensor_info.dtype, tensor_info.shape) tensor_info.name, tensor_info.dtype, tensor_info.shape)
return message return message
class RuntimeOption:
def __init__(self):
self._option = C.RuntimeOption()
def set_model_path(self, model_path, params_path="", model_format="paddle"):
return self._option.set_model_path(model_path, params_path, model_format)
def use_gpu(self, device_id=0):
return self._option.use_gpu(device_id)
def use_cpu(self):
return self._option.use_cpu()
def set_cpu_thread_num(self, thread_num=8):
return self._option.set_cpu_thread_num(thread_num)
def use_paddle_backend(self):
return self._option.use_paddle_backend()
def use_ort_backend(self):
return self._option.use_ort_backend()
def use_trt_backend(self):
return self._option.use_trt_backend()
def enable_paddle_mkldnn(self):
return self._option.enable_paddle_mkldnn()
def disable_paddle_mkldnn(self):
return self._option.disable_paddle_mkldnn()
def set_paddle_mkldnn_cache_size(self, cache_size):
return self._option.set_paddle_mkldnn_cache_size(cache_size)
def set_trt_input_shape(self, tensor_name, min_shape, opt_shape=None, max_shape=None):
if opt_shape is None and max_shape is None:
opt_shape = min_shape
max_shape = min_shape
else:
assert opt_shape is not None and max_shape is not None, "Set min_shape only, or set min_shape, opt_shape, max_shape both."
return self._option.set_trt_input_shape(tensor_name, min_shape, opt_shape, max_shape)
def set_trt_cache_file(self, cache_file_path):
return self._option.set_trt_cache_file(cache_file_path)
def enable_trt_fp16(self):
return self._option.enable_trt_fp16()
def dissable_trt_fp16(self):
return self._option.disable_trt_fp16()
def __repr__(self):
attrs = dir(self._option)
message = "RuntimeOption(\n"
for attr in attrs:
if attr.startswith("__"):
continue
if hasattr(getattr(self._option, attr), "__call__"):
continue
message += " {} : {}\t\n".format(attr, getattr(self._option, attr))
message.strip("\n")
message += ")"
return message
def RuntimeOptionStr(runtime_option): def RuntimeOptionStr(runtime_option):
attrs = dir(runtime_option) attrs = dir(runtime_option)
message = "RuntimeOption(\n" message = "RuntimeOption(\n"
@@ -122,5 +61,6 @@ def RuntimeOptionStr(runtime_option):
message += ")" message += ")"
return message return message
C.TensorInfo.__repr__ = TensorInfoStr C.TensorInfo.__repr__ = TensorInfoStr
C.RuntimeOption.__repr__ = RuntimeOptionStr C.RuntimeOption.__repr__ = RuntimeOptionStr

View File

@@ -54,35 +54,3 @@ class FastDeployModel:
if self._model is None: if self._model is None:
return False return False
return self._model.initialized() return self._model.initialized()
class Runtime:
def __init__(self, runtime_option):
self._runtime = C.Runtime()
assert self._runtime.init(runtime_option), "Initialize Runtime Failed!"
def infer(self, data):
assert isinstance(data, dict), "The input data should be type of dict."
return self._runtime.infer(data)
def num_inputs(self):
return self._runtime.num_inputs()
def num_outputs(self):
return self._runtime.num_outputs()
def get_input_info(self, index):
assert isinstance(
index, int), "The input parameter index should be type of int."
assert index < self.num_inputs(
), "The input parameter index:{} should less than number of inputs:{}.".format(
index, self.num_inputs)
return self._runtime.get_input_info(index)
def get_output_info(self, index):
assert isinstance(
index, int), "The input parameter index should be type of int."
assert index < self.num_outputs(
), "The input parameter index:{} should less than number of outputs:{}.".format(
index, self.num_outputs)
return self._runtime.get_output_info(index)

121
fastdeploy/runtime.py Normal file
View File

@@ -0,0 +1,121 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
from . import fastdeploy_main as C
class Runtime:
def __init__(self, runtime_option):
self._runtime = C.Runtime()
assert self._runtime.init(runtime_option), "Initialize Runtime Failed!"
def infer(self, data):
assert isinstance(data, dict), "The input data should be type of dict."
return self._runtime.infer(data)
def num_inputs(self):
return self._runtime.num_inputs()
def num_outputs(self):
return self._runtime.num_outputs()
def get_input_info(self, index):
assert isinstance(
index, int), "The input parameter index should be type of int."
assert index < self.num_inputs(
), "The input parameter index:{} should less than number of inputs:{}.".format(
index, self.num_inputs)
return self._runtime.get_input_info(index)
def get_output_info(self, index):
assert isinstance(
index, int), "The input parameter index should be type of int."
assert index < self.num_outputs(
), "The input parameter index:{} should less than number of outputs:{}.".format(
index, self.num_outputs)
return self._runtime.get_output_info(index)
class RuntimeOption:
def __init__(self):
self._option = C.RuntimeOption()
def set_model_path(self, model_path, params_path="",
model_format="paddle"):
return self._option.set_model_path(model_path, params_path,
model_format)
def use_gpu(self, device_id=0):
return self._option.use_gpu(device_id)
def use_cpu(self):
return self._option.use_cpu()
def set_cpu_thread_num(self, thread_num=8):
return self._option.set_cpu_thread_num(thread_num)
def use_paddle_backend(self):
return self._option.use_paddle_backend()
def use_ort_backend(self):
return self._option.use_ort_backend()
def use_trt_backend(self):
return self._option.use_trt_backend()
def enable_paddle_mkldnn(self):
return self._option.enable_paddle_mkldnn()
def disable_paddle_mkldnn(self):
return self._option.disable_paddle_mkldnn()
def set_paddle_mkldnn_cache_size(self, cache_size):
return self._option.set_paddle_mkldnn_cache_size(cache_size)
def set_trt_input_shape(self,
tensor_name,
min_shape,
opt_shape=None,
max_shape=None):
if opt_shape is None and max_shape is None:
opt_shape = min_shape
max_shape = min_shape
else:
assert opt_shape is not None and max_shape is not None, "Set min_shape only, or set min_shape, opt_shape, max_shape both."
return self._option.set_trt_input_shape(tensor_name, min_shape,
opt_shape, max_shape)
def set_trt_cache_file(self, cache_file_path):
return self._option.set_trt_cache_file(cache_file_path)
def enable_trt_fp16(self):
return self._option.enable_trt_fp16()
def disable_trt_fp16(self):
return self._option.disable_trt_fp16()
def __repr__(self):
attrs = dir(self._option)
message = "RuntimeOption(\n"
for attr in attrs:
if attr.startswith("__"):
continue
if hasattr(getattr(self._option, attr), "__call__"):
continue
message += " {} : {}\t\n".format(attr,
getattr(self._option, attr))
message.strip("\n")
message += ")"
return message