mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
Modify yolov7 and visualize functions (#82)
modify yolov7 and visualize functions
This commit is contained in:
@@ -30,6 +30,8 @@ pybind11::dtype FDDataTypeToNumpyDataType(const FDDataType& fd_dtype) {
|
||||
dt = pybind11::dtype::of<float>();
|
||||
} else if (fd_dtype == FDDataType::FP64) {
|
||||
dt = pybind11::dtype::of<double>();
|
||||
} else if (fd_dtype == FDDataType::UINT8) {
|
||||
dt = pybind11::dtype::of<uint8_t>();
|
||||
} else {
|
||||
FDASSERT(false, "The function doesn't support data type of " +
|
||||
Str(fd_dtype) + ".");
|
||||
@@ -46,6 +48,8 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype) {
|
||||
return FDDataType::FP32;
|
||||
} else if (np_dtype.is(pybind11::dtype::of<double>())) {
|
||||
return FDDataType::FP64;
|
||||
} else if (np_dtype.is(pybind11::dtype::of<uint8_t>())) {
|
||||
return FDDataType::UINT8;
|
||||
}
|
||||
FDASSERT(false,
|
||||
"NumpyDataTypeToFDDataType() only support "
|
||||
@@ -66,6 +70,13 @@ void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
|
||||
}
|
||||
}
|
||||
|
||||
pybind11::array TensorToPyArray(const FDTensor& tensor) {
|
||||
auto numpy_dtype = FDDataTypeToNumpyDataType(tensor.dtype);
|
||||
auto out = pybind11::array(numpy_dtype, tensor.shape);
|
||||
memcpy(out.mutable_data(), tensor.Data(), tensor.Numel() * FDDataTypeSize(tensor.dtype));
|
||||
return out;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_VISION
|
||||
int NumpyDataTypeToOpenCvType(const pybind11::dtype& np_dtype) {
|
||||
if (np_dtype.is(pybind11::dtype::of<int32_t>())) {
|
||||
|
@@ -36,12 +36,14 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype);
|
||||
|
||||
void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
|
||||
bool share_buffer = false);
|
||||
pybind11::array TensorToPyArray(const FDTensor& tensor);
|
||||
|
||||
#ifdef ENABLE_VISION
|
||||
cv::Mat PyArrayToCvMat(pybind11::array& pyarray);
|
||||
#endif
|
||||
|
||||
template <typename T> FDDataType CTypeToFDDataType() {
|
||||
template <typename T>
|
||||
FDDataType CTypeToFDDataType() {
|
||||
if (std::is_same<T, int32_t>::value) {
|
||||
return FDDataType::INT32;
|
||||
} else if (std::is_same<T, int64_t>::value) {
|
||||
@@ -57,9 +59,9 @@ template <typename T> FDDataType CTypeToFDDataType() {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<pybind11::array>
|
||||
PyBackendInfer(T& self, const std::vector<std::string>& names,
|
||||
std::vector<pybind11::array>& data) {
|
||||
std::vector<pybind11::array> PyBackendInfer(
|
||||
T& self, const std::vector<std::string>& names,
|
||||
std::vector<pybind11::array>& data) {
|
||||
std::vector<FDTensor> inputs(data.size());
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
// TODO(jiangjiajun) here is considered to use user memory directly
|
||||
@@ -85,4 +87,4 @@ PyBackendInfer(T& self, const std::vector<std::string>& names,
|
||||
return results;
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
} // namespace fastdeploy
|
||||
|
@@ -23,6 +23,9 @@
|
||||
#include "fastdeploy/vision/deepinsight/partial_fc.h"
|
||||
#include "fastdeploy/vision/deepinsight/scrfd.h"
|
||||
#include "fastdeploy/vision/deepinsight/vpl.h"
|
||||
#include "fastdeploy/vision/detection/contrib/scaledyolov4.h"
|
||||
#include "fastdeploy/vision/detection/contrib/yolor.h"
|
||||
#include "fastdeploy/vision/detection/contrib/yolov7.h"
|
||||
#include "fastdeploy/vision/linzaer/ultraface.h"
|
||||
#include "fastdeploy/vision/megvii/yolox.h"
|
||||
#include "fastdeploy/vision/meituan/yolov6.h"
|
||||
@@ -32,9 +35,6 @@
|
||||
#include "fastdeploy/vision/ppseg/model.h"
|
||||
#include "fastdeploy/vision/rangilyu/nanodet_plus.h"
|
||||
#include "fastdeploy/vision/ultralytics/yolov5.h"
|
||||
#include "fastdeploy/vision/wongkinyiu/scaledyolov4.h"
|
||||
#include "fastdeploy/vision/wongkinyiu/yolor.h"
|
||||
#include "fastdeploy/vision/wongkinyiu/yolov7.h"
|
||||
#include "fastdeploy/vision/zhkkke/modnet.h"
|
||||
#endif
|
||||
|
||||
|
@@ -27,7 +27,7 @@ namespace vision {
|
||||
|
||||
enum Layout { HWC, CHW };
|
||||
|
||||
struct Mat {
|
||||
struct FASTDEPLOY_DECL Mat {
|
||||
explicit Mat(cv::Mat& mat) {
|
||||
cpu_mat = mat;
|
||||
device = Device::CPU;
|
||||
@@ -76,5 +76,5 @@ struct Mat {
|
||||
Device device = Device::CPU;
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
@@ -12,13 +12,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/wongkinyiu/scaledyolov4.h"
|
||||
#include "fastdeploy/vision/detection/contrib/scaledyolov4.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace wongkinyiu {
|
||||
namespace detection {
|
||||
|
||||
void ScaledYOLOv4::LetterBox(Mat* mat, const std::vector<int>& size,
|
||||
const std::vector<float>& color, bool _auto,
|
||||
@@ -65,8 +65,8 @@ ScaledYOLOv4::ScaledYOLOv4(const std::string& model_file,
|
||||
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
|
||||
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
|
||||
} else {
|
||||
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
|
||||
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
||||
valid_cpu_backends = {Backend::PDINFER};
|
||||
valid_gpu_backends = {Backend::PDINFER};
|
||||
}
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
@@ -219,10 +219,6 @@ bool ScaledYOLOv4::Postprocess(
|
||||
|
||||
bool ScaledYOLOv4::Predict(cv::Mat* im, DetectionResult* result,
|
||||
float conf_threshold, float nms_iou_threshold) {
|
||||
#ifdef FASTDEPLOY_DEBUG
|
||||
TIMERECORD_START(0)
|
||||
#endif
|
||||
|
||||
Mat mat(*im);
|
||||
std::vector<FDTensor> input_tensors(1);
|
||||
|
||||
@@ -239,34 +235,21 @@ bool ScaledYOLOv4::Predict(cv::Mat* im, DetectionResult* result,
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef FASTDEPLOY_DEBUG
|
||||
TIMERECORD_END(0, "Preprocess")
|
||||
TIMERECORD_START(1)
|
||||
#endif
|
||||
|
||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||
std::vector<FDTensor> output_tensors;
|
||||
if (!Infer(input_tensors, &output_tensors)) {
|
||||
FDERROR << "Failed to inference." << std::endl;
|
||||
return false;
|
||||
}
|
||||
#ifdef FASTDEPLOY_DEBUG
|
||||
TIMERECORD_END(1, "Inference")
|
||||
TIMERECORD_START(2)
|
||||
#endif
|
||||
|
||||
if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
|
||||
nms_iou_threshold)) {
|
||||
FDERROR << "Failed to post process." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef FASTDEPLOY_DEBUG
|
||||
TIMERECORD_END(2, "Postprocess")
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace wongkinyiu
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -19,7 +19,7 @@
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace wongkinyiu {
|
||||
namespace detection {
|
||||
|
||||
class FASTDEPLOY_DECL ScaledYOLOv4 : public FastDeployModel {
|
||||
public:
|
||||
@@ -31,7 +31,7 @@ class FASTDEPLOY_DECL ScaledYOLOv4 : public FastDeployModel {
|
||||
const Frontend& model_format = Frontend::ONNX);
|
||||
|
||||
// 定义模型的名称
|
||||
virtual std::string ModelName() const { return "WongKinYiu/ScaledYOLOv4"; }
|
||||
virtual std::string ModelName() const { return "ScaledYOLOv4"; }
|
||||
|
||||
// 模型预测接口,即用户调用的接口
|
||||
// im 为用户的输入数据,目前对于CV均定义为cv::Mat
|
||||
@@ -98,6 +98,6 @@ class FASTDEPLOY_DECL ScaledYOLOv4 : public FastDeployModel {
|
||||
// auto check by fastdeploy after the internal Runtime already initialized.
|
||||
bool is_dynamic_input_;
|
||||
};
|
||||
} // namespace wongkinyiu
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindScaledYOLOv4(pybind11::module& m) {
|
||||
pybind11::class_<vision::detection::ScaledYOLOv4, FastDeployModel>(
|
||||
m, "ScaledYOLOv4")
|
||||
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
||||
.def("predict",
|
||||
[](vision::detection::ScaledYOLOv4& self, pybind11::array& data,
|
||||
float conf_threshold, float nms_iou_threshold) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
||||
return res;
|
||||
})
|
||||
.def_readwrite("size", &vision::detection::ScaledYOLOv4::size)
|
||||
.def_readwrite("padding_value",
|
||||
&vision::detection::ScaledYOLOv4::padding_value)
|
||||
.def_readwrite("is_mini_pad",
|
||||
&vision::detection::ScaledYOLOv4::is_mini_pad)
|
||||
.def_readwrite("is_no_pad", &vision::detection::ScaledYOLOv4::is_no_pad)
|
||||
.def_readwrite("is_scale_up",
|
||||
&vision::detection::ScaledYOLOv4::is_scale_up)
|
||||
.def_readwrite("stride", &vision::detection::ScaledYOLOv4::stride)
|
||||
.def_readwrite("max_wh", &vision::detection::ScaledYOLOv4::max_wh);
|
||||
}
|
||||
} // namespace fastdeploy
|
@@ -12,13 +12,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/wongkinyiu/yolor.h"
|
||||
#include "fastdeploy/vision/detection/contrib/yolor.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace wongkinyiu {
|
||||
namespace detection {
|
||||
|
||||
void YOLOR::LetterBox(Mat* mat, const std::vector<int>& size,
|
||||
const std::vector<float>& color, bool _auto,
|
||||
@@ -63,8 +63,8 @@ YOLOR::YOLOR(const std::string& model_file, const std::string& params_file,
|
||||
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
|
||||
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
|
||||
} else {
|
||||
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
|
||||
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
||||
valid_cpu_backends = {Backend::PDINFER};
|
||||
valid_gpu_backends = {Backend::PDINFER};
|
||||
}
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
@@ -216,10 +216,6 @@ bool YOLOR::Postprocess(
|
||||
|
||||
bool YOLOR::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
||||
float nms_iou_threshold) {
|
||||
#ifdef FASTDEPLOY_DEBUG
|
||||
TIMERECORD_START(0)
|
||||
#endif
|
||||
|
||||
Mat mat(*im);
|
||||
std::vector<FDTensor> input_tensors(1);
|
||||
|
||||
@@ -236,21 +232,12 @@ bool YOLOR::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef FASTDEPLOY_DEBUG
|
||||
TIMERECORD_END(0, "Preprocess")
|
||||
TIMERECORD_START(1)
|
||||
#endif
|
||||
|
||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||
std::vector<FDTensor> output_tensors;
|
||||
if (!Infer(input_tensors, &output_tensors)) {
|
||||
FDERROR << "Failed to inference." << std::endl;
|
||||
return false;
|
||||
}
|
||||
#ifdef FASTDEPLOY_DEBUG
|
||||
TIMERECORD_END(1, "Inference")
|
||||
TIMERECORD_START(2)
|
||||
#endif
|
||||
|
||||
if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
|
||||
nms_iou_threshold)) {
|
||||
@@ -258,12 +245,9 @@ bool YOLOR::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef FASTDEPLOY_DEBUG
|
||||
TIMERECORD_END(2, "Postprocess")
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace wongkinyiu
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -19,7 +19,7 @@
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace wongkinyiu {
|
||||
namespace detection {
|
||||
|
||||
class FASTDEPLOY_DECL YOLOR : public FastDeployModel {
|
||||
public:
|
||||
@@ -30,7 +30,7 @@ class FASTDEPLOY_DECL YOLOR : public FastDeployModel {
|
||||
const Frontend& model_format = Frontend::ONNX);
|
||||
|
||||
// 定义模型的名称
|
||||
virtual std::string ModelName() const { return "WongKinYiu/yolor"; }
|
||||
virtual std::string ModelName() const { return "YOLOR"; }
|
||||
|
||||
// 模型预测接口,即用户调用的接口
|
||||
// im 为用户的输入数据,目前对于CV均定义为cv::Mat
|
||||
@@ -97,6 +97,6 @@ class FASTDEPLOY_DECL YOLOR : public FastDeployModel {
|
||||
// auto check by fastdeploy after the internal Runtime already initialized.
|
||||
bool is_dynamic_input_;
|
||||
};
|
||||
} // namespace wongkinyiu
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
37
csrcs/fastdeploy/vision/detection/contrib/yolor_pybind.cc
Normal file
37
csrcs/fastdeploy/vision/detection/contrib/yolor_pybind.cc
Normal file
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindYOLOR(pybind11::module& m) {
|
||||
pybind11::class_<vision::detection::YOLOR, FastDeployModel>(m, "YOLOR")
|
||||
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
||||
.def("predict",
|
||||
[](vision::detection::YOLOR& self, pybind11::array& data,
|
||||
float conf_threshold, float nms_iou_threshold) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
||||
return res;
|
||||
})
|
||||
.def_readwrite("size", &vision::detection::YOLOR::size)
|
||||
.def_readwrite("padding_value", &vision::detection::YOLOR::padding_value)
|
||||
.def_readwrite("is_mini_pad", &vision::detection::YOLOR::is_mini_pad)
|
||||
.def_readwrite("is_no_pad", &vision::detection::YOLOR::is_no_pad)
|
||||
.def_readwrite("is_scale_up", &vision::detection::YOLOR::is_scale_up)
|
||||
.def_readwrite("stride", &vision::detection::YOLOR::stride)
|
||||
.def_readwrite("max_wh", &vision::detection::YOLOR::max_wh);
|
||||
}
|
||||
} // namespace fastdeploy
|
@@ -12,13 +12,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/wongkinyiu/yolov7.h"
|
||||
#include "fastdeploy/vision/detection/contrib/yolov7.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace wongkinyiu {
|
||||
namespace detection {
|
||||
|
||||
void YOLOv7::LetterBox(Mat* mat, const std::vector<int>& size,
|
||||
const std::vector<float>& color, bool _auto,
|
||||
@@ -64,13 +64,12 @@ YOLOv7::YOLOv7(const std::string& model_file, const std::string& params_file,
|
||||
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
|
||||
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
|
||||
} else {
|
||||
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
|
||||
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
||||
valid_cpu_backends = {Backend::PDINFER};
|
||||
valid_gpu_backends = {Backend::PDINFER};
|
||||
}
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
@@ -217,10 +216,6 @@ bool YOLOv7::Postprocess(
|
||||
|
||||
bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
||||
float nms_iou_threshold) {
|
||||
#ifdef FASTDEPLOY_DEBUG
|
||||
TIMERECORD_START(0)
|
||||
#endif
|
||||
|
||||
Mat mat(*im);
|
||||
std::vector<FDTensor> input_tensors(1);
|
||||
|
||||
@@ -237,21 +232,12 @@ bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef FASTDEPLOY_DEBUG
|
||||
TIMERECORD_END(0, "Preprocess")
|
||||
TIMERECORD_START(1)
|
||||
#endif
|
||||
|
||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||
std::vector<FDTensor> output_tensors;
|
||||
if (!Infer(input_tensors, &output_tensors)) {
|
||||
FDERROR << "Failed to inference." << std::endl;
|
||||
return false;
|
||||
}
|
||||
#ifdef FASTDEPLOY_DEBUG
|
||||
TIMERECORD_END(1, "Inference")
|
||||
TIMERECORD_START(2)
|
||||
#endif
|
||||
|
||||
if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
|
||||
nms_iou_threshold)) {
|
||||
@@ -259,12 +245,9 @@ bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef FASTDEPLOY_DEBUG
|
||||
TIMERECORD_END(2, "Postprocess")
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace wongkinyiu
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -19,18 +19,16 @@
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace wongkinyiu {
|
||||
namespace detection {
|
||||
|
||||
class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel {
|
||||
public:
|
||||
// 当model_format为ONNX时,无需指定params_file
|
||||
// 当model_format为Paddle时,则需同时指定model_file & params_file
|
||||
YOLOv7(const std::string& model_file, const std::string& params_file = "",
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const Frontend& model_format = Frontend::ONNX);
|
||||
|
||||
// 定义模型的名称
|
||||
virtual std::string ModelName() const { return "WongKinYiu/yolov7"; }
|
||||
virtual std::string ModelName() const { return "yolov7"; }
|
||||
|
||||
// 模型预测接口,即用户调用的接口
|
||||
// im 为用户的输入数据,目前对于CV均定义为cv::Mat
|
||||
@@ -97,6 +95,6 @@ class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel {
|
||||
// auto check by fastdeploy after the internal Runtime already initialized.
|
||||
bool is_dynamic_input_;
|
||||
};
|
||||
} // namespace wongkinyiu
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
37
csrcs/fastdeploy/vision/detection/contrib/yolov7_pybind.cc
Normal file
37
csrcs/fastdeploy/vision/detection/contrib/yolov7_pybind.cc
Normal file
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindYOLOv7(pybind11::module& m) {
|
||||
pybind11::class_<vision::detection::YOLOv7, FastDeployModel>(m, "YOLOv7")
|
||||
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
||||
.def("predict",
|
||||
[](vision::detection::YOLOv7& self, pybind11::array& data,
|
||||
float conf_threshold, float nms_iou_threshold) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
||||
return res;
|
||||
})
|
||||
.def_readwrite("size", &vision::detection::YOLOv7::size)
|
||||
.def_readwrite("padding_value", &vision::detection::YOLOv7::padding_value)
|
||||
.def_readwrite("is_mini_pad", &vision::detection::YOLOv7::is_mini_pad)
|
||||
.def_readwrite("is_no_pad", &vision::detection::YOLOv7::is_no_pad)
|
||||
.def_readwrite("is_scale_up", &vision::detection::YOLOv7::is_scale_up)
|
||||
.def_readwrite("stride", &vision::detection::YOLOv7::stride)
|
||||
.def_readwrite("max_wh", &vision::detection::YOLOv7::max_wh);
|
||||
}
|
||||
} // namespace fastdeploy
|
30
csrcs/fastdeploy/vision/detection/detection_pybind.cc
Normal file
30
csrcs/fastdeploy/vision/detection/detection_pybind.cc
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void BindYOLOv7(pybind11::module& m);
|
||||
void BindScaledYOLOv4(pybind11::module& m);
|
||||
void BindYOLOR(pybind11::module& m);
|
||||
|
||||
void BindDetection(pybind11::module& m) {
|
||||
auto detection_module =
|
||||
m.def_submodule("detection", "Image object detection models.");
|
||||
BindYOLOv7(detection_module);
|
||||
BindScaledYOLOv4(detection_module);
|
||||
BindYOLOR(detection_module);
|
||||
}
|
||||
} // namespace fastdeploy
|
@@ -18,7 +18,6 @@ namespace fastdeploy {
|
||||
|
||||
void BindPPCls(pybind11::module& m);
|
||||
void BindPPDet(pybind11::module& m);
|
||||
void BindWongkinyiu(pybind11::module& m);
|
||||
void BindPPSeg(pybind11::module& m);
|
||||
void BindUltralytics(pybind11::module& m);
|
||||
void BindMeituan(pybind11::module& m);
|
||||
@@ -30,6 +29,8 @@ void BindBiubug6(pybind11::module& m);
|
||||
void BindPpogg(pybind11::module& m);
|
||||
void BindDeepInsight(pybind11::module& m);
|
||||
void BindZHKKKe(pybind11::module& m);
|
||||
|
||||
void BindDetection(pybind11::module& m);
|
||||
#ifdef ENABLE_VISION_VISUALIZE
|
||||
void BindVisualize(pybind11::module& m);
|
||||
#endif
|
||||
@@ -88,7 +89,6 @@ void BindVision(pybind11::module& m) {
|
||||
BindPPDet(m);
|
||||
BindPPSeg(m);
|
||||
BindUltralytics(m);
|
||||
BindWongkinyiu(m);
|
||||
BindMeituan(m);
|
||||
BindMegvii(m);
|
||||
BindDeepCam(m);
|
||||
@@ -98,6 +98,8 @@ void BindVision(pybind11::module& m) {
|
||||
BindPpogg(m);
|
||||
BindDeepInsight(m);
|
||||
BindZHKKKe(m);
|
||||
|
||||
BindDetection(m);
|
||||
#ifdef ENABLE_VISION_VISUALIZE
|
||||
BindVisualize(m);
|
||||
#endif
|
||||
|
@@ -23,11 +23,13 @@ namespace vision {
|
||||
// Default only support visualize num_classes <= 1000
|
||||
// If need to visualize num_classes > 1000
|
||||
// Please call Visualize::GetColorMap(num_classes) first
|
||||
void Visualize::VisDetection(cv::Mat* im, const DetectionResult& result,
|
||||
int line_size, float font_size) {
|
||||
cv::Mat Visualize::VisDetection(const cv::Mat& im,
|
||||
const DetectionResult& result, int line_size,
|
||||
float font_size) {
|
||||
auto color_map = GetColorMap();
|
||||
int h = im->rows;
|
||||
int w = im->cols;
|
||||
int h = im.rows;
|
||||
int w = im.cols;
|
||||
auto vis_im = im.clone();
|
||||
for (size_t i = 0; i < result.boxes.size(); ++i) {
|
||||
cv::Rect rect(result.boxes[i][0], result.boxes[i][1],
|
||||
result.boxes[i][2] - result.boxes[i][0],
|
||||
@@ -50,10 +52,11 @@ void Visualize::VisDetection(cv::Mat* im, const DetectionResult& result,
|
||||
cv::Rect text_background =
|
||||
cv::Rect(result.boxes[i][0], result.boxes[i][1] - text_size.height,
|
||||
text_size.width, text_size.height);
|
||||
cv::rectangle(*im, rect, rect_color, line_size);
|
||||
cv::putText(*im, text, origin, font, font_size, cv::Scalar(255, 255, 255),
|
||||
1);
|
||||
cv::rectangle(vis_im, rect, rect_color, line_size);
|
||||
cv::putText(vis_im, text, origin, font, font_size,
|
||||
cv::Scalar(255, 255, 255), 1);
|
||||
}
|
||||
return vis_im;
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
|
@@ -24,12 +24,14 @@ namespace vision {
|
||||
// Default only support visualize num_classes <= 1000
|
||||
// If need to visualize num_classes > 1000
|
||||
// Please call Visualize::GetColorMap(num_classes) first
|
||||
void Visualize::VisFaceDetection(cv::Mat* im, const FaceDetectionResult& result,
|
||||
int line_size, float font_size) {
|
||||
cv::Mat Visualize::VisFaceDetection(const cv::Mat& im,
|
||||
const FaceDetectionResult& result,
|
||||
int line_size, float font_size) {
|
||||
auto color_map = GetColorMap();
|
||||
int h = im->rows;
|
||||
int w = im->cols;
|
||||
int h = im.rows;
|
||||
int w = im.cols;
|
||||
|
||||
auto vis_im = im.clone();
|
||||
bool vis_landmarks = false;
|
||||
if ((result.landmarks_per_face > 0) &&
|
||||
(result.boxes.size() * result.landmarks_per_face ==
|
||||
@@ -57,9 +59,9 @@ void Visualize::VisFaceDetection(cv::Mat* im, const FaceDetectionResult& result,
|
||||
cv::Rect text_background =
|
||||
cv::Rect(result.boxes[i][0], result.boxes[i][1] - text_size.height,
|
||||
text_size.width, text_size.height);
|
||||
cv::rectangle(*im, rect, rect_color, line_size);
|
||||
cv::putText(*im, text, origin, font, font_size, cv::Scalar(255, 255, 255),
|
||||
1);
|
||||
cv::rectangle(vis_im, rect, rect_color, line_size);
|
||||
cv::putText(vis_im, text, origin, font, font_size,
|
||||
cv::Scalar(255, 255, 255), 1);
|
||||
// vis landmarks (if have)
|
||||
if (vis_landmarks) {
|
||||
cv::Scalar landmark_color = rect_color;
|
||||
@@ -69,13 +71,14 @@ void Visualize::VisFaceDetection(cv::Mat* im, const FaceDetectionResult& result,
|
||||
result.landmarks[i * result.landmarks_per_face + j][0]);
|
||||
landmark.y = static_cast<int>(
|
||||
result.landmarks[i * result.landmarks_per_face + j][1]);
|
||||
cv::circle(*im, landmark, line_size, landmark_color, -1);
|
||||
cv::circle(vis_im, landmark, line_size, landmark_color, -1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return vis_im;
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
@@ -65,12 +65,14 @@ static void RemoveSmallConnectedArea(cv::Mat* alpha_pred,
|
||||
}
|
||||
}
|
||||
|
||||
void Visualize::VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
|
||||
cv::Mat* vis_img,
|
||||
bool remove_small_connected_area) {
|
||||
cv::Mat Visualize::VisMattingAlpha(const cv::Mat& im,
|
||||
const MattingResult& result,
|
||||
bool remove_small_connected_area) {
|
||||
// 只可视化alpha,fgr(前景)本身就是一张图 不需要可视化
|
||||
FDASSERT((!im.empty()), "im can't be empty!");
|
||||
FDASSERT((im.channels() == 3), "Only support 3 channels mat!");
|
||||
|
||||
auto vis_img = im.clone();
|
||||
int out_h = static_cast<int>(result.shape[0]);
|
||||
int out_w = static_cast<int>(result.shape[1]);
|
||||
int height = im.rows;
|
||||
@@ -87,18 +89,11 @@ void Visualize::VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
|
||||
cv::resize(alpha, alpha, cv::Size(width, height));
|
||||
}
|
||||
|
||||
int vis_h = (*vis_img).rows;
|
||||
int vis_w = (*vis_img).cols;
|
||||
|
||||
if ((vis_h != height) || (vis_w != width)) {
|
||||
// faster than resize
|
||||
(*vis_img) = cv::Mat::zeros(height, width, CV_8UC3);
|
||||
}
|
||||
if ((*vis_img).type() != CV_8UC3) {
|
||||
(*vis_img).convertTo((*vis_img), CV_8UC3);
|
||||
if ((vis_img).type() != CV_8UC3) {
|
||||
(vis_img).convertTo((vis_img), CV_8UC3);
|
||||
}
|
||||
|
||||
uchar* vis_data = static_cast<uchar*>(vis_img->data);
|
||||
uchar* vis_data = static_cast<uchar*>(vis_img.data);
|
||||
uchar* im_data = static_cast<uchar*>(im.data);
|
||||
float* alpha_data = reinterpret_cast<float*>(alpha.data);
|
||||
|
||||
@@ -116,6 +111,7 @@ void Visualize::VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
|
||||
(1.f - alpha_val) * 120.f);
|
||||
}
|
||||
}
|
||||
return vis_img;
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
|
@@ -21,24 +21,24 @@
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
void Visualize::VisSegmentation(const cv::Mat& im,
|
||||
const SegmentationResult& result,
|
||||
cv::Mat* vis_img, const int& num_classes) {
|
||||
auto color_map = GetColorMap(num_classes);
|
||||
cv::Mat Visualize::VisSegmentation(const cv::Mat& im,
|
||||
const SegmentationResult& result) {
|
||||
auto color_map = GetColorMap();
|
||||
int64_t height = result.shape[0];
|
||||
int64_t width = result.shape[1];
|
||||
*vis_img = cv::Mat::zeros(height, width, CV_8UC3);
|
||||
auto vis_img = cv::Mat(height, width, CV_8UC3);
|
||||
|
||||
int64_t index = 0;
|
||||
for (int i = 0; i < height; i++) {
|
||||
for (int j = 0; j < width; j++) {
|
||||
int category_id = result.label_map[index++];
|
||||
vis_img->at<cv::Vec3b>(i, j)[0] = color_map[3 * category_id + 0];
|
||||
vis_img->at<cv::Vec3b>(i, j)[1] = color_map[3 * category_id + 1];
|
||||
vis_img->at<cv::Vec3b>(i, j)[2] = color_map[3 * category_id + 2];
|
||||
vis_img.at<cv::Vec3b>(i, j)[0] = color_map[3 * category_id + 0];
|
||||
vis_img.at<cv::Vec3b>(i, j)[1] = color_map[3 * category_id + 1];
|
||||
vis_img.at<cv::Vec3b>(i, j)[2] = color_map[3 * category_id + 2];
|
||||
}
|
||||
}
|
||||
cv::addWeighted(im, .5, *vis_img, .5, 0, *vis_img);
|
||||
cv::addWeighted(im, .5, vis_img, .5, 0, vis_img);
|
||||
return vis_img;
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
|
@@ -25,16 +25,15 @@ class FASTDEPLOY_DECL Visualize {
|
||||
static int num_classes_;
|
||||
static std::vector<int> color_map_;
|
||||
static const std::vector<int>& GetColorMap(int num_classes = 1000);
|
||||
static void VisDetection(cv::Mat* im, const DetectionResult& result,
|
||||
int line_size = 2, float font_size = 0.5f);
|
||||
static void VisFaceDetection(cv::Mat* im, const FaceDetectionResult& result,
|
||||
int line_size = 2, float font_size = 0.5f);
|
||||
static void VisSegmentation(const cv::Mat& im,
|
||||
const SegmentationResult& result,
|
||||
cv::Mat* vis_img, const int& num_classes = 1000);
|
||||
static void VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
|
||||
cv::Mat* vis_img,
|
||||
bool remove_small_connected_area = false);
|
||||
static cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result,
|
||||
int line_size = 2, float font_size = 0.5f);
|
||||
static cv::Mat VisFaceDetection(const cv::Mat& im,
|
||||
const FaceDetectionResult& result,
|
||||
int line_size = 2, float font_size = 0.5f);
|
||||
static cv::Mat VisSegmentation(const cv::Mat& im,
|
||||
const SegmentationResult& result);
|
||||
static cv::Mat VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
|
||||
bool remove_small_connected_area = false);
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
|
@@ -22,34 +22,41 @@ void BindVisualize(pybind11::module& m) {
|
||||
[](pybind11::array& im_data, vision::DetectionResult& result,
|
||||
int line_size, float font_size) {
|
||||
auto im = PyArrayToCvMat(im_data);
|
||||
vision::Visualize::VisDetection(&im, result, line_size,
|
||||
font_size);
|
||||
auto vis_im = vision::Visualize::VisDetection(
|
||||
im, result, line_size, font_size);
|
||||
FDTensor out;
|
||||
vision::Mat(vis_im).ShareWithTensor(&out);
|
||||
return TensorToPyArray(out);
|
||||
})
|
||||
.def_static(
|
||||
"vis_face_detection",
|
||||
[](pybind11::array& im_data, vision::FaceDetectionResult& result,
|
||||
int line_size, float font_size) {
|
||||
auto im = PyArrayToCvMat(im_data);
|
||||
vision::Visualize::VisFaceDetection(&im, result, line_size,
|
||||
font_size);
|
||||
auto vis_im = vision::Visualize::VisFaceDetection(
|
||||
im, result, line_size, font_size);
|
||||
FDTensor out;
|
||||
vision::Mat(vis_im).ShareWithTensor(&out);
|
||||
return TensorToPyArray(out);
|
||||
})
|
||||
.def_static(
|
||||
"vis_segmentation",
|
||||
[](pybind11::array& im_data, vision::SegmentationResult& result,
|
||||
pybind11::array& vis_im_data, const int& num_classes) {
|
||||
[](pybind11::array& im_data, vision::SegmentationResult& result) {
|
||||
cv::Mat im = PyArrayToCvMat(im_data);
|
||||
cv::Mat vis_im = PyArrayToCvMat(vis_im_data);
|
||||
vision::Visualize::VisSegmentation(im, result, &vis_im,
|
||||
num_classes);
|
||||
auto vis_im = vision::Visualize::VisSegmentation(im, result);
|
||||
FDTensor out;
|
||||
vision::Mat(vis_im).ShareWithTensor(&out);
|
||||
return TensorToPyArray(out);
|
||||
})
|
||||
.def_static(
|
||||
"vis_matting_alpha",
|
||||
[](pybind11::array& im_data, vision::MattingResult& result,
|
||||
pybind11::array& vis_im_data, bool remove_small_connected_area) {
|
||||
cv::Mat im = PyArrayToCvMat(im_data);
|
||||
cv::Mat vis_im = PyArrayToCvMat(vis_im_data);
|
||||
vision::Visualize::VisMattingAlpha(im, result, &vis_im,
|
||||
remove_small_connected_area);
|
||||
});
|
||||
.def_static("vis_matting_alpha",
|
||||
[](pybind11::array& im_data, vision::MattingResult& result,
|
||||
bool remove_small_connected_area) {
|
||||
cv::Mat im = PyArrayToCvMat(im_data);
|
||||
auto vis_im = vision::Visualize::VisMattingAlpha(
|
||||
im, result, remove_small_connected_area);
|
||||
FDTensor out;
|
||||
vision::Mat(vis_im).ShareWithTensor(&out);
|
||||
return TensorToPyArray(out);
|
||||
});
|
||||
}
|
||||
} // namespace fastdeploy
|
||||
|
@@ -1,79 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindWongkinyiu(pybind11::module& m) {
|
||||
auto wongkinyiu_module =
|
||||
m.def_submodule("wongkinyiu", "https://github.com/WongKinYiu");
|
||||
pybind11::class_<vision::wongkinyiu::YOLOv7, FastDeployModel>(
|
||||
wongkinyiu_module, "YOLOv7")
|
||||
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
||||
.def("predict",
|
||||
[](vision::wongkinyiu::YOLOv7& self, pybind11::array& data,
|
||||
float conf_threshold, float nms_iou_threshold) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
||||
return res;
|
||||
})
|
||||
.def_readwrite("size", &vision::wongkinyiu::YOLOv7::size)
|
||||
.def_readwrite("padding_value",
|
||||
&vision::wongkinyiu::YOLOv7::padding_value)
|
||||
.def_readwrite("is_mini_pad", &vision::wongkinyiu::YOLOv7::is_mini_pad)
|
||||
.def_readwrite("is_no_pad", &vision::wongkinyiu::YOLOv7::is_no_pad)
|
||||
.def_readwrite("is_scale_up", &vision::wongkinyiu::YOLOv7::is_scale_up)
|
||||
.def_readwrite("stride", &vision::wongkinyiu::YOLOv7::stride)
|
||||
.def_readwrite("max_wh", &vision::wongkinyiu::YOLOv7::max_wh);
|
||||
|
||||
pybind11::class_<vision::wongkinyiu::YOLOR, FastDeployModel>(
|
||||
wongkinyiu_module, "YOLOR")
|
||||
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
||||
.def("predict",
|
||||
[](vision::wongkinyiu::YOLOR& self, pybind11::array& data,
|
||||
float conf_threshold, float nms_iou_threshold) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
||||
return res;
|
||||
})
|
||||
.def_readwrite("size", &vision::wongkinyiu::YOLOR::size)
|
||||
.def_readwrite("padding_value", &vision::wongkinyiu::YOLOR::padding_value)
|
||||
.def_readwrite("is_mini_pad", &vision::wongkinyiu::YOLOR::is_mini_pad)
|
||||
.def_readwrite("is_no_pad", &vision::wongkinyiu::YOLOR::is_no_pad)
|
||||
.def_readwrite("is_scale_up", &vision::wongkinyiu::YOLOR::is_scale_up)
|
||||
.def_readwrite("stride", &vision::wongkinyiu::YOLOR::stride)
|
||||
.def_readwrite("max_wh", &vision::wongkinyiu::YOLOR::max_wh);
|
||||
|
||||
pybind11::class_<vision::wongkinyiu::ScaledYOLOv4, FastDeployModel>(
|
||||
wongkinyiu_module, "ScaledYOLOv4")
|
||||
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
||||
.def("predict",
|
||||
[](vision::wongkinyiu::ScaledYOLOv4& self, pybind11::array& data,
|
||||
float conf_threshold, float nms_iou_threshold) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
||||
return res;
|
||||
})
|
||||
.def_readwrite("size", &vision::wongkinyiu::ScaledYOLOv4::size)
|
||||
.def_readwrite("padding_value", &vision::wongkinyiu::ScaledYOLOv4::padding_value)
|
||||
.def_readwrite("is_mini_pad", &vision::wongkinyiu::ScaledYOLOv4::is_mini_pad)
|
||||
.def_readwrite("is_no_pad", &vision::wongkinyiu::ScaledYOLOv4::is_no_pad)
|
||||
.def_readwrite("is_scale_up", &vision::wongkinyiu::ScaledYOLOv4::is_scale_up)
|
||||
.def_readwrite("stride", &vision::wongkinyiu::ScaledYOLOv4::stride)
|
||||
.def_readwrite("max_wh", &vision::wongkinyiu::ScaledYOLOv4::max_wh);
|
||||
}
|
||||
} // namespace fastdeploy
|
@@ -20,8 +20,6 @@ from . import ppseg
|
||||
from . import ultralytics
|
||||
from . import meituan
|
||||
from . import megvii
|
||||
from . import visualize
|
||||
from . import wongkinyiu
|
||||
from . import deepcam
|
||||
from . import rangilyu
|
||||
from . import linzaer
|
||||
@@ -29,3 +27,7 @@ from . import biubug6
|
||||
from . import ppogg
|
||||
from . import deepinsight
|
||||
from . import zhkkke
|
||||
|
||||
from . import detection
|
||||
|
||||
from .visualize import *
|
||||
|
18
fastdeploy/vision/detection/__init__.py
Normal file
18
fastdeploy/vision/detection/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from .yolov7 import YOLOv7
|
||||
from .yolor import YOLOR
|
||||
from .scaled_yolov4 import ScaledYOLOv4
|
116
fastdeploy/vision/detection/scaled_yolov4.py
Normal file
116
fastdeploy/vision/detection/scaled_yolov4.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
import logging
|
||||
from ... import FastDeployModel, Frontend
|
||||
from ... import c_lib_wrap as C
|
||||
|
||||
|
||||
class ScaledYOLOv4(FastDeployModel):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
params_file="",
|
||||
runtime_option=None,
|
||||
model_format=Frontend.ONNX):
|
||||
# 调用基函数进行backend_option的初始化
|
||||
# 初始化后的option保存在self._runtime_option
|
||||
super(ScaledYOLOv4, self).__init__(runtime_option)
|
||||
|
||||
self._model = C.vision.detection.ScaledYOLOv4(
|
||||
model_file, params_file, self._runtime_option, model_format)
|
||||
# 通过self.initialized判断整个模型的初始化是否成功
|
||||
assert self.initialized, "ScaledYOLOv4 initialize failed."
|
||||
|
||||
def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5):
|
||||
return self._model.predict(input_image, conf_threshold,
|
||||
nms_iou_threshold)
|
||||
|
||||
# 一些跟ScaledYOLOv4模型有关的属性封装
|
||||
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
|
||||
@property
|
||||
def size(self):
|
||||
return self._model.size
|
||||
|
||||
@property
|
||||
def padding_value(self):
|
||||
return self._model.padding_value
|
||||
|
||||
@property
|
||||
def is_no_pad(self):
|
||||
return self._model.is_no_pad
|
||||
|
||||
@property
|
||||
def is_mini_pad(self):
|
||||
return self._model.is_mini_pad
|
||||
|
||||
@property
|
||||
def is_scale_up(self):
|
||||
return self._model.is_scale_up
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
return self._model.stride
|
||||
|
||||
@property
|
||||
def max_wh(self):
|
||||
return self._model.max_wh
|
||||
|
||||
@size.setter
|
||||
def size(self, wh):
|
||||
assert isinstance(wh, [list, tuple]),\
|
||||
"The value to set `size` must be type of tuple or list."
|
||||
assert len(wh) == 2,\
|
||||
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
|
||||
len(wh))
|
||||
self._model.size = wh
|
||||
|
||||
@padding_value.setter
|
||||
def padding_value(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
list), "The value to set `padding_value` must be type of list."
|
||||
self._model.padding_value = value
|
||||
|
||||
@is_no_pad.setter
|
||||
def is_no_pad(self, value):
|
||||
assert isinstance(
|
||||
value, bool), "The value to set `is_no_pad` must be type of bool."
|
||||
self._model.is_no_pad = value
|
||||
|
||||
@is_mini_pad.setter
|
||||
def is_mini_pad(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
bool), "The value to set `is_mini_pad` must be type of bool."
|
||||
self._model.is_mini_pad = value
|
||||
|
||||
@is_scale_up.setter
|
||||
def is_scale_up(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
bool), "The value to set `is_scale_up` must be type of bool."
|
||||
self._model.is_scale_up = value
|
||||
|
||||
@stride.setter
|
||||
def stride(self, value):
|
||||
assert isinstance(
|
||||
value, int), "The value to set `stride` must be type of int."
|
||||
self._model.stride = value
|
||||
|
||||
@max_wh.setter
|
||||
def max_wh(self, value):
|
||||
assert isinstance(
|
||||
value, float), "The value to set `max_wh` must be type of float."
|
||||
self._model.max_wh = value
|
116
fastdeploy/vision/detection/yolor.py
Normal file
116
fastdeploy/vision/detection/yolor.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
import logging
|
||||
from ... import FastDeployModel, Frontend
|
||||
from ... import c_lib_wrap as C
|
||||
|
||||
|
||||
class YOLOR(FastDeployModel):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
params_file="",
|
||||
runtime_option=None,
|
||||
model_format=Frontend.ONNX):
|
||||
# 调用基函数进行backend_option的初始化
|
||||
# 初始化后的option保存在self._runtime_option
|
||||
super(YOLOR, self).__init__(runtime_option)
|
||||
|
||||
self._model = C.vision.detection.YOLOR(
|
||||
model_file, params_file, self._runtime_option, model_format)
|
||||
# 通过self.initialized判断整个模型的初始化是否成功
|
||||
assert self.initialized, "YOLOR initialize failed."
|
||||
|
||||
def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5):
|
||||
return self._model.predict(input_image, conf_threshold,
|
||||
nms_iou_threshold)
|
||||
|
||||
# 一些跟YOLOR模型有关的属性封装
|
||||
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
|
||||
@property
|
||||
def size(self):
|
||||
return self._model.size
|
||||
|
||||
@property
|
||||
def padding_value(self):
|
||||
return self._model.padding_value
|
||||
|
||||
@property
|
||||
def is_no_pad(self):
|
||||
return self._model.is_no_pad
|
||||
|
||||
@property
|
||||
def is_mini_pad(self):
|
||||
return self._model.is_mini_pad
|
||||
|
||||
@property
|
||||
def is_scale_up(self):
|
||||
return self._model.is_scale_up
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
return self._model.stride
|
||||
|
||||
@property
|
||||
def max_wh(self):
|
||||
return self._model.max_wh
|
||||
|
||||
@size.setter
|
||||
def size(self, wh):
|
||||
assert isinstance(wh, [list, tuple]),\
|
||||
"The value to set `size` must be type of tuple or list."
|
||||
assert len(wh) == 2,\
|
||||
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
|
||||
len(wh))
|
||||
self._model.size = wh
|
||||
|
||||
@padding_value.setter
|
||||
def padding_value(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
list), "The value to set `padding_value` must be type of list."
|
||||
self._model.padding_value = value
|
||||
|
||||
@is_no_pad.setter
|
||||
def is_no_pad(self, value):
|
||||
assert isinstance(
|
||||
value, bool), "The value to set `is_no_pad` must be type of bool."
|
||||
self._model.is_no_pad = value
|
||||
|
||||
@is_mini_pad.setter
|
||||
def is_mini_pad(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
bool), "The value to set `is_mini_pad` must be type of bool."
|
||||
self._model.is_mini_pad = value
|
||||
|
||||
@is_scale_up.setter
|
||||
def is_scale_up(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
bool), "The value to set `is_scale_up` must be type of bool."
|
||||
self._model.is_scale_up = value
|
||||
|
||||
@stride.setter
|
||||
def stride(self, value):
|
||||
assert isinstance(
|
||||
value, int), "The value to set `stride` must be type of int."
|
||||
self._model.stride = value
|
||||
|
||||
@max_wh.setter
|
||||
def max_wh(self, value):
|
||||
assert isinstance(
|
||||
value, float), "The value to set `max_wh` must be type of float."
|
||||
self._model.max_wh = value
|
116
fastdeploy/vision/detection/yolov7.py
Normal file
116
fastdeploy/vision/detection/yolov7.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
import logging
|
||||
from ... import FastDeployModel, Frontend
|
||||
from ... import c_lib_wrap as C
|
||||
|
||||
|
||||
class YOLOv7(FastDeployModel):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
params_file="",
|
||||
runtime_option=None,
|
||||
model_format=Frontend.ONNX):
|
||||
# 调用基函数进行backend_option的初始化
|
||||
# 初始化后的option保存在self._runtime_option
|
||||
super(YOLOv7, self).__init__(runtime_option)
|
||||
|
||||
self._model = C.vision.detection.YOLOv7(
|
||||
model_file, params_file, self._runtime_option, model_format)
|
||||
# 通过self.initialized判断整个模型的初始化是否成功
|
||||
assert self.initialized, "YOLOv7 initialize failed."
|
||||
|
||||
def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5):
|
||||
return self._model.predict(input_image, conf_threshold,
|
||||
nms_iou_threshold)
|
||||
|
||||
# 一些跟YOLOv7模型有关的属性封装
|
||||
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
|
||||
@property
|
||||
def size(self):
|
||||
return self._model.size
|
||||
|
||||
@property
|
||||
def padding_value(self):
|
||||
return self._model.padding_value
|
||||
|
||||
@property
|
||||
def is_no_pad(self):
|
||||
return self._model.is_no_pad
|
||||
|
||||
@property
|
||||
def is_mini_pad(self):
|
||||
return self._model.is_mini_pad
|
||||
|
||||
@property
|
||||
def is_scale_up(self):
|
||||
return self._model.is_scale_up
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
return self._model.stride
|
||||
|
||||
@property
|
||||
def max_wh(self):
|
||||
return self._model.max_wh
|
||||
|
||||
@size.setter
|
||||
def size(self, wh):
|
||||
assert isinstance(wh, [list, tuple]),\
|
||||
"The value to set `size` must be type of tuple or list."
|
||||
assert len(wh) == 2,\
|
||||
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
|
||||
len(wh))
|
||||
self._model.size = wh
|
||||
|
||||
@padding_value.setter
|
||||
def padding_value(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
list), "The value to set `padding_value` must be type of list."
|
||||
self._model.padding_value = value
|
||||
|
||||
@is_no_pad.setter
|
||||
def is_no_pad(self, value):
|
||||
assert isinstance(
|
||||
value, bool), "The value to set `is_no_pad` must be type of bool."
|
||||
self._model.is_no_pad = value
|
||||
|
||||
@is_mini_pad.setter
|
||||
def is_mini_pad(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
bool), "The value to set `is_mini_pad` must be type of bool."
|
||||
self._model.is_mini_pad = value
|
||||
|
||||
@is_scale_up.setter
|
||||
def is_scale_up(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
bool), "The value to set `is_scale_up` must be type of bool."
|
||||
self._model.is_scale_up = value
|
||||
|
||||
@stride.setter
|
||||
def stride(self, value):
|
||||
assert isinstance(
|
||||
value, int), "The value to set `stride` must be type of int."
|
||||
self._model.stride = value
|
||||
|
||||
@max_wh.setter
|
||||
def max_wh(self, value):
|
||||
assert isinstance(
|
||||
value, float), "The value to set `max_wh` must be type of float."
|
||||
self._model.max_wh = value
|
@@ -17,23 +17,22 @@ import logging
|
||||
from ... import c_lib_wrap as C
|
||||
|
||||
|
||||
def vis_detection(im_data, det_result, line_size=1, font_size=0.5):
|
||||
C.vision.Visualize.vis_detection(im_data, det_result, line_size, font_size)
|
||||
def vis_detection(im_data, det_result, line_size=2, font_size=0.5):
|
||||
return C.vision.Visualize.vis_detection(im_data, det_result, line_size,
|
||||
font_size)
|
||||
|
||||
|
||||
def vis_face_detection(im_data, face_det_result, line_size=1, font_size=0.5):
|
||||
C.vision.Visualize.vis_face_detection(im_data, face_det_result, line_size,
|
||||
font_size)
|
||||
def vis_face_detection(im_data, face_det_result, line_size=2, font_size=0.5):
|
||||
return C.vision.Visualize.vis_face_detection(im_data, face_det_result,
|
||||
line_size, font_size)
|
||||
|
||||
|
||||
def vis_segmentation(im_data, seg_result, vis_im_data, num_classes=1000):
|
||||
C.vision.Visualize.vis_segmentation(im_data, seg_result, vis_im_data,
|
||||
num_classes)
|
||||
def vis_segmentation(im_data, seg_result):
|
||||
return C.vision.Visualize.vis_segmentation(im_data, seg_result)
|
||||
|
||||
|
||||
def vis_matting_alpha(im_data,
|
||||
matting_result,
|
||||
vis_im_data,
|
||||
remove_small_connected_area=False):
|
||||
C.vision.Visualize.vis_matting_alpha(im_data, matting_result, vis_im_data,
|
||||
remove_small_connected_area)
|
||||
return C.vision.Visualize.vis_matting_alpha(im_data, matting_result,
|
||||
remove_small_connected_area)
|
||||
|
@@ -16,8 +16,8 @@ im = cv2.imread("3.jpg")
|
||||
result = model.predict(im, conf_threshold=0.7, nms_iou_threshold=0.3)
|
||||
|
||||
# 可视化结果
|
||||
fd.vision.visualize.vis_face_detection(im, result)
|
||||
cv2.imwrite("vis_result.jpg", im)
|
||||
vis_im = fd.vision.visualize.vis_face_detection(im, result)
|
||||
cv2.imwrite("vis_result.jpg", vis_im)
|
||||
|
||||
# 输出预测结果
|
||||
print(result)
|
||||
|
Reference in New Issue
Block a user