mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Modify yolov7 and visualize functions (#82)
modify yolov7 and visualize functions
This commit is contained in:
@@ -30,6 +30,8 @@ pybind11::dtype FDDataTypeToNumpyDataType(const FDDataType& fd_dtype) {
|
|||||||
dt = pybind11::dtype::of<float>();
|
dt = pybind11::dtype::of<float>();
|
||||||
} else if (fd_dtype == FDDataType::FP64) {
|
} else if (fd_dtype == FDDataType::FP64) {
|
||||||
dt = pybind11::dtype::of<double>();
|
dt = pybind11::dtype::of<double>();
|
||||||
|
} else if (fd_dtype == FDDataType::UINT8) {
|
||||||
|
dt = pybind11::dtype::of<uint8_t>();
|
||||||
} else {
|
} else {
|
||||||
FDASSERT(false, "The function doesn't support data type of " +
|
FDASSERT(false, "The function doesn't support data type of " +
|
||||||
Str(fd_dtype) + ".");
|
Str(fd_dtype) + ".");
|
||||||
@@ -46,6 +48,8 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype) {
|
|||||||
return FDDataType::FP32;
|
return FDDataType::FP32;
|
||||||
} else if (np_dtype.is(pybind11::dtype::of<double>())) {
|
} else if (np_dtype.is(pybind11::dtype::of<double>())) {
|
||||||
return FDDataType::FP64;
|
return FDDataType::FP64;
|
||||||
|
} else if (np_dtype.is(pybind11::dtype::of<uint8_t>())) {
|
||||||
|
return FDDataType::UINT8;
|
||||||
}
|
}
|
||||||
FDASSERT(false,
|
FDASSERT(false,
|
||||||
"NumpyDataTypeToFDDataType() only support "
|
"NumpyDataTypeToFDDataType() only support "
|
||||||
@@ -66,6 +70,13 @@ void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pybind11::array TensorToPyArray(const FDTensor& tensor) {
|
||||||
|
auto numpy_dtype = FDDataTypeToNumpyDataType(tensor.dtype);
|
||||||
|
auto out = pybind11::array(numpy_dtype, tensor.shape);
|
||||||
|
memcpy(out.mutable_data(), tensor.Data(), tensor.Numel() * FDDataTypeSize(tensor.dtype));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_VISION
|
#ifdef ENABLE_VISION
|
||||||
int NumpyDataTypeToOpenCvType(const pybind11::dtype& np_dtype) {
|
int NumpyDataTypeToOpenCvType(const pybind11::dtype& np_dtype) {
|
||||||
if (np_dtype.is(pybind11::dtype::of<int32_t>())) {
|
if (np_dtype.is(pybind11::dtype::of<int32_t>())) {
|
||||||
|
@@ -36,12 +36,14 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype);
|
|||||||
|
|
||||||
void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
|
void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
|
||||||
bool share_buffer = false);
|
bool share_buffer = false);
|
||||||
|
pybind11::array TensorToPyArray(const FDTensor& tensor);
|
||||||
|
|
||||||
#ifdef ENABLE_VISION
|
#ifdef ENABLE_VISION
|
||||||
cv::Mat PyArrayToCvMat(pybind11::array& pyarray);
|
cv::Mat PyArrayToCvMat(pybind11::array& pyarray);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename T> FDDataType CTypeToFDDataType() {
|
template <typename T>
|
||||||
|
FDDataType CTypeToFDDataType() {
|
||||||
if (std::is_same<T, int32_t>::value) {
|
if (std::is_same<T, int32_t>::value) {
|
||||||
return FDDataType::INT32;
|
return FDDataType::INT32;
|
||||||
} else if (std::is_same<T, int64_t>::value) {
|
} else if (std::is_same<T, int64_t>::value) {
|
||||||
@@ -57,8 +59,8 @@ template <typename T> FDDataType CTypeToFDDataType() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<pybind11::array>
|
std::vector<pybind11::array> PyBackendInfer(
|
||||||
PyBackendInfer(T& self, const std::vector<std::string>& names,
|
T& self, const std::vector<std::string>& names,
|
||||||
std::vector<pybind11::array>& data) {
|
std::vector<pybind11::array>& data) {
|
||||||
std::vector<FDTensor> inputs(data.size());
|
std::vector<FDTensor> inputs(data.size());
|
||||||
for (size_t i = 0; i < data.size(); ++i) {
|
for (size_t i = 0; i < data.size(); ++i) {
|
||||||
|
@@ -23,6 +23,9 @@
|
|||||||
#include "fastdeploy/vision/deepinsight/partial_fc.h"
|
#include "fastdeploy/vision/deepinsight/partial_fc.h"
|
||||||
#include "fastdeploy/vision/deepinsight/scrfd.h"
|
#include "fastdeploy/vision/deepinsight/scrfd.h"
|
||||||
#include "fastdeploy/vision/deepinsight/vpl.h"
|
#include "fastdeploy/vision/deepinsight/vpl.h"
|
||||||
|
#include "fastdeploy/vision/detection/contrib/scaledyolov4.h"
|
||||||
|
#include "fastdeploy/vision/detection/contrib/yolor.h"
|
||||||
|
#include "fastdeploy/vision/detection/contrib/yolov7.h"
|
||||||
#include "fastdeploy/vision/linzaer/ultraface.h"
|
#include "fastdeploy/vision/linzaer/ultraface.h"
|
||||||
#include "fastdeploy/vision/megvii/yolox.h"
|
#include "fastdeploy/vision/megvii/yolox.h"
|
||||||
#include "fastdeploy/vision/meituan/yolov6.h"
|
#include "fastdeploy/vision/meituan/yolov6.h"
|
||||||
@@ -32,9 +35,6 @@
|
|||||||
#include "fastdeploy/vision/ppseg/model.h"
|
#include "fastdeploy/vision/ppseg/model.h"
|
||||||
#include "fastdeploy/vision/rangilyu/nanodet_plus.h"
|
#include "fastdeploy/vision/rangilyu/nanodet_plus.h"
|
||||||
#include "fastdeploy/vision/ultralytics/yolov5.h"
|
#include "fastdeploy/vision/ultralytics/yolov5.h"
|
||||||
#include "fastdeploy/vision/wongkinyiu/scaledyolov4.h"
|
|
||||||
#include "fastdeploy/vision/wongkinyiu/yolor.h"
|
|
||||||
#include "fastdeploy/vision/wongkinyiu/yolov7.h"
|
|
||||||
#include "fastdeploy/vision/zhkkke/modnet.h"
|
#include "fastdeploy/vision/zhkkke/modnet.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@@ -27,7 +27,7 @@ namespace vision {
|
|||||||
|
|
||||||
enum Layout { HWC, CHW };
|
enum Layout { HWC, CHW };
|
||||||
|
|
||||||
struct Mat {
|
struct FASTDEPLOY_DECL Mat {
|
||||||
explicit Mat(cv::Mat& mat) {
|
explicit Mat(cv::Mat& mat) {
|
||||||
cpu_mat = mat;
|
cpu_mat = mat;
|
||||||
device = Device::CPU;
|
device = Device::CPU;
|
||||||
|
@@ -12,13 +12,13 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "fastdeploy/vision/wongkinyiu/scaledyolov4.h"
|
#include "fastdeploy/vision/detection/contrib/scaledyolov4.h"
|
||||||
#include "fastdeploy/utils/perf.h"
|
#include "fastdeploy/utils/perf.h"
|
||||||
#include "fastdeploy/vision/utils/utils.h"
|
#include "fastdeploy/vision/utils/utils.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
namespace wongkinyiu {
|
namespace detection {
|
||||||
|
|
||||||
void ScaledYOLOv4::LetterBox(Mat* mat, const std::vector<int>& size,
|
void ScaledYOLOv4::LetterBox(Mat* mat, const std::vector<int>& size,
|
||||||
const std::vector<float>& color, bool _auto,
|
const std::vector<float>& color, bool _auto,
|
||||||
@@ -65,8 +65,8 @@ ScaledYOLOv4::ScaledYOLOv4(const std::string& model_file,
|
|||||||
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
|
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
|
||||||
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
|
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
|
||||||
} else {
|
} else {
|
||||||
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
|
valid_cpu_backends = {Backend::PDINFER};
|
||||||
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
valid_gpu_backends = {Backend::PDINFER};
|
||||||
}
|
}
|
||||||
runtime_option = custom_option;
|
runtime_option = custom_option;
|
||||||
runtime_option.model_format = model_format;
|
runtime_option.model_format = model_format;
|
||||||
@@ -219,10 +219,6 @@ bool ScaledYOLOv4::Postprocess(
|
|||||||
|
|
||||||
bool ScaledYOLOv4::Predict(cv::Mat* im, DetectionResult* result,
|
bool ScaledYOLOv4::Predict(cv::Mat* im, DetectionResult* result,
|
||||||
float conf_threshold, float nms_iou_threshold) {
|
float conf_threshold, float nms_iou_threshold) {
|
||||||
#ifdef FASTDEPLOY_DEBUG
|
|
||||||
TIMERECORD_START(0)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
Mat mat(*im);
|
Mat mat(*im);
|
||||||
std::vector<FDTensor> input_tensors(1);
|
std::vector<FDTensor> input_tensors(1);
|
||||||
|
|
||||||
@@ -239,34 +235,21 @@ bool ScaledYOLOv4::Predict(cv::Mat* im, DetectionResult* result,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef FASTDEPLOY_DEBUG
|
|
||||||
TIMERECORD_END(0, "Preprocess")
|
|
||||||
TIMERECORD_START(1)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||||
std::vector<FDTensor> output_tensors;
|
std::vector<FDTensor> output_tensors;
|
||||||
if (!Infer(input_tensors, &output_tensors)) {
|
if (!Infer(input_tensors, &output_tensors)) {
|
||||||
FDERROR << "Failed to inference." << std::endl;
|
FDERROR << "Failed to inference." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
#ifdef FASTDEPLOY_DEBUG
|
|
||||||
TIMERECORD_END(1, "Inference")
|
|
||||||
TIMERECORD_START(2)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
|
if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
|
||||||
nms_iou_threshold)) {
|
nms_iou_threshold)) {
|
||||||
FDERROR << "Failed to post process." << std::endl;
|
FDERROR << "Failed to post process." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef FASTDEPLOY_DEBUG
|
|
||||||
TIMERECORD_END(2, "Postprocess")
|
|
||||||
#endif
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace wongkinyiu
|
} // namespace detection
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
@@ -19,7 +19,7 @@
|
|||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
namespace wongkinyiu {
|
namespace detection {
|
||||||
|
|
||||||
class FASTDEPLOY_DECL ScaledYOLOv4 : public FastDeployModel {
|
class FASTDEPLOY_DECL ScaledYOLOv4 : public FastDeployModel {
|
||||||
public:
|
public:
|
||||||
@@ -31,7 +31,7 @@ class FASTDEPLOY_DECL ScaledYOLOv4 : public FastDeployModel {
|
|||||||
const Frontend& model_format = Frontend::ONNX);
|
const Frontend& model_format = Frontend::ONNX);
|
||||||
|
|
||||||
// 定义模型的名称
|
// 定义模型的名称
|
||||||
virtual std::string ModelName() const { return "WongKinYiu/ScaledYOLOv4"; }
|
virtual std::string ModelName() const { return "ScaledYOLOv4"; }
|
||||||
|
|
||||||
// 模型预测接口,即用户调用的接口
|
// 模型预测接口,即用户调用的接口
|
||||||
// im 为用户的输入数据,目前对于CV均定义为cv::Mat
|
// im 为用户的输入数据,目前对于CV均定义为cv::Mat
|
||||||
@@ -98,6 +98,6 @@ class FASTDEPLOY_DECL ScaledYOLOv4 : public FastDeployModel {
|
|||||||
// auto check by fastdeploy after the internal Runtime already initialized.
|
// auto check by fastdeploy after the internal Runtime already initialized.
|
||||||
bool is_dynamic_input_;
|
bool is_dynamic_input_;
|
||||||
};
|
};
|
||||||
} // namespace wongkinyiu
|
} // namespace detection
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
@@ -0,0 +1,41 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/pybind/main.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
void BindScaledYOLOv4(pybind11::module& m) {
|
||||||
|
pybind11::class_<vision::detection::ScaledYOLOv4, FastDeployModel>(
|
||||||
|
m, "ScaledYOLOv4")
|
||||||
|
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
||||||
|
.def("predict",
|
||||||
|
[](vision::detection::ScaledYOLOv4& self, pybind11::array& data,
|
||||||
|
float conf_threshold, float nms_iou_threshold) {
|
||||||
|
auto mat = PyArrayToCvMat(data);
|
||||||
|
vision::DetectionResult res;
|
||||||
|
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
||||||
|
return res;
|
||||||
|
})
|
||||||
|
.def_readwrite("size", &vision::detection::ScaledYOLOv4::size)
|
||||||
|
.def_readwrite("padding_value",
|
||||||
|
&vision::detection::ScaledYOLOv4::padding_value)
|
||||||
|
.def_readwrite("is_mini_pad",
|
||||||
|
&vision::detection::ScaledYOLOv4::is_mini_pad)
|
||||||
|
.def_readwrite("is_no_pad", &vision::detection::ScaledYOLOv4::is_no_pad)
|
||||||
|
.def_readwrite("is_scale_up",
|
||||||
|
&vision::detection::ScaledYOLOv4::is_scale_up)
|
||||||
|
.def_readwrite("stride", &vision::detection::ScaledYOLOv4::stride)
|
||||||
|
.def_readwrite("max_wh", &vision::detection::ScaledYOLOv4::max_wh);
|
||||||
|
}
|
||||||
|
} // namespace fastdeploy
|
@@ -12,13 +12,13 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "fastdeploy/vision/wongkinyiu/yolor.h"
|
#include "fastdeploy/vision/detection/contrib/yolor.h"
|
||||||
#include "fastdeploy/utils/perf.h"
|
#include "fastdeploy/utils/perf.h"
|
||||||
#include "fastdeploy/vision/utils/utils.h"
|
#include "fastdeploy/vision/utils/utils.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
namespace wongkinyiu {
|
namespace detection {
|
||||||
|
|
||||||
void YOLOR::LetterBox(Mat* mat, const std::vector<int>& size,
|
void YOLOR::LetterBox(Mat* mat, const std::vector<int>& size,
|
||||||
const std::vector<float>& color, bool _auto,
|
const std::vector<float>& color, bool _auto,
|
||||||
@@ -63,8 +63,8 @@ YOLOR::YOLOR(const std::string& model_file, const std::string& params_file,
|
|||||||
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
|
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
|
||||||
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
|
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
|
||||||
} else {
|
} else {
|
||||||
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
|
valid_cpu_backends = {Backend::PDINFER};
|
||||||
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
valid_gpu_backends = {Backend::PDINFER};
|
||||||
}
|
}
|
||||||
runtime_option = custom_option;
|
runtime_option = custom_option;
|
||||||
runtime_option.model_format = model_format;
|
runtime_option.model_format = model_format;
|
||||||
@@ -216,10 +216,6 @@ bool YOLOR::Postprocess(
|
|||||||
|
|
||||||
bool YOLOR::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
bool YOLOR::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
||||||
float nms_iou_threshold) {
|
float nms_iou_threshold) {
|
||||||
#ifdef FASTDEPLOY_DEBUG
|
|
||||||
TIMERECORD_START(0)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
Mat mat(*im);
|
Mat mat(*im);
|
||||||
std::vector<FDTensor> input_tensors(1);
|
std::vector<FDTensor> input_tensors(1);
|
||||||
|
|
||||||
@@ -236,21 +232,12 @@ bool YOLOR::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef FASTDEPLOY_DEBUG
|
|
||||||
TIMERECORD_END(0, "Preprocess")
|
|
||||||
TIMERECORD_START(1)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||||
std::vector<FDTensor> output_tensors;
|
std::vector<FDTensor> output_tensors;
|
||||||
if (!Infer(input_tensors, &output_tensors)) {
|
if (!Infer(input_tensors, &output_tensors)) {
|
||||||
FDERROR << "Failed to inference." << std::endl;
|
FDERROR << "Failed to inference." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
#ifdef FASTDEPLOY_DEBUG
|
|
||||||
TIMERECORD_END(1, "Inference")
|
|
||||||
TIMERECORD_START(2)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
|
if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
|
||||||
nms_iou_threshold)) {
|
nms_iou_threshold)) {
|
||||||
@@ -258,12 +245,9 @@ bool YOLOR::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef FASTDEPLOY_DEBUG
|
|
||||||
TIMERECORD_END(2, "Postprocess")
|
|
||||||
#endif
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace wongkinyiu
|
} // namespace detection
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
@@ -19,7 +19,7 @@
|
|||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
namespace wongkinyiu {
|
namespace detection {
|
||||||
|
|
||||||
class FASTDEPLOY_DECL YOLOR : public FastDeployModel {
|
class FASTDEPLOY_DECL YOLOR : public FastDeployModel {
|
||||||
public:
|
public:
|
||||||
@@ -30,7 +30,7 @@ class FASTDEPLOY_DECL YOLOR : public FastDeployModel {
|
|||||||
const Frontend& model_format = Frontend::ONNX);
|
const Frontend& model_format = Frontend::ONNX);
|
||||||
|
|
||||||
// 定义模型的名称
|
// 定义模型的名称
|
||||||
virtual std::string ModelName() const { return "WongKinYiu/yolor"; }
|
virtual std::string ModelName() const { return "YOLOR"; }
|
||||||
|
|
||||||
// 模型预测接口,即用户调用的接口
|
// 模型预测接口,即用户调用的接口
|
||||||
// im 为用户的输入数据,目前对于CV均定义为cv::Mat
|
// im 为用户的输入数据,目前对于CV均定义为cv::Mat
|
||||||
@@ -97,6 +97,6 @@ class FASTDEPLOY_DECL YOLOR : public FastDeployModel {
|
|||||||
// auto check by fastdeploy after the internal Runtime already initialized.
|
// auto check by fastdeploy after the internal Runtime already initialized.
|
||||||
bool is_dynamic_input_;
|
bool is_dynamic_input_;
|
||||||
};
|
};
|
||||||
} // namespace wongkinyiu
|
} // namespace detection
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
37
csrcs/fastdeploy/vision/detection/contrib/yolor_pybind.cc
Normal file
37
csrcs/fastdeploy/vision/detection/contrib/yolor_pybind.cc
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/pybind/main.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
void BindYOLOR(pybind11::module& m) {
|
||||||
|
pybind11::class_<vision::detection::YOLOR, FastDeployModel>(m, "YOLOR")
|
||||||
|
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
||||||
|
.def("predict",
|
||||||
|
[](vision::detection::YOLOR& self, pybind11::array& data,
|
||||||
|
float conf_threshold, float nms_iou_threshold) {
|
||||||
|
auto mat = PyArrayToCvMat(data);
|
||||||
|
vision::DetectionResult res;
|
||||||
|
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
||||||
|
return res;
|
||||||
|
})
|
||||||
|
.def_readwrite("size", &vision::detection::YOLOR::size)
|
||||||
|
.def_readwrite("padding_value", &vision::detection::YOLOR::padding_value)
|
||||||
|
.def_readwrite("is_mini_pad", &vision::detection::YOLOR::is_mini_pad)
|
||||||
|
.def_readwrite("is_no_pad", &vision::detection::YOLOR::is_no_pad)
|
||||||
|
.def_readwrite("is_scale_up", &vision::detection::YOLOR::is_scale_up)
|
||||||
|
.def_readwrite("stride", &vision::detection::YOLOR::stride)
|
||||||
|
.def_readwrite("max_wh", &vision::detection::YOLOR::max_wh);
|
||||||
|
}
|
||||||
|
} // namespace fastdeploy
|
@@ -12,13 +12,13 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "fastdeploy/vision/wongkinyiu/yolov7.h"
|
#include "fastdeploy/vision/detection/contrib/yolov7.h"
|
||||||
#include "fastdeploy/utils/perf.h"
|
#include "fastdeploy/utils/perf.h"
|
||||||
#include "fastdeploy/vision/utils/utils.h"
|
#include "fastdeploy/vision/utils/utils.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
namespace wongkinyiu {
|
namespace detection {
|
||||||
|
|
||||||
void YOLOv7::LetterBox(Mat* mat, const std::vector<int>& size,
|
void YOLOv7::LetterBox(Mat* mat, const std::vector<int>& size,
|
||||||
const std::vector<float>& color, bool _auto,
|
const std::vector<float>& color, bool _auto,
|
||||||
@@ -64,13 +64,12 @@ YOLOv7::YOLOv7(const std::string& model_file, const std::string& params_file,
|
|||||||
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
|
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
|
||||||
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
|
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
|
||||||
} else {
|
} else {
|
||||||
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
|
valid_cpu_backends = {Backend::PDINFER};
|
||||||
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
valid_gpu_backends = {Backend::PDINFER};
|
||||||
}
|
}
|
||||||
runtime_option = custom_option;
|
runtime_option = custom_option;
|
||||||
runtime_option.model_format = model_format;
|
runtime_option.model_format = model_format;
|
||||||
runtime_option.model_file = model_file;
|
runtime_option.model_file = model_file;
|
||||||
runtime_option.params_file = params_file;
|
|
||||||
initialized = Initialize();
|
initialized = Initialize();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,10 +216,6 @@ bool YOLOv7::Postprocess(
|
|||||||
|
|
||||||
bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
||||||
float nms_iou_threshold) {
|
float nms_iou_threshold) {
|
||||||
#ifdef FASTDEPLOY_DEBUG
|
|
||||||
TIMERECORD_START(0)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
Mat mat(*im);
|
Mat mat(*im);
|
||||||
std::vector<FDTensor> input_tensors(1);
|
std::vector<FDTensor> input_tensors(1);
|
||||||
|
|
||||||
@@ -237,21 +232,12 @@ bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef FASTDEPLOY_DEBUG
|
|
||||||
TIMERECORD_END(0, "Preprocess")
|
|
||||||
TIMERECORD_START(1)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||||
std::vector<FDTensor> output_tensors;
|
std::vector<FDTensor> output_tensors;
|
||||||
if (!Infer(input_tensors, &output_tensors)) {
|
if (!Infer(input_tensors, &output_tensors)) {
|
||||||
FDERROR << "Failed to inference." << std::endl;
|
FDERROR << "Failed to inference." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
#ifdef FASTDEPLOY_DEBUG
|
|
||||||
TIMERECORD_END(1, "Inference")
|
|
||||||
TIMERECORD_START(2)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
|
if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
|
||||||
nms_iou_threshold)) {
|
nms_iou_threshold)) {
|
||||||
@@ -259,12 +245,9 @@ bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef FASTDEPLOY_DEBUG
|
|
||||||
TIMERECORD_END(2, "Postprocess")
|
|
||||||
#endif
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace wongkinyiu
|
} // namespace detection
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
@@ -19,18 +19,16 @@
|
|||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
namespace wongkinyiu {
|
namespace detection {
|
||||||
|
|
||||||
class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel {
|
class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel {
|
||||||
public:
|
public:
|
||||||
// 当model_format为ONNX时,无需指定params_file
|
|
||||||
// 当model_format为Paddle时,则需同时指定model_file & params_file
|
|
||||||
YOLOv7(const std::string& model_file, const std::string& params_file = "",
|
YOLOv7(const std::string& model_file, const std::string& params_file = "",
|
||||||
const RuntimeOption& custom_option = RuntimeOption(),
|
const RuntimeOption& custom_option = RuntimeOption(),
|
||||||
const Frontend& model_format = Frontend::ONNX);
|
const Frontend& model_format = Frontend::ONNX);
|
||||||
|
|
||||||
// 定义模型的名称
|
// 定义模型的名称
|
||||||
virtual std::string ModelName() const { return "WongKinYiu/yolov7"; }
|
virtual std::string ModelName() const { return "yolov7"; }
|
||||||
|
|
||||||
// 模型预测接口,即用户调用的接口
|
// 模型预测接口,即用户调用的接口
|
||||||
// im 为用户的输入数据,目前对于CV均定义为cv::Mat
|
// im 为用户的输入数据,目前对于CV均定义为cv::Mat
|
||||||
@@ -97,6 +95,6 @@ class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel {
|
|||||||
// auto check by fastdeploy after the internal Runtime already initialized.
|
// auto check by fastdeploy after the internal Runtime already initialized.
|
||||||
bool is_dynamic_input_;
|
bool is_dynamic_input_;
|
||||||
};
|
};
|
||||||
} // namespace wongkinyiu
|
} // namespace detection
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
37
csrcs/fastdeploy/vision/detection/contrib/yolov7_pybind.cc
Normal file
37
csrcs/fastdeploy/vision/detection/contrib/yolov7_pybind.cc
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/pybind/main.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
void BindYOLOv7(pybind11::module& m) {
|
||||||
|
pybind11::class_<vision::detection::YOLOv7, FastDeployModel>(m, "YOLOv7")
|
||||||
|
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
||||||
|
.def("predict",
|
||||||
|
[](vision::detection::YOLOv7& self, pybind11::array& data,
|
||||||
|
float conf_threshold, float nms_iou_threshold) {
|
||||||
|
auto mat = PyArrayToCvMat(data);
|
||||||
|
vision::DetectionResult res;
|
||||||
|
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
||||||
|
return res;
|
||||||
|
})
|
||||||
|
.def_readwrite("size", &vision::detection::YOLOv7::size)
|
||||||
|
.def_readwrite("padding_value", &vision::detection::YOLOv7::padding_value)
|
||||||
|
.def_readwrite("is_mini_pad", &vision::detection::YOLOv7::is_mini_pad)
|
||||||
|
.def_readwrite("is_no_pad", &vision::detection::YOLOv7::is_no_pad)
|
||||||
|
.def_readwrite("is_scale_up", &vision::detection::YOLOv7::is_scale_up)
|
||||||
|
.def_readwrite("stride", &vision::detection::YOLOv7::stride)
|
||||||
|
.def_readwrite("max_wh", &vision::detection::YOLOv7::max_wh);
|
||||||
|
}
|
||||||
|
} // namespace fastdeploy
|
30
csrcs/fastdeploy/vision/detection/detection_pybind.cc
Normal file
30
csrcs/fastdeploy/vision/detection/detection_pybind.cc
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/pybind/main.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
|
||||||
|
void BindYOLOv7(pybind11::module& m);
|
||||||
|
void BindScaledYOLOv4(pybind11::module& m);
|
||||||
|
void BindYOLOR(pybind11::module& m);
|
||||||
|
|
||||||
|
void BindDetection(pybind11::module& m) {
|
||||||
|
auto detection_module =
|
||||||
|
m.def_submodule("detection", "Image object detection models.");
|
||||||
|
BindYOLOv7(detection_module);
|
||||||
|
BindScaledYOLOv4(detection_module);
|
||||||
|
BindYOLOR(detection_module);
|
||||||
|
}
|
||||||
|
} // namespace fastdeploy
|
@@ -18,7 +18,6 @@ namespace fastdeploy {
|
|||||||
|
|
||||||
void BindPPCls(pybind11::module& m);
|
void BindPPCls(pybind11::module& m);
|
||||||
void BindPPDet(pybind11::module& m);
|
void BindPPDet(pybind11::module& m);
|
||||||
void BindWongkinyiu(pybind11::module& m);
|
|
||||||
void BindPPSeg(pybind11::module& m);
|
void BindPPSeg(pybind11::module& m);
|
||||||
void BindUltralytics(pybind11::module& m);
|
void BindUltralytics(pybind11::module& m);
|
||||||
void BindMeituan(pybind11::module& m);
|
void BindMeituan(pybind11::module& m);
|
||||||
@@ -30,6 +29,8 @@ void BindBiubug6(pybind11::module& m);
|
|||||||
void BindPpogg(pybind11::module& m);
|
void BindPpogg(pybind11::module& m);
|
||||||
void BindDeepInsight(pybind11::module& m);
|
void BindDeepInsight(pybind11::module& m);
|
||||||
void BindZHKKKe(pybind11::module& m);
|
void BindZHKKKe(pybind11::module& m);
|
||||||
|
|
||||||
|
void BindDetection(pybind11::module& m);
|
||||||
#ifdef ENABLE_VISION_VISUALIZE
|
#ifdef ENABLE_VISION_VISUALIZE
|
||||||
void BindVisualize(pybind11::module& m);
|
void BindVisualize(pybind11::module& m);
|
||||||
#endif
|
#endif
|
||||||
@@ -88,7 +89,6 @@ void BindVision(pybind11::module& m) {
|
|||||||
BindPPDet(m);
|
BindPPDet(m);
|
||||||
BindPPSeg(m);
|
BindPPSeg(m);
|
||||||
BindUltralytics(m);
|
BindUltralytics(m);
|
||||||
BindWongkinyiu(m);
|
|
||||||
BindMeituan(m);
|
BindMeituan(m);
|
||||||
BindMegvii(m);
|
BindMegvii(m);
|
||||||
BindDeepCam(m);
|
BindDeepCam(m);
|
||||||
@@ -98,6 +98,8 @@ void BindVision(pybind11::module& m) {
|
|||||||
BindPpogg(m);
|
BindPpogg(m);
|
||||||
BindDeepInsight(m);
|
BindDeepInsight(m);
|
||||||
BindZHKKKe(m);
|
BindZHKKKe(m);
|
||||||
|
|
||||||
|
BindDetection(m);
|
||||||
#ifdef ENABLE_VISION_VISUALIZE
|
#ifdef ENABLE_VISION_VISUALIZE
|
||||||
BindVisualize(m);
|
BindVisualize(m);
|
||||||
#endif
|
#endif
|
||||||
|
@@ -23,11 +23,13 @@ namespace vision {
|
|||||||
// Default only support visualize num_classes <= 1000
|
// Default only support visualize num_classes <= 1000
|
||||||
// If need to visualize num_classes > 1000
|
// If need to visualize num_classes > 1000
|
||||||
// Please call Visualize::GetColorMap(num_classes) first
|
// Please call Visualize::GetColorMap(num_classes) first
|
||||||
void Visualize::VisDetection(cv::Mat* im, const DetectionResult& result,
|
cv::Mat Visualize::VisDetection(const cv::Mat& im,
|
||||||
int line_size, float font_size) {
|
const DetectionResult& result, int line_size,
|
||||||
|
float font_size) {
|
||||||
auto color_map = GetColorMap();
|
auto color_map = GetColorMap();
|
||||||
int h = im->rows;
|
int h = im.rows;
|
||||||
int w = im->cols;
|
int w = im.cols;
|
||||||
|
auto vis_im = im.clone();
|
||||||
for (size_t i = 0; i < result.boxes.size(); ++i) {
|
for (size_t i = 0; i < result.boxes.size(); ++i) {
|
||||||
cv::Rect rect(result.boxes[i][0], result.boxes[i][1],
|
cv::Rect rect(result.boxes[i][0], result.boxes[i][1],
|
||||||
result.boxes[i][2] - result.boxes[i][0],
|
result.boxes[i][2] - result.boxes[i][0],
|
||||||
@@ -50,10 +52,11 @@ void Visualize::VisDetection(cv::Mat* im, const DetectionResult& result,
|
|||||||
cv::Rect text_background =
|
cv::Rect text_background =
|
||||||
cv::Rect(result.boxes[i][0], result.boxes[i][1] - text_size.height,
|
cv::Rect(result.boxes[i][0], result.boxes[i][1] - text_size.height,
|
||||||
text_size.width, text_size.height);
|
text_size.width, text_size.height);
|
||||||
cv::rectangle(*im, rect, rect_color, line_size);
|
cv::rectangle(vis_im, rect, rect_color, line_size);
|
||||||
cv::putText(*im, text, origin, font, font_size, cv::Scalar(255, 255, 255),
|
cv::putText(vis_im, text, origin, font, font_size,
|
||||||
1);
|
cv::Scalar(255, 255, 255), 1);
|
||||||
}
|
}
|
||||||
|
return vis_im;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
|
@@ -24,12 +24,14 @@ namespace vision {
|
|||||||
// Default only support visualize num_classes <= 1000
|
// Default only support visualize num_classes <= 1000
|
||||||
// If need to visualize num_classes > 1000
|
// If need to visualize num_classes > 1000
|
||||||
// Please call Visualize::GetColorMap(num_classes) first
|
// Please call Visualize::GetColorMap(num_classes) first
|
||||||
void Visualize::VisFaceDetection(cv::Mat* im, const FaceDetectionResult& result,
|
cv::Mat Visualize::VisFaceDetection(const cv::Mat& im,
|
||||||
|
const FaceDetectionResult& result,
|
||||||
int line_size, float font_size) {
|
int line_size, float font_size) {
|
||||||
auto color_map = GetColorMap();
|
auto color_map = GetColorMap();
|
||||||
int h = im->rows;
|
int h = im.rows;
|
||||||
int w = im->cols;
|
int w = im.cols;
|
||||||
|
|
||||||
|
auto vis_im = im.clone();
|
||||||
bool vis_landmarks = false;
|
bool vis_landmarks = false;
|
||||||
if ((result.landmarks_per_face > 0) &&
|
if ((result.landmarks_per_face > 0) &&
|
||||||
(result.boxes.size() * result.landmarks_per_face ==
|
(result.boxes.size() * result.landmarks_per_face ==
|
||||||
@@ -57,9 +59,9 @@ void Visualize::VisFaceDetection(cv::Mat* im, const FaceDetectionResult& result,
|
|||||||
cv::Rect text_background =
|
cv::Rect text_background =
|
||||||
cv::Rect(result.boxes[i][0], result.boxes[i][1] - text_size.height,
|
cv::Rect(result.boxes[i][0], result.boxes[i][1] - text_size.height,
|
||||||
text_size.width, text_size.height);
|
text_size.width, text_size.height);
|
||||||
cv::rectangle(*im, rect, rect_color, line_size);
|
cv::rectangle(vis_im, rect, rect_color, line_size);
|
||||||
cv::putText(*im, text, origin, font, font_size, cv::Scalar(255, 255, 255),
|
cv::putText(vis_im, text, origin, font, font_size,
|
||||||
1);
|
cv::Scalar(255, 255, 255), 1);
|
||||||
// vis landmarks (if have)
|
// vis landmarks (if have)
|
||||||
if (vis_landmarks) {
|
if (vis_landmarks) {
|
||||||
cv::Scalar landmark_color = rect_color;
|
cv::Scalar landmark_color = rect_color;
|
||||||
@@ -69,10 +71,11 @@ void Visualize::VisFaceDetection(cv::Mat* im, const FaceDetectionResult& result,
|
|||||||
result.landmarks[i * result.landmarks_per_face + j][0]);
|
result.landmarks[i * result.landmarks_per_face + j][0]);
|
||||||
landmark.y = static_cast<int>(
|
landmark.y = static_cast<int>(
|
||||||
result.landmarks[i * result.landmarks_per_face + j][1]);
|
result.landmarks[i * result.landmarks_per_face + j][1]);
|
||||||
cv::circle(*im, landmark, line_size, landmark_color, -1);
|
cv::circle(vis_im, landmark, line_size, landmark_color, -1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return vis_im;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
|
@@ -65,12 +65,14 @@ static void RemoveSmallConnectedArea(cv::Mat* alpha_pred,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Visualize::VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
|
cv::Mat Visualize::VisMattingAlpha(const cv::Mat& im,
|
||||||
cv::Mat* vis_img,
|
const MattingResult& result,
|
||||||
bool remove_small_connected_area) {
|
bool remove_small_connected_area) {
|
||||||
// 只可视化alpha,fgr(前景)本身就是一张图 不需要可视化
|
// 只可视化alpha,fgr(前景)本身就是一张图 不需要可视化
|
||||||
FDASSERT((!im.empty()), "im can't be empty!");
|
FDASSERT((!im.empty()), "im can't be empty!");
|
||||||
FDASSERT((im.channels() == 3), "Only support 3 channels mat!");
|
FDASSERT((im.channels() == 3), "Only support 3 channels mat!");
|
||||||
|
|
||||||
|
auto vis_img = im.clone();
|
||||||
int out_h = static_cast<int>(result.shape[0]);
|
int out_h = static_cast<int>(result.shape[0]);
|
||||||
int out_w = static_cast<int>(result.shape[1]);
|
int out_w = static_cast<int>(result.shape[1]);
|
||||||
int height = im.rows;
|
int height = im.rows;
|
||||||
@@ -87,18 +89,11 @@ void Visualize::VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
|
|||||||
cv::resize(alpha, alpha, cv::Size(width, height));
|
cv::resize(alpha, alpha, cv::Size(width, height));
|
||||||
}
|
}
|
||||||
|
|
||||||
int vis_h = (*vis_img).rows;
|
if ((vis_img).type() != CV_8UC3) {
|
||||||
int vis_w = (*vis_img).cols;
|
(vis_img).convertTo((vis_img), CV_8UC3);
|
||||||
|
|
||||||
if ((vis_h != height) || (vis_w != width)) {
|
|
||||||
// faster than resize
|
|
||||||
(*vis_img) = cv::Mat::zeros(height, width, CV_8UC3);
|
|
||||||
}
|
|
||||||
if ((*vis_img).type() != CV_8UC3) {
|
|
||||||
(*vis_img).convertTo((*vis_img), CV_8UC3);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uchar* vis_data = static_cast<uchar*>(vis_img->data);
|
uchar* vis_data = static_cast<uchar*>(vis_img.data);
|
||||||
uchar* im_data = static_cast<uchar*>(im.data);
|
uchar* im_data = static_cast<uchar*>(im.data);
|
||||||
float* alpha_data = reinterpret_cast<float*>(alpha.data);
|
float* alpha_data = reinterpret_cast<float*>(alpha.data);
|
||||||
|
|
||||||
@@ -116,6 +111,7 @@ void Visualize::VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
|
|||||||
(1.f - alpha_val) * 120.f);
|
(1.f - alpha_val) * 120.f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return vis_img;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
|
@@ -21,24 +21,24 @@
|
|||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
|
|
||||||
void Visualize::VisSegmentation(const cv::Mat& im,
|
cv::Mat Visualize::VisSegmentation(const cv::Mat& im,
|
||||||
const SegmentationResult& result,
|
const SegmentationResult& result) {
|
||||||
cv::Mat* vis_img, const int& num_classes) {
|
auto color_map = GetColorMap();
|
||||||
auto color_map = GetColorMap(num_classes);
|
|
||||||
int64_t height = result.shape[0];
|
int64_t height = result.shape[0];
|
||||||
int64_t width = result.shape[1];
|
int64_t width = result.shape[1];
|
||||||
*vis_img = cv::Mat::zeros(height, width, CV_8UC3);
|
auto vis_img = cv::Mat(height, width, CV_8UC3);
|
||||||
|
|
||||||
int64_t index = 0;
|
int64_t index = 0;
|
||||||
for (int i = 0; i < height; i++) {
|
for (int i = 0; i < height; i++) {
|
||||||
for (int j = 0; j < width; j++) {
|
for (int j = 0; j < width; j++) {
|
||||||
int category_id = result.label_map[index++];
|
int category_id = result.label_map[index++];
|
||||||
vis_img->at<cv::Vec3b>(i, j)[0] = color_map[3 * category_id + 0];
|
vis_img.at<cv::Vec3b>(i, j)[0] = color_map[3 * category_id + 0];
|
||||||
vis_img->at<cv::Vec3b>(i, j)[1] = color_map[3 * category_id + 1];
|
vis_img.at<cv::Vec3b>(i, j)[1] = color_map[3 * category_id + 1];
|
||||||
vis_img->at<cv::Vec3b>(i, j)[2] = color_map[3 * category_id + 2];
|
vis_img.at<cv::Vec3b>(i, j)[2] = color_map[3 * category_id + 2];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cv::addWeighted(im, .5, *vis_img, .5, 0, *vis_img);
|
cv::addWeighted(im, .5, vis_img, .5, 0, vis_img);
|
||||||
|
return vis_img;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
|
@@ -25,15 +25,14 @@ class FASTDEPLOY_DECL Visualize {
|
|||||||
static int num_classes_;
|
static int num_classes_;
|
||||||
static std::vector<int> color_map_;
|
static std::vector<int> color_map_;
|
||||||
static const std::vector<int>& GetColorMap(int num_classes = 1000);
|
static const std::vector<int>& GetColorMap(int num_classes = 1000);
|
||||||
static void VisDetection(cv::Mat* im, const DetectionResult& result,
|
static cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result,
|
||||||
int line_size = 2, float font_size = 0.5f);
|
int line_size = 2, float font_size = 0.5f);
|
||||||
static void VisFaceDetection(cv::Mat* im, const FaceDetectionResult& result,
|
static cv::Mat VisFaceDetection(const cv::Mat& im,
|
||||||
|
const FaceDetectionResult& result,
|
||||||
int line_size = 2, float font_size = 0.5f);
|
int line_size = 2, float font_size = 0.5f);
|
||||||
static void VisSegmentation(const cv::Mat& im,
|
static cv::Mat VisSegmentation(const cv::Mat& im,
|
||||||
const SegmentationResult& result,
|
const SegmentationResult& result);
|
||||||
cv::Mat* vis_img, const int& num_classes = 1000);
|
static cv::Mat VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
|
||||||
static void VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
|
|
||||||
cv::Mat* vis_img,
|
|
||||||
bool remove_small_connected_area = false);
|
bool remove_small_connected_area = false);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -22,34 +22,41 @@ void BindVisualize(pybind11::module& m) {
|
|||||||
[](pybind11::array& im_data, vision::DetectionResult& result,
|
[](pybind11::array& im_data, vision::DetectionResult& result,
|
||||||
int line_size, float font_size) {
|
int line_size, float font_size) {
|
||||||
auto im = PyArrayToCvMat(im_data);
|
auto im = PyArrayToCvMat(im_data);
|
||||||
vision::Visualize::VisDetection(&im, result, line_size,
|
auto vis_im = vision::Visualize::VisDetection(
|
||||||
font_size);
|
im, result, line_size, font_size);
|
||||||
|
FDTensor out;
|
||||||
|
vision::Mat(vis_im).ShareWithTensor(&out);
|
||||||
|
return TensorToPyArray(out);
|
||||||
})
|
})
|
||||||
.def_static(
|
.def_static(
|
||||||
"vis_face_detection",
|
"vis_face_detection",
|
||||||
[](pybind11::array& im_data, vision::FaceDetectionResult& result,
|
[](pybind11::array& im_data, vision::FaceDetectionResult& result,
|
||||||
int line_size, float font_size) {
|
int line_size, float font_size) {
|
||||||
auto im = PyArrayToCvMat(im_data);
|
auto im = PyArrayToCvMat(im_data);
|
||||||
vision::Visualize::VisFaceDetection(&im, result, line_size,
|
auto vis_im = vision::Visualize::VisFaceDetection(
|
||||||
font_size);
|
im, result, line_size, font_size);
|
||||||
|
FDTensor out;
|
||||||
|
vision::Mat(vis_im).ShareWithTensor(&out);
|
||||||
|
return TensorToPyArray(out);
|
||||||
})
|
})
|
||||||
.def_static(
|
.def_static(
|
||||||
"vis_segmentation",
|
"vis_segmentation",
|
||||||
[](pybind11::array& im_data, vision::SegmentationResult& result,
|
[](pybind11::array& im_data, vision::SegmentationResult& result) {
|
||||||
pybind11::array& vis_im_data, const int& num_classes) {
|
|
||||||
cv::Mat im = PyArrayToCvMat(im_data);
|
cv::Mat im = PyArrayToCvMat(im_data);
|
||||||
cv::Mat vis_im = PyArrayToCvMat(vis_im_data);
|
auto vis_im = vision::Visualize::VisSegmentation(im, result);
|
||||||
vision::Visualize::VisSegmentation(im, result, &vis_im,
|
FDTensor out;
|
||||||
num_classes);
|
vision::Mat(vis_im).ShareWithTensor(&out);
|
||||||
|
return TensorToPyArray(out);
|
||||||
})
|
})
|
||||||
.def_static(
|
.def_static("vis_matting_alpha",
|
||||||
"vis_matting_alpha",
|
|
||||||
[](pybind11::array& im_data, vision::MattingResult& result,
|
[](pybind11::array& im_data, vision::MattingResult& result,
|
||||||
pybind11::array& vis_im_data, bool remove_small_connected_area) {
|
bool remove_small_connected_area) {
|
||||||
cv::Mat im = PyArrayToCvMat(im_data);
|
cv::Mat im = PyArrayToCvMat(im_data);
|
||||||
cv::Mat vis_im = PyArrayToCvMat(vis_im_data);
|
auto vis_im = vision::Visualize::VisMattingAlpha(
|
||||||
vision::Visualize::VisMattingAlpha(im, result, &vis_im,
|
im, result, remove_small_connected_area);
|
||||||
remove_small_connected_area);
|
FDTensor out;
|
||||||
|
vision::Mat(vis_im).ShareWithTensor(&out);
|
||||||
|
return TensorToPyArray(out);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -1,79 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "fastdeploy/pybind/main.h"
|
|
||||||
|
|
||||||
namespace fastdeploy {
|
|
||||||
void BindWongkinyiu(pybind11::module& m) {
|
|
||||||
auto wongkinyiu_module =
|
|
||||||
m.def_submodule("wongkinyiu", "https://github.com/WongKinYiu");
|
|
||||||
pybind11::class_<vision::wongkinyiu::YOLOv7, FastDeployModel>(
|
|
||||||
wongkinyiu_module, "YOLOv7")
|
|
||||||
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
|
||||||
.def("predict",
|
|
||||||
[](vision::wongkinyiu::YOLOv7& self, pybind11::array& data,
|
|
||||||
float conf_threshold, float nms_iou_threshold) {
|
|
||||||
auto mat = PyArrayToCvMat(data);
|
|
||||||
vision::DetectionResult res;
|
|
||||||
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
|
||||||
return res;
|
|
||||||
})
|
|
||||||
.def_readwrite("size", &vision::wongkinyiu::YOLOv7::size)
|
|
||||||
.def_readwrite("padding_value",
|
|
||||||
&vision::wongkinyiu::YOLOv7::padding_value)
|
|
||||||
.def_readwrite("is_mini_pad", &vision::wongkinyiu::YOLOv7::is_mini_pad)
|
|
||||||
.def_readwrite("is_no_pad", &vision::wongkinyiu::YOLOv7::is_no_pad)
|
|
||||||
.def_readwrite("is_scale_up", &vision::wongkinyiu::YOLOv7::is_scale_up)
|
|
||||||
.def_readwrite("stride", &vision::wongkinyiu::YOLOv7::stride)
|
|
||||||
.def_readwrite("max_wh", &vision::wongkinyiu::YOLOv7::max_wh);
|
|
||||||
|
|
||||||
pybind11::class_<vision::wongkinyiu::YOLOR, FastDeployModel>(
|
|
||||||
wongkinyiu_module, "YOLOR")
|
|
||||||
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
|
||||||
.def("predict",
|
|
||||||
[](vision::wongkinyiu::YOLOR& self, pybind11::array& data,
|
|
||||||
float conf_threshold, float nms_iou_threshold) {
|
|
||||||
auto mat = PyArrayToCvMat(data);
|
|
||||||
vision::DetectionResult res;
|
|
||||||
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
|
||||||
return res;
|
|
||||||
})
|
|
||||||
.def_readwrite("size", &vision::wongkinyiu::YOLOR::size)
|
|
||||||
.def_readwrite("padding_value", &vision::wongkinyiu::YOLOR::padding_value)
|
|
||||||
.def_readwrite("is_mini_pad", &vision::wongkinyiu::YOLOR::is_mini_pad)
|
|
||||||
.def_readwrite("is_no_pad", &vision::wongkinyiu::YOLOR::is_no_pad)
|
|
||||||
.def_readwrite("is_scale_up", &vision::wongkinyiu::YOLOR::is_scale_up)
|
|
||||||
.def_readwrite("stride", &vision::wongkinyiu::YOLOR::stride)
|
|
||||||
.def_readwrite("max_wh", &vision::wongkinyiu::YOLOR::max_wh);
|
|
||||||
|
|
||||||
pybind11::class_<vision::wongkinyiu::ScaledYOLOv4, FastDeployModel>(
|
|
||||||
wongkinyiu_module, "ScaledYOLOv4")
|
|
||||||
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
|
||||||
.def("predict",
|
|
||||||
[](vision::wongkinyiu::ScaledYOLOv4& self, pybind11::array& data,
|
|
||||||
float conf_threshold, float nms_iou_threshold) {
|
|
||||||
auto mat = PyArrayToCvMat(data);
|
|
||||||
vision::DetectionResult res;
|
|
||||||
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
|
||||||
return res;
|
|
||||||
})
|
|
||||||
.def_readwrite("size", &vision::wongkinyiu::ScaledYOLOv4::size)
|
|
||||||
.def_readwrite("padding_value", &vision::wongkinyiu::ScaledYOLOv4::padding_value)
|
|
||||||
.def_readwrite("is_mini_pad", &vision::wongkinyiu::ScaledYOLOv4::is_mini_pad)
|
|
||||||
.def_readwrite("is_no_pad", &vision::wongkinyiu::ScaledYOLOv4::is_no_pad)
|
|
||||||
.def_readwrite("is_scale_up", &vision::wongkinyiu::ScaledYOLOv4::is_scale_up)
|
|
||||||
.def_readwrite("stride", &vision::wongkinyiu::ScaledYOLOv4::stride)
|
|
||||||
.def_readwrite("max_wh", &vision::wongkinyiu::ScaledYOLOv4::max_wh);
|
|
||||||
}
|
|
||||||
} // namespace fastdeploy
|
|
@@ -20,8 +20,6 @@ from . import ppseg
|
|||||||
from . import ultralytics
|
from . import ultralytics
|
||||||
from . import meituan
|
from . import meituan
|
||||||
from . import megvii
|
from . import megvii
|
||||||
from . import visualize
|
|
||||||
from . import wongkinyiu
|
|
||||||
from . import deepcam
|
from . import deepcam
|
||||||
from . import rangilyu
|
from . import rangilyu
|
||||||
from . import linzaer
|
from . import linzaer
|
||||||
@@ -29,3 +27,7 @@ from . import biubug6
|
|||||||
from . import ppogg
|
from . import ppogg
|
||||||
from . import deepinsight
|
from . import deepinsight
|
||||||
from . import zhkkke
|
from . import zhkkke
|
||||||
|
|
||||||
|
from . import detection
|
||||||
|
|
||||||
|
from .visualize import *
|
||||||
|
18
fastdeploy/vision/detection/__init__.py
Normal file
18
fastdeploy/vision/detection/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from .yolov7 import YOLOv7
|
||||||
|
from .yolor import YOLOR
|
||||||
|
from .scaled_yolov4 import ScaledYOLOv4
|
116
fastdeploy/vision/detection/scaled_yolov4.py
Normal file
116
fastdeploy/vision/detection/scaled_yolov4.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
import logging
|
||||||
|
from ... import FastDeployModel, Frontend
|
||||||
|
from ... import c_lib_wrap as C
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledYOLOv4(FastDeployModel):
|
||||||
|
def __init__(self,
|
||||||
|
model_file,
|
||||||
|
params_file="",
|
||||||
|
runtime_option=None,
|
||||||
|
model_format=Frontend.ONNX):
|
||||||
|
# 调用基函数进行backend_option的初始化
|
||||||
|
# 初始化后的option保存在self._runtime_option
|
||||||
|
super(ScaledYOLOv4, self).__init__(runtime_option)
|
||||||
|
|
||||||
|
self._model = C.vision.detection.ScaledYOLOv4(
|
||||||
|
model_file, params_file, self._runtime_option, model_format)
|
||||||
|
# 通过self.initialized判断整个模型的初始化是否成功
|
||||||
|
assert self.initialized, "ScaledYOLOv4 initialize failed."
|
||||||
|
|
||||||
|
def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5):
|
||||||
|
return self._model.predict(input_image, conf_threshold,
|
||||||
|
nms_iou_threshold)
|
||||||
|
|
||||||
|
# 一些跟ScaledYOLOv4模型有关的属性封装
|
||||||
|
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
|
||||||
|
@property
|
||||||
|
def size(self):
|
||||||
|
return self._model.size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def padding_value(self):
|
||||||
|
return self._model.padding_value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_no_pad(self):
|
||||||
|
return self._model.is_no_pad
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_mini_pad(self):
|
||||||
|
return self._model.is_mini_pad
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_scale_up(self):
|
||||||
|
return self._model.is_scale_up
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stride(self):
|
||||||
|
return self._model.stride
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_wh(self):
|
||||||
|
return self._model.max_wh
|
||||||
|
|
||||||
|
@size.setter
|
||||||
|
def size(self, wh):
|
||||||
|
assert isinstance(wh, [list, tuple]),\
|
||||||
|
"The value to set `size` must be type of tuple or list."
|
||||||
|
assert len(wh) == 2,\
|
||||||
|
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
|
||||||
|
len(wh))
|
||||||
|
self._model.size = wh
|
||||||
|
|
||||||
|
@padding_value.setter
|
||||||
|
def padding_value(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value,
|
||||||
|
list), "The value to set `padding_value` must be type of list."
|
||||||
|
self._model.padding_value = value
|
||||||
|
|
||||||
|
@is_no_pad.setter
|
||||||
|
def is_no_pad(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value, bool), "The value to set `is_no_pad` must be type of bool."
|
||||||
|
self._model.is_no_pad = value
|
||||||
|
|
||||||
|
@is_mini_pad.setter
|
||||||
|
def is_mini_pad(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value,
|
||||||
|
bool), "The value to set `is_mini_pad` must be type of bool."
|
||||||
|
self._model.is_mini_pad = value
|
||||||
|
|
||||||
|
@is_scale_up.setter
|
||||||
|
def is_scale_up(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value,
|
||||||
|
bool), "The value to set `is_scale_up` must be type of bool."
|
||||||
|
self._model.is_scale_up = value
|
||||||
|
|
||||||
|
@stride.setter
|
||||||
|
def stride(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value, int), "The value to set `stride` must be type of int."
|
||||||
|
self._model.stride = value
|
||||||
|
|
||||||
|
@max_wh.setter
|
||||||
|
def max_wh(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value, float), "The value to set `max_wh` must be type of float."
|
||||||
|
self._model.max_wh = value
|
116
fastdeploy/vision/detection/yolor.py
Normal file
116
fastdeploy/vision/detection/yolor.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
import logging
|
||||||
|
from ... import FastDeployModel, Frontend
|
||||||
|
from ... import c_lib_wrap as C
|
||||||
|
|
||||||
|
|
||||||
|
class YOLOR(FastDeployModel):
|
||||||
|
def __init__(self,
|
||||||
|
model_file,
|
||||||
|
params_file="",
|
||||||
|
runtime_option=None,
|
||||||
|
model_format=Frontend.ONNX):
|
||||||
|
# 调用基函数进行backend_option的初始化
|
||||||
|
# 初始化后的option保存在self._runtime_option
|
||||||
|
super(YOLOR, self).__init__(runtime_option)
|
||||||
|
|
||||||
|
self._model = C.vision.detection.YOLOR(
|
||||||
|
model_file, params_file, self._runtime_option, model_format)
|
||||||
|
# 通过self.initialized判断整个模型的初始化是否成功
|
||||||
|
assert self.initialized, "YOLOR initialize failed."
|
||||||
|
|
||||||
|
def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5):
|
||||||
|
return self._model.predict(input_image, conf_threshold,
|
||||||
|
nms_iou_threshold)
|
||||||
|
|
||||||
|
# 一些跟YOLOR模型有关的属性封装
|
||||||
|
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
|
||||||
|
@property
|
||||||
|
def size(self):
|
||||||
|
return self._model.size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def padding_value(self):
|
||||||
|
return self._model.padding_value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_no_pad(self):
|
||||||
|
return self._model.is_no_pad
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_mini_pad(self):
|
||||||
|
return self._model.is_mini_pad
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_scale_up(self):
|
||||||
|
return self._model.is_scale_up
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stride(self):
|
||||||
|
return self._model.stride
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_wh(self):
|
||||||
|
return self._model.max_wh
|
||||||
|
|
||||||
|
@size.setter
|
||||||
|
def size(self, wh):
|
||||||
|
assert isinstance(wh, [list, tuple]),\
|
||||||
|
"The value to set `size` must be type of tuple or list."
|
||||||
|
assert len(wh) == 2,\
|
||||||
|
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
|
||||||
|
len(wh))
|
||||||
|
self._model.size = wh
|
||||||
|
|
||||||
|
@padding_value.setter
|
||||||
|
def padding_value(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value,
|
||||||
|
list), "The value to set `padding_value` must be type of list."
|
||||||
|
self._model.padding_value = value
|
||||||
|
|
||||||
|
@is_no_pad.setter
|
||||||
|
def is_no_pad(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value, bool), "The value to set `is_no_pad` must be type of bool."
|
||||||
|
self._model.is_no_pad = value
|
||||||
|
|
||||||
|
@is_mini_pad.setter
|
||||||
|
def is_mini_pad(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value,
|
||||||
|
bool), "The value to set `is_mini_pad` must be type of bool."
|
||||||
|
self._model.is_mini_pad = value
|
||||||
|
|
||||||
|
@is_scale_up.setter
|
||||||
|
def is_scale_up(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value,
|
||||||
|
bool), "The value to set `is_scale_up` must be type of bool."
|
||||||
|
self._model.is_scale_up = value
|
||||||
|
|
||||||
|
@stride.setter
|
||||||
|
def stride(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value, int), "The value to set `stride` must be type of int."
|
||||||
|
self._model.stride = value
|
||||||
|
|
||||||
|
@max_wh.setter
|
||||||
|
def max_wh(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value, float), "The value to set `max_wh` must be type of float."
|
||||||
|
self._model.max_wh = value
|
116
fastdeploy/vision/detection/yolov7.py
Normal file
116
fastdeploy/vision/detection/yolov7.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
import logging
|
||||||
|
from ... import FastDeployModel, Frontend
|
||||||
|
from ... import c_lib_wrap as C
|
||||||
|
|
||||||
|
|
||||||
|
class YOLOv7(FastDeployModel):
|
||||||
|
def __init__(self,
|
||||||
|
model_file,
|
||||||
|
params_file="",
|
||||||
|
runtime_option=None,
|
||||||
|
model_format=Frontend.ONNX):
|
||||||
|
# 调用基函数进行backend_option的初始化
|
||||||
|
# 初始化后的option保存在self._runtime_option
|
||||||
|
super(YOLOv7, self).__init__(runtime_option)
|
||||||
|
|
||||||
|
self._model = C.vision.detection.YOLOv7(
|
||||||
|
model_file, params_file, self._runtime_option, model_format)
|
||||||
|
# 通过self.initialized判断整个模型的初始化是否成功
|
||||||
|
assert self.initialized, "YOLOv7 initialize failed."
|
||||||
|
|
||||||
|
def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5):
|
||||||
|
return self._model.predict(input_image, conf_threshold,
|
||||||
|
nms_iou_threshold)
|
||||||
|
|
||||||
|
# 一些跟YOLOv7模型有关的属性封装
|
||||||
|
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
|
||||||
|
@property
|
||||||
|
def size(self):
|
||||||
|
return self._model.size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def padding_value(self):
|
||||||
|
return self._model.padding_value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_no_pad(self):
|
||||||
|
return self._model.is_no_pad
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_mini_pad(self):
|
||||||
|
return self._model.is_mini_pad
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_scale_up(self):
|
||||||
|
return self._model.is_scale_up
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stride(self):
|
||||||
|
return self._model.stride
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_wh(self):
|
||||||
|
return self._model.max_wh
|
||||||
|
|
||||||
|
@size.setter
|
||||||
|
def size(self, wh):
|
||||||
|
assert isinstance(wh, [list, tuple]),\
|
||||||
|
"The value to set `size` must be type of tuple or list."
|
||||||
|
assert len(wh) == 2,\
|
||||||
|
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
|
||||||
|
len(wh))
|
||||||
|
self._model.size = wh
|
||||||
|
|
||||||
|
@padding_value.setter
|
||||||
|
def padding_value(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value,
|
||||||
|
list), "The value to set `padding_value` must be type of list."
|
||||||
|
self._model.padding_value = value
|
||||||
|
|
||||||
|
@is_no_pad.setter
|
||||||
|
def is_no_pad(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value, bool), "The value to set `is_no_pad` must be type of bool."
|
||||||
|
self._model.is_no_pad = value
|
||||||
|
|
||||||
|
@is_mini_pad.setter
|
||||||
|
def is_mini_pad(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value,
|
||||||
|
bool), "The value to set `is_mini_pad` must be type of bool."
|
||||||
|
self._model.is_mini_pad = value
|
||||||
|
|
||||||
|
@is_scale_up.setter
|
||||||
|
def is_scale_up(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value,
|
||||||
|
bool), "The value to set `is_scale_up` must be type of bool."
|
||||||
|
self._model.is_scale_up = value
|
||||||
|
|
||||||
|
@stride.setter
|
||||||
|
def stride(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value, int), "The value to set `stride` must be type of int."
|
||||||
|
self._model.stride = value
|
||||||
|
|
||||||
|
@max_wh.setter
|
||||||
|
def max_wh(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value, float), "The value to set `max_wh` must be type of float."
|
||||||
|
self._model.max_wh = value
|
@@ -17,23 +17,22 @@ import logging
|
|||||||
from ... import c_lib_wrap as C
|
from ... import c_lib_wrap as C
|
||||||
|
|
||||||
|
|
||||||
def vis_detection(im_data, det_result, line_size=1, font_size=0.5):
|
def vis_detection(im_data, det_result, line_size=2, font_size=0.5):
|
||||||
C.vision.Visualize.vis_detection(im_data, det_result, line_size, font_size)
|
return C.vision.Visualize.vis_detection(im_data, det_result, line_size,
|
||||||
|
|
||||||
|
|
||||||
def vis_face_detection(im_data, face_det_result, line_size=1, font_size=0.5):
|
|
||||||
C.vision.Visualize.vis_face_detection(im_data, face_det_result, line_size,
|
|
||||||
font_size)
|
font_size)
|
||||||
|
|
||||||
|
|
||||||
def vis_segmentation(im_data, seg_result, vis_im_data, num_classes=1000):
|
def vis_face_detection(im_data, face_det_result, line_size=2, font_size=0.5):
|
||||||
C.vision.Visualize.vis_segmentation(im_data, seg_result, vis_im_data,
|
return C.vision.Visualize.vis_face_detection(im_data, face_det_result,
|
||||||
num_classes)
|
line_size, font_size)
|
||||||
|
|
||||||
|
|
||||||
|
def vis_segmentation(im_data, seg_result):
|
||||||
|
return C.vision.Visualize.vis_segmentation(im_data, seg_result)
|
||||||
|
|
||||||
|
|
||||||
def vis_matting_alpha(im_data,
|
def vis_matting_alpha(im_data,
|
||||||
matting_result,
|
matting_result,
|
||||||
vis_im_data,
|
|
||||||
remove_small_connected_area=False):
|
remove_small_connected_area=False):
|
||||||
C.vision.Visualize.vis_matting_alpha(im_data, matting_result, vis_im_data,
|
return C.vision.Visualize.vis_matting_alpha(im_data, matting_result,
|
||||||
remove_small_connected_area)
|
remove_small_connected_area)
|
||||||
|
@@ -16,8 +16,8 @@ im = cv2.imread("3.jpg")
|
|||||||
result = model.predict(im, conf_threshold=0.7, nms_iou_threshold=0.3)
|
result = model.predict(im, conf_threshold=0.7, nms_iou_threshold=0.3)
|
||||||
|
|
||||||
# 可视化结果
|
# 可视化结果
|
||||||
fd.vision.visualize.vis_face_detection(im, result)
|
vis_im = fd.vision.visualize.vis_face_detection(im, result)
|
||||||
cv2.imwrite("vis_result.jpg", im)
|
cv2.imwrite("vis_result.jpg", vis_im)
|
||||||
|
|
||||||
# 输出预测结果
|
# 输出预测结果
|
||||||
print(result)
|
print(result)
|
||||||
|
Reference in New Issue
Block a user