mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
yolov5 servitization optimization (#262)
* yolov5 split pre and post process * yolov5 postprocess * yolov5 postprocess
This commit is contained in:
@@ -67,6 +67,12 @@ void BindRuntime(pybind11::module& m) {
|
|||||||
pybind11::class_<Runtime>(m, "Runtime")
|
pybind11::class_<Runtime>(m, "Runtime")
|
||||||
.def(pybind11::init())
|
.def(pybind11::init())
|
||||||
.def("init", &Runtime::Init)
|
.def("init", &Runtime::Init)
|
||||||
|
.def("infer",
|
||||||
|
[](Runtime& self, std::vector<FDTensor>& inputs) {
|
||||||
|
std::vector<FDTensor> outputs(self.NumOutputs());
|
||||||
|
self.Infer(inputs, &outputs);
|
||||||
|
return outputs;
|
||||||
|
})
|
||||||
.def("infer",
|
.def("infer",
|
||||||
[](Runtime& self, std::map<std::string, pybind11::array>& data) {
|
[](Runtime& self, std::map<std::string, pybind11::array>& data) {
|
||||||
std::vector<FDTensor> inputs(data.size());
|
std::vector<FDTensor> inputs(data.size());
|
||||||
@@ -132,6 +138,32 @@ void BindRuntime(pybind11::module& m) {
|
|||||||
.value("FP64", FDDataType::FP64)
|
.value("FP64", FDDataType::FP64)
|
||||||
.value("UINT8", FDDataType::UINT8);
|
.value("UINT8", FDDataType::UINT8);
|
||||||
|
|
||||||
|
pybind11::class_<FDTensor>(m, "FDTensor", pybind11::buffer_protocol())
|
||||||
|
.def(pybind11::init())
|
||||||
|
.def("cpu_data",
|
||||||
|
[](FDTensor& self) {
|
||||||
|
auto ptr = self.CpuData();
|
||||||
|
auto numel = self.Numel();
|
||||||
|
auto dtype = FDDataTypeToNumpyDataType(self.dtype);
|
||||||
|
auto base = pybind11::array(dtype, self.shape);
|
||||||
|
return pybind11::array(dtype, self.shape, ptr, base);
|
||||||
|
})
|
||||||
|
.def("resize", static_cast<void (FDTensor::*)(size_t)>(&FDTensor::Resize))
|
||||||
|
.def("resize",
|
||||||
|
static_cast<void (FDTensor::*)(const std::vector<int64_t>&)>(
|
||||||
|
&FDTensor::Resize))
|
||||||
|
.def(
|
||||||
|
"resize",
|
||||||
|
[](FDTensor& self, const std::vector<int64_t>& shape,
|
||||||
|
const FDDataType& dtype, const std::string& name,
|
||||||
|
const Device& device) { self.Resize(shape, dtype, name, device); })
|
||||||
|
.def("numel", &FDTensor::Numel)
|
||||||
|
.def("nbytes", &FDTensor::Nbytes)
|
||||||
|
.def_readwrite("name", &FDTensor::name)
|
||||||
|
.def_readonly("shape", &FDTensor::shape)
|
||||||
|
.def_readonly("dtype", &FDTensor::dtype)
|
||||||
|
.def_readonly("device", &FDTensor::device);
|
||||||
|
|
||||||
m.def("get_available_backends", []() { return GetAvailableBackends(); });
|
m.def("get_available_backends", []() { return GetAvailableBackends(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -73,6 +73,13 @@ void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PyArrayToTensorList(std::vector<pybind11::array>& pyarrays, std::vector<FDTensor>* tensors,
|
||||||
|
bool share_buffer) {
|
||||||
|
for(auto i = 0; i < pyarrays.size(); ++i) {
|
||||||
|
PyArrayToTensor(pyarrays[i], &(*tensors)[i], share_buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pybind11::array TensorToPyArray(const FDTensor& tensor) {
|
pybind11::array TensorToPyArray(const FDTensor& tensor) {
|
||||||
auto numpy_dtype = FDDataTypeToNumpyDataType(tensor.dtype);
|
auto numpy_dtype = FDDataTypeToNumpyDataType(tensor.dtype);
|
||||||
auto out = pybind11::array(numpy_dtype, tensor.shape);
|
auto out = pybind11::array(numpy_dtype, tensor.shape);
|
||||||
|
@@ -42,6 +42,9 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype);
|
|||||||
|
|
||||||
void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
|
void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
|
||||||
bool share_buffer = false);
|
bool share_buffer = false);
|
||||||
|
void PyArrayToTensorList(std::vector<pybind11::array>& pyarray,
|
||||||
|
std::vector<FDTensor>* tensor,
|
||||||
|
bool share_buffer = false);
|
||||||
pybind11::array TensorToPyArray(const FDTensor& tensor);
|
pybind11::array TensorToPyArray(const FDTensor& tensor);
|
||||||
|
|
||||||
#ifdef ENABLE_VISION
|
#ifdef ENABLE_VISION
|
||||||
|
@@ -13,6 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "fastdeploy/vision/detection/contrib/yolov5.h"
|
#include "fastdeploy/vision/detection/contrib/yolov5.h"
|
||||||
|
|
||||||
#include "fastdeploy/utils/perf.h"
|
#include "fastdeploy/utils/perf.h"
|
||||||
#include "fastdeploy/vision/utils/utils.h"
|
#include "fastdeploy/vision/utils/utils.h"
|
||||||
|
|
||||||
@@ -74,14 +75,14 @@ YOLOv5::YOLOv5(const std::string& model_file, const std::string& params_file,
|
|||||||
|
|
||||||
bool YOLOv5::Initialize() {
|
bool YOLOv5::Initialize() {
|
||||||
// parameters for preprocess
|
// parameters for preprocess
|
||||||
size = {640, 640};
|
size_ = {640, 640};
|
||||||
padding_value = {114.0, 114.0, 114.0};
|
padding_value_ = {114.0, 114.0, 114.0};
|
||||||
is_mini_pad = false;
|
is_mini_pad_ = false;
|
||||||
is_no_pad = false;
|
is_no_pad_ = false;
|
||||||
is_scale_up = false;
|
is_scale_up_ = false;
|
||||||
stride = 32;
|
stride_ = 32;
|
||||||
max_wh = 7680.0;
|
max_wh_ = 7680.0;
|
||||||
multi_label = true;
|
multi_label_ = true;
|
||||||
|
|
||||||
if (!InitRuntime()) {
|
if (!InitRuntime()) {
|
||||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||||
@@ -90,23 +91,34 @@ bool YOLOv5::Initialize() {
|
|||||||
// Check if the input shape is dynamic after Runtime already initialized,
|
// Check if the input shape is dynamic after Runtime already initialized,
|
||||||
// Note that, We need to force is_mini_pad 'false' to keep static
|
// Note that, We need to force is_mini_pad 'false' to keep static
|
||||||
// shape after padding (LetterBox) when the is_dynamic_shape is 'false'.
|
// shape after padding (LetterBox) when the is_dynamic_shape is 'false'.
|
||||||
is_dynamic_input_ = false;
|
// TODO(qiuyanjun): remove
|
||||||
auto shape = InputInfoOfRuntime(0).shape;
|
// is_dynamic_input_ = false;
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
// auto shape = InputInfoOfRuntime(0).shape;
|
||||||
// if height or width is dynamic
|
// for (int i = 0; i < shape.size(); ++i) {
|
||||||
if (i >= 2 && shape[i] <= 0) {
|
// // if height or width is dynamic
|
||||||
is_dynamic_input_ = true;
|
// if (i >= 2 && shape[i] <= 0) {
|
||||||
break;
|
// is_dynamic_input_ = true;
|
||||||
}
|
// break;
|
||||||
}
|
// }
|
||||||
if (!is_dynamic_input_) {
|
// }
|
||||||
is_mini_pad = false;
|
// if (!is_dynamic_input_) {
|
||||||
}
|
// is_mini_pad_ = false;
|
||||||
|
// }
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool YOLOv5::Preprocess(Mat* mat, FDTensor* output,
|
bool YOLOv5::Preprocess(Mat* mat, FDTensor* output,
|
||||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
std::map<std::string, std::array<float, 2>>* im_info,
|
||||||
|
const std::vector<int>& size,
|
||||||
|
const std::vector<float> padding_value,
|
||||||
|
bool is_mini_pad, bool is_no_pad, bool is_scale_up,
|
||||||
|
int stride, float max_wh, bool multi_label) {
|
||||||
|
// Record the shape of image and the shape of preprocessed image
|
||||||
|
(*im_info)["input_shape"] = {static_cast<float>(mat->Height()),
|
||||||
|
static_cast<float>(mat->Width())};
|
||||||
|
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
|
||||||
|
static_cast<float>(mat->Width())};
|
||||||
|
|
||||||
// process after image load
|
// process after image load
|
||||||
double ratio = (size[0] * 1.0) / std::max(static_cast<float>(mat->Height()),
|
double ratio = (size[0] * 1.0) / std::max(static_cast<float>(mat->Height()),
|
||||||
static_cast<float>(mat->Width()));
|
static_cast<float>(mat->Width()));
|
||||||
@@ -145,9 +157,11 @@ bool YOLOv5::Preprocess(Mat* mat, FDTensor* output,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool YOLOv5::Postprocess(
|
bool YOLOv5::Postprocess(
|
||||||
FDTensor& infer_result, DetectionResult* result,
|
std::vector<FDTensor>& infer_results, DetectionResult* result,
|
||||||
const std::map<std::string, std::array<float, 2>>& im_info,
|
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||||
float conf_threshold, float nms_iou_threshold, bool multi_label) {
|
float conf_threshold, float nms_iou_threshold, bool multi_label,
|
||||||
|
float max_wh) {
|
||||||
|
auto& infer_result = infer_results[0];
|
||||||
FDASSERT(infer_result.shape[0] == 1, "Only support batch =1 now.");
|
FDASSERT(infer_result.shape[0] == 1, "Only support batch =1 now.");
|
||||||
result->Clear();
|
result->Clear();
|
||||||
if (multi_label) {
|
if (multi_label) {
|
||||||
@@ -251,13 +265,9 @@ bool YOLOv5::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
|||||||
|
|
||||||
std::map<std::string, std::array<float, 2>> im_info;
|
std::map<std::string, std::array<float, 2>> im_info;
|
||||||
|
|
||||||
// Record the shape of image and the shape of preprocessed image
|
if (!Preprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_,
|
||||||
im_info["input_shape"] = {static_cast<float>(mat.Height()),
|
is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_,
|
||||||
static_cast<float>(mat.Width())};
|
multi_label_)) {
|
||||||
im_info["output_shape"] = {static_cast<float>(mat.Height()),
|
|
||||||
static_cast<float>(mat.Width())};
|
|
||||||
|
|
||||||
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
|
|
||||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -278,8 +288,8 @@ bool YOLOv5::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
|||||||
TIMERECORD_START(2)
|
TIMERECORD_START(2)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
|
if (!Postprocess(output_tensors, result, im_info, conf_threshold,
|
||||||
nms_iou_threshold, multi_label)) {
|
nms_iou_threshold, multi_label_)) {
|
||||||
FDERROR << "Failed to post process." << std::endl;
|
FDERROR << "Failed to post process." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@@ -41,38 +41,18 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel {
|
|||||||
float conf_threshold = 0.25,
|
float conf_threshold = 0.25,
|
||||||
float nms_iou_threshold = 0.5);
|
float nms_iou_threshold = 0.5);
|
||||||
|
|
||||||
// 以下为模型在预测时的一些参数,基本是前后处理所需
|
|
||||||
// 用户在创建模型后,可根据模型的要求,以及自己的需求
|
|
||||||
// 对参数进行修改
|
|
||||||
// tuple of (width, height)
|
|
||||||
std::vector<int> size;
|
|
||||||
// padding value, size should be same with Channels
|
|
||||||
std::vector<float> padding_value;
|
|
||||||
// only pad to the minimum rectange which height and width is times of stride
|
|
||||||
bool is_mini_pad;
|
|
||||||
// while is_mini_pad = false and is_no_pad = true, will resize the image to
|
|
||||||
// the set size
|
|
||||||
bool is_no_pad;
|
|
||||||
// if is_scale_up is false, the input image only can be zoom out, the maximum
|
|
||||||
// resize scale cannot exceed 1.0
|
|
||||||
bool is_scale_up;
|
|
||||||
// padding stride, for is_mini_pad
|
|
||||||
int stride;
|
|
||||||
// for offseting the boxes by classes when using NMS
|
|
||||||
float max_wh;
|
|
||||||
// for different strategies to get boxes when postprocessing
|
|
||||||
bool multi_label;
|
|
||||||
|
|
||||||
private:
|
|
||||||
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
|
|
||||||
bool Initialize();
|
|
||||||
|
|
||||||
// 输入图像预处理操作
|
// 输入图像预处理操作
|
||||||
// Mat为FastDeploy定义的数据结构
|
// Mat为FastDeploy定义的数据结构
|
||||||
// FDTensor为预处理后的Tensor数据,传给后端进行推理
|
// FDTensor为预处理后的Tensor数据,传给后端进行推理
|
||||||
// im_info为预处理过程保存的数据,在后处理中需要用到
|
// im_info为预处理过程保存的数据,在后处理中需要用到
|
||||||
bool Preprocess(Mat* mat, FDTensor* outputs,
|
static bool Preprocess(Mat* mat, FDTensor* output,
|
||||||
std::map<std::string, std::array<float, 2>>* im_info);
|
std::map<std::string, std::array<float, 2>>* im_info,
|
||||||
|
const std::vector<int>& size = {640, 640},
|
||||||
|
const std::vector<float> padding_value = {114.0, 114.0,
|
||||||
|
114.0},
|
||||||
|
bool is_mini_pad = false, bool is_no_pad = false,
|
||||||
|
bool is_scale_up = false, int stride = 32,
|
||||||
|
float max_wh = 7680.0, bool multi_label = true);
|
||||||
|
|
||||||
// 后端推理结果后处理,输出给用户
|
// 后端推理结果后处理,输出给用户
|
||||||
// infer_result 为后端推理后的输出Tensor
|
// infer_result 为后端推理后的输出Tensor
|
||||||
@@ -81,17 +61,45 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel {
|
|||||||
// conf_threshold 后处理时过滤box的置信度阈值
|
// conf_threshold 后处理时过滤box的置信度阈值
|
||||||
// nms_iou_threshold 后处理时NMS设定的iou阈值
|
// nms_iou_threshold 后处理时NMS设定的iou阈值
|
||||||
// multi_label 后处理时box选取是否采用多标签方式
|
// multi_label 后处理时box选取是否采用多标签方式
|
||||||
bool Postprocess(FDTensor& infer_result, DetectionResult* result,
|
static bool Postprocess(
|
||||||
const std::map<std::string, std::array<float, 2>>& im_info,
|
std::vector<FDTensor>& infer_results, DetectionResult* result,
|
||||||
float conf_threshold, float nms_iou_threshold,
|
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||||
bool multi_label);
|
float conf_threshold, float nms_iou_threshold, bool multi_label,
|
||||||
|
float max_wh = 7680.0);
|
||||||
|
|
||||||
|
// 以下为模型在预测时的一些参数,基本是前后处理所需
|
||||||
|
// 用户在创建模型后,可根据模型的要求,以及自己的需求
|
||||||
|
// 对参数进行修改
|
||||||
|
// tuple of (width, height)
|
||||||
|
std::vector<int> size_;
|
||||||
|
// padding value, size should be same with Channels
|
||||||
|
std::vector<float> padding_value_;
|
||||||
|
// only pad to the minimum rectange which height and width is times of stride
|
||||||
|
bool is_mini_pad_;
|
||||||
|
// while is_mini_pad = false and is_no_pad = true, will resize the image to
|
||||||
|
// the set size
|
||||||
|
bool is_no_pad_;
|
||||||
|
// if is_scale_up is false, the input image only can be zoom out, the maximum
|
||||||
|
// resize scale cannot exceed 1.0
|
||||||
|
bool is_scale_up_;
|
||||||
|
// padding stride, for is_mini_pad
|
||||||
|
int stride_;
|
||||||
|
// for offseting the boxes by classes when using NMS
|
||||||
|
float max_wh_;
|
||||||
|
// for different strategies to get boxes when postprocessing
|
||||||
|
bool multi_label_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
|
||||||
|
bool Initialize();
|
||||||
|
|
||||||
// 查看输入是否为动态维度的 不建议直接使用 不同模型的逻辑可能不一致
|
// 查看输入是否为动态维度的 不建议直接使用 不同模型的逻辑可能不一致
|
||||||
bool IsDynamicInput() const { return is_dynamic_input_; }
|
bool IsDynamicInput() const { return is_dynamic_input_; }
|
||||||
|
|
||||||
void LetterBox(Mat* mat, std::vector<int> size, std::vector<float> color,
|
static void LetterBox(Mat* mat, std::vector<int> size,
|
||||||
bool _auto, bool scale_fill = false, bool scale_up = true,
|
std::vector<float> color, bool _auto,
|
||||||
int stride = 32);
|
bool scale_fill = false, bool scale_up = true,
|
||||||
|
int stride = 32);
|
||||||
|
|
||||||
// whether to inference with dynamic shape (e.g ONNX export with dynamic shape
|
// whether to inference with dynamic shape (e.g ONNX export with dynamic shape
|
||||||
// or not.)
|
// or not.)
|
||||||
|
@@ -26,13 +26,43 @@ void BindYOLOv5(pybind11::module& m) {
|
|||||||
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
||||||
return res;
|
return res;
|
||||||
})
|
})
|
||||||
.def_readwrite("size", &vision::detection::YOLOv5::size)
|
.def_static("preprocess",
|
||||||
.def_readwrite("padding_value", &vision::detection::YOLOv5::padding_value)
|
[](pybind11::array& data, const std::vector<int>& size,
|
||||||
.def_readwrite("is_mini_pad", &vision::detection::YOLOv5::is_mini_pad)
|
const std::vector<float> padding_value, bool is_mini_pad,
|
||||||
.def_readwrite("is_no_pad", &vision::detection::YOLOv5::is_no_pad)
|
bool is_no_pad, bool is_scale_up, int stride, float max_wh,
|
||||||
.def_readwrite("is_scale_up", &vision::detection::YOLOv5::is_scale_up)
|
bool multi_label) {
|
||||||
.def_readwrite("stride", &vision::detection::YOLOv5::stride)
|
auto mat = PyArrayToCvMat(data);
|
||||||
.def_readwrite("max_wh", &vision::detection::YOLOv5::max_wh)
|
fastdeploy::vision::Mat fd_mat(mat);
|
||||||
.def_readwrite("multi_label", &vision::detection::YOLOv5::multi_label);
|
FDTensor output;
|
||||||
|
std::map<std::string, std::array<float, 2>> im_info;
|
||||||
|
vision::detection::YOLOv5::Preprocess(
|
||||||
|
&fd_mat, &output, &im_info, size, padding_value,
|
||||||
|
is_mini_pad, is_no_pad, is_scale_up, stride, max_wh,
|
||||||
|
multi_label);
|
||||||
|
return make_pair(TensorToPyArray(output), im_info);
|
||||||
|
})
|
||||||
|
.def_static(
|
||||||
|
"postprocess",
|
||||||
|
[](std::vector<pybind11::array> infer_results,
|
||||||
|
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||||
|
float conf_threshold, float nms_iou_threshold, bool multi_label,
|
||||||
|
float max_wh) {
|
||||||
|
std::vector<FDTensor> fd_infer_results(infer_results.size());
|
||||||
|
PyArrayToTensorList(infer_results, &fd_infer_results, true);
|
||||||
|
vision::DetectionResult result;
|
||||||
|
vision::detection::YOLOv5::Postprocess(
|
||||||
|
fd_infer_results, &result, im_info, conf_threshold,
|
||||||
|
nms_iou_threshold, multi_label, max_wh);
|
||||||
|
return result;
|
||||||
|
})
|
||||||
|
.def_readwrite("size", &vision::detection::YOLOv5::size_)
|
||||||
|
.def_readwrite("padding_value",
|
||||||
|
&vision::detection::YOLOv5::padding_value_)
|
||||||
|
.def_readwrite("is_mini_pad", &vision::detection::YOLOv5::is_mini_pad_)
|
||||||
|
.def_readwrite("is_no_pad", &vision::detection::YOLOv5::is_no_pad_)
|
||||||
|
.def_readwrite("is_scale_up", &vision::detection::YOLOv5::is_scale_up_)
|
||||||
|
.def_readwrite("stride", &vision::detection::YOLOv5::stride_)
|
||||||
|
.def_readwrite("max_wh", &vision::detection::YOLOv5::max_wh_)
|
||||||
|
.def_readwrite("multi_label", &vision::detection::YOLOv5::multi_label_);
|
||||||
}
|
}
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -17,9 +17,10 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from .c_lib_wrap import (Frontend, Backend, FDDataType, TensorInfo, Device,
|
from .c_lib_wrap import (Frontend, Backend, FDDataType, TensorInfo, Device,
|
||||||
is_built_with_gpu, is_built_with_ort,
|
FDTensor, is_built_with_gpu, is_built_with_ort,
|
||||||
is_built_with_paddle, is_built_with_trt,
|
is_built_with_paddle, is_built_with_trt,
|
||||||
get_default_cuda_directory)
|
get_default_cuda_directory)
|
||||||
|
|
||||||
from .runtime import Runtime, RuntimeOption
|
from .runtime import Runtime, RuntimeOption
|
||||||
from .model import FastDeployModel
|
from .model import FastDeployModel
|
||||||
from . import c_lib_wrap as C
|
from . import c_lib_wrap as C
|
||||||
|
@@ -23,7 +23,8 @@ class Runtime:
|
|||||||
runtime_option._option), "Initialize Runtime Failed!"
|
runtime_option._option), "Initialize Runtime Failed!"
|
||||||
|
|
||||||
def infer(self, data):
|
def infer(self, data):
|
||||||
assert isinstance(data, dict), "The input data should be type of dict."
|
assert isinstance(data, dict) or isinstance(
|
||||||
|
data, list), "The input data should be type of dict or list."
|
||||||
return self._runtime.infer(data)
|
return self._runtime.infer(data)
|
||||||
|
|
||||||
def num_inputs(self):
|
def num_inputs(self):
|
||||||
|
@@ -37,6 +37,31 @@ class YOLOv5(FastDeployModel):
|
|||||||
return self._model.predict(input_image, conf_threshold,
|
return self._model.predict(input_image, conf_threshold,
|
||||||
nms_iou_threshold)
|
nms_iou_threshold)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preprocess(input_image,
|
||||||
|
size=[640, 640],
|
||||||
|
padding_value=[114.0, 114.0, 114.0],
|
||||||
|
is_mini_pad=False,
|
||||||
|
is_no_pad=False,
|
||||||
|
is_scale_up=False,
|
||||||
|
stride=32,
|
||||||
|
max_wh=7680.0,
|
||||||
|
multi_label=True):
|
||||||
|
return C.vision.detection.YOLOv5.preprocess(
|
||||||
|
input_image, size, padding_value, is_mini_pad, is_no_pad,
|
||||||
|
is_scale_up, stride, max_wh, multi_label)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def postprocess(infer_result,
|
||||||
|
im_info,
|
||||||
|
conf_threshold=0.25,
|
||||||
|
nms_iou_threshold=0.5,
|
||||||
|
multi_label=True,
|
||||||
|
max_wh=7680.0):
|
||||||
|
return C.vision.detection.YOLOv5.postprocess(
|
||||||
|
infer_result, im_info, conf_threshold, nms_iou_threshold,
|
||||||
|
multi_label, max_wh)
|
||||||
|
|
||||||
# 一些跟YOLOv5模型有关的属性封装
|
# 一些跟YOLOv5模型有关的属性封装
|
||||||
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
|
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
|
||||||
@property
|
@property
|
||||||
|
Reference in New Issue
Block a user