Modify yolov7 and visualize functions (#82)

modify yolov7 and visualize functions
This commit is contained in:
Jason
2022-08-09 10:16:41 +08:00
committed by GitHub
parent 36b2bfaa33
commit b2cd30e64f
29 changed files with 653 additions and 249 deletions

View File

@@ -30,6 +30,8 @@ pybind11::dtype FDDataTypeToNumpyDataType(const FDDataType& fd_dtype) {
dt = pybind11::dtype::of<float>(); dt = pybind11::dtype::of<float>();
} else if (fd_dtype == FDDataType::FP64) { } else if (fd_dtype == FDDataType::FP64) {
dt = pybind11::dtype::of<double>(); dt = pybind11::dtype::of<double>();
} else if (fd_dtype == FDDataType::UINT8) {
dt = pybind11::dtype::of<uint8_t>();
} else { } else {
FDASSERT(false, "The function doesn't support data type of " + FDASSERT(false, "The function doesn't support data type of " +
Str(fd_dtype) + "."); Str(fd_dtype) + ".");
@@ -46,6 +48,8 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype) {
return FDDataType::FP32; return FDDataType::FP32;
} else if (np_dtype.is(pybind11::dtype::of<double>())) { } else if (np_dtype.is(pybind11::dtype::of<double>())) {
return FDDataType::FP64; return FDDataType::FP64;
} else if (np_dtype.is(pybind11::dtype::of<uint8_t>())) {
return FDDataType::UINT8;
} }
FDASSERT(false, FDASSERT(false,
"NumpyDataTypeToFDDataType() only support " "NumpyDataTypeToFDDataType() only support "
@@ -66,6 +70,13 @@ void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
} }
} }
pybind11::array TensorToPyArray(const FDTensor& tensor) {
auto numpy_dtype = FDDataTypeToNumpyDataType(tensor.dtype);
auto out = pybind11::array(numpy_dtype, tensor.shape);
memcpy(out.mutable_data(), tensor.Data(), tensor.Numel() * FDDataTypeSize(tensor.dtype));
return out;
}
#ifdef ENABLE_VISION #ifdef ENABLE_VISION
int NumpyDataTypeToOpenCvType(const pybind11::dtype& np_dtype) { int NumpyDataTypeToOpenCvType(const pybind11::dtype& np_dtype) {
if (np_dtype.is(pybind11::dtype::of<int32_t>())) { if (np_dtype.is(pybind11::dtype::of<int32_t>())) {

View File

@@ -36,12 +36,14 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype);
void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor, void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
bool share_buffer = false); bool share_buffer = false);
pybind11::array TensorToPyArray(const FDTensor& tensor);
#ifdef ENABLE_VISION #ifdef ENABLE_VISION
cv::Mat PyArrayToCvMat(pybind11::array& pyarray); cv::Mat PyArrayToCvMat(pybind11::array& pyarray);
#endif #endif
template <typename T> FDDataType CTypeToFDDataType() { template <typename T>
FDDataType CTypeToFDDataType() {
if (std::is_same<T, int32_t>::value) { if (std::is_same<T, int32_t>::value) {
return FDDataType::INT32; return FDDataType::INT32;
} else if (std::is_same<T, int64_t>::value) { } else if (std::is_same<T, int64_t>::value) {
@@ -57,9 +59,9 @@ template <typename T> FDDataType CTypeToFDDataType() {
} }
template <typename T> template <typename T>
std::vector<pybind11::array> std::vector<pybind11::array> PyBackendInfer(
PyBackendInfer(T& self, const std::vector<std::string>& names, T& self, const std::vector<std::string>& names,
std::vector<pybind11::array>& data) { std::vector<pybind11::array>& data) {
std::vector<FDTensor> inputs(data.size()); std::vector<FDTensor> inputs(data.size());
for (size_t i = 0; i < data.size(); ++i) { for (size_t i = 0; i < data.size(); ++i) {
// TODO(jiangjiajun) here is considered to use user memory directly // TODO(jiangjiajun) here is considered to use user memory directly
@@ -85,4 +87,4 @@ PyBackendInfer(T& self, const std::vector<std::string>& names,
return results; return results;
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -23,6 +23,9 @@
#include "fastdeploy/vision/deepinsight/partial_fc.h" #include "fastdeploy/vision/deepinsight/partial_fc.h"
#include "fastdeploy/vision/deepinsight/scrfd.h" #include "fastdeploy/vision/deepinsight/scrfd.h"
#include "fastdeploy/vision/deepinsight/vpl.h" #include "fastdeploy/vision/deepinsight/vpl.h"
#include "fastdeploy/vision/detection/contrib/scaledyolov4.h"
#include "fastdeploy/vision/detection/contrib/yolor.h"
#include "fastdeploy/vision/detection/contrib/yolov7.h"
#include "fastdeploy/vision/linzaer/ultraface.h" #include "fastdeploy/vision/linzaer/ultraface.h"
#include "fastdeploy/vision/megvii/yolox.h" #include "fastdeploy/vision/megvii/yolox.h"
#include "fastdeploy/vision/meituan/yolov6.h" #include "fastdeploy/vision/meituan/yolov6.h"
@@ -32,9 +35,6 @@
#include "fastdeploy/vision/ppseg/model.h" #include "fastdeploy/vision/ppseg/model.h"
#include "fastdeploy/vision/rangilyu/nanodet_plus.h" #include "fastdeploy/vision/rangilyu/nanodet_plus.h"
#include "fastdeploy/vision/ultralytics/yolov5.h" #include "fastdeploy/vision/ultralytics/yolov5.h"
#include "fastdeploy/vision/wongkinyiu/scaledyolov4.h"
#include "fastdeploy/vision/wongkinyiu/yolor.h"
#include "fastdeploy/vision/wongkinyiu/yolov7.h"
#include "fastdeploy/vision/zhkkke/modnet.h" #include "fastdeploy/vision/zhkkke/modnet.h"
#endif #endif

View File

@@ -27,7 +27,7 @@ namespace vision {
enum Layout { HWC, CHW }; enum Layout { HWC, CHW };
struct Mat { struct FASTDEPLOY_DECL Mat {
explicit Mat(cv::Mat& mat) { explicit Mat(cv::Mat& mat) {
cpu_mat = mat; cpu_mat = mat;
device = Device::CPU; device = Device::CPU;
@@ -76,5 +76,5 @@ struct Mat {
Device device = Device::CPU; Device device = Device::CPU;
}; };
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "fastdeploy/vision/wongkinyiu/scaledyolov4.h" #include "fastdeploy/vision/detection/contrib/scaledyolov4.h"
#include "fastdeploy/utils/perf.h" #include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h" #include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
namespace wongkinyiu { namespace detection {
void ScaledYOLOv4::LetterBox(Mat* mat, const std::vector<int>& size, void ScaledYOLOv4::LetterBox(Mat* mat, const std::vector<int>& size,
const std::vector<float>& color, bool _auto, const std::vector<float>& color, bool _auto,
@@ -65,8 +65,8 @@ ScaledYOLOv4::ScaledYOLOv4(const std::string& model_file,
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端 valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端 valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
} else { } else {
valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; valid_cpu_backends = {Backend::PDINFER};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; valid_gpu_backends = {Backend::PDINFER};
} }
runtime_option = custom_option; runtime_option = custom_option;
runtime_option.model_format = model_format; runtime_option.model_format = model_format;
@@ -219,10 +219,6 @@ bool ScaledYOLOv4::Postprocess(
bool ScaledYOLOv4::Predict(cv::Mat* im, DetectionResult* result, bool ScaledYOLOv4::Predict(cv::Mat* im, DetectionResult* result,
float conf_threshold, float nms_iou_threshold) { float conf_threshold, float nms_iou_threshold) {
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_START(0)
#endif
Mat mat(*im); Mat mat(*im);
std::vector<FDTensor> input_tensors(1); std::vector<FDTensor> input_tensors(1);
@@ -239,34 +235,21 @@ bool ScaledYOLOv4::Predict(cv::Mat* im, DetectionResult* result,
return false; return false;
} }
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(0, "Preprocess")
TIMERECORD_START(1)
#endif
input_tensors[0].name = InputInfoOfRuntime(0).name; input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors; std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) { if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl; FDERROR << "Failed to inference." << std::endl;
return false; return false;
} }
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(1, "Inference")
TIMERECORD_START(2)
#endif
if (!Postprocess(output_tensors[0], result, im_info, conf_threshold, if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
nms_iou_threshold)) { nms_iou_threshold)) {
FDERROR << "Failed to post process." << std::endl; FDERROR << "Failed to post process." << std::endl;
return false; return false;
} }
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(2, "Postprocess")
#endif
return true; return true;
} }
} // namespace wongkinyiu } // namespace detection
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -19,7 +19,7 @@
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
namespace wongkinyiu { namespace detection {
class FASTDEPLOY_DECL ScaledYOLOv4 : public FastDeployModel { class FASTDEPLOY_DECL ScaledYOLOv4 : public FastDeployModel {
public: public:
@@ -31,7 +31,7 @@ class FASTDEPLOY_DECL ScaledYOLOv4 : public FastDeployModel {
const Frontend& model_format = Frontend::ONNX); const Frontend& model_format = Frontend::ONNX);
// 定义模型的名称 // 定义模型的名称
virtual std::string ModelName() const { return "WongKinYiu/ScaledYOLOv4"; } virtual std::string ModelName() const { return "ScaledYOLOv4"; }
// 模型预测接口,即用户调用的接口 // 模型预测接口,即用户调用的接口
// im 为用户的输入数据目前对于CV均定义为cv::Mat // im 为用户的输入数据目前对于CV均定义为cv::Mat
@@ -98,6 +98,6 @@ class FASTDEPLOY_DECL ScaledYOLOv4 : public FastDeployModel {
// auto check by fastdeploy after the internal Runtime already initialized. // auto check by fastdeploy after the internal Runtime already initialized.
bool is_dynamic_input_; bool is_dynamic_input_;
}; };
} // namespace wongkinyiu } // namespace detection
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -0,0 +1,41 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindScaledYOLOv4(pybind11::module& m) {
pybind11::class_<vision::detection::ScaledYOLOv4, FastDeployModel>(
m, "ScaledYOLOv4")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::detection::ScaledYOLOv4& self, pybind11::array& data,
float conf_threshold, float nms_iou_threshold) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
})
.def_readwrite("size", &vision::detection::ScaledYOLOv4::size)
.def_readwrite("padding_value",
&vision::detection::ScaledYOLOv4::padding_value)
.def_readwrite("is_mini_pad",
&vision::detection::ScaledYOLOv4::is_mini_pad)
.def_readwrite("is_no_pad", &vision::detection::ScaledYOLOv4::is_no_pad)
.def_readwrite("is_scale_up",
&vision::detection::ScaledYOLOv4::is_scale_up)
.def_readwrite("stride", &vision::detection::ScaledYOLOv4::stride)
.def_readwrite("max_wh", &vision::detection::ScaledYOLOv4::max_wh);
}
} // namespace fastdeploy

View File

@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "fastdeploy/vision/wongkinyiu/yolor.h" #include "fastdeploy/vision/detection/contrib/yolor.h"
#include "fastdeploy/utils/perf.h" #include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h" #include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
namespace wongkinyiu { namespace detection {
void YOLOR::LetterBox(Mat* mat, const std::vector<int>& size, void YOLOR::LetterBox(Mat* mat, const std::vector<int>& size,
const std::vector<float>& color, bool _auto, const std::vector<float>& color, bool _auto,
@@ -63,8 +63,8 @@ YOLOR::YOLOR(const std::string& model_file, const std::string& params_file,
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端 valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端 valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
} else { } else {
valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; valid_cpu_backends = {Backend::PDINFER};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; valid_gpu_backends = {Backend::PDINFER};
} }
runtime_option = custom_option; runtime_option = custom_option;
runtime_option.model_format = model_format; runtime_option.model_format = model_format;
@@ -216,10 +216,6 @@ bool YOLOR::Postprocess(
bool YOLOR::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold, bool YOLOR::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
float nms_iou_threshold) { float nms_iou_threshold) {
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_START(0)
#endif
Mat mat(*im); Mat mat(*im);
std::vector<FDTensor> input_tensors(1); std::vector<FDTensor> input_tensors(1);
@@ -236,21 +232,12 @@ bool YOLOR::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
return false; return false;
} }
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(0, "Preprocess")
TIMERECORD_START(1)
#endif
input_tensors[0].name = InputInfoOfRuntime(0).name; input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors; std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) { if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl; FDERROR << "Failed to inference." << std::endl;
return false; return false;
} }
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(1, "Inference")
TIMERECORD_START(2)
#endif
if (!Postprocess(output_tensors[0], result, im_info, conf_threshold, if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
nms_iou_threshold)) { nms_iou_threshold)) {
@@ -258,12 +245,9 @@ bool YOLOR::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
return false; return false;
} }
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(2, "Postprocess")
#endif
return true; return true;
} }
} // namespace wongkinyiu } // namespace detection
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -19,7 +19,7 @@
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
namespace wongkinyiu { namespace detection {
class FASTDEPLOY_DECL YOLOR : public FastDeployModel { class FASTDEPLOY_DECL YOLOR : public FastDeployModel {
public: public:
@@ -30,7 +30,7 @@ class FASTDEPLOY_DECL YOLOR : public FastDeployModel {
const Frontend& model_format = Frontend::ONNX); const Frontend& model_format = Frontend::ONNX);
// 定义模型的名称 // 定义模型的名称
virtual std::string ModelName() const { return "WongKinYiu/yolor"; } virtual std::string ModelName() const { return "YOLOR"; }
// 模型预测接口,即用户调用的接口 // 模型预测接口,即用户调用的接口
// im 为用户的输入数据目前对于CV均定义为cv::Mat // im 为用户的输入数据目前对于CV均定义为cv::Mat
@@ -97,6 +97,6 @@ class FASTDEPLOY_DECL YOLOR : public FastDeployModel {
// auto check by fastdeploy after the internal Runtime already initialized. // auto check by fastdeploy after the internal Runtime already initialized.
bool is_dynamic_input_; bool is_dynamic_input_;
}; };
} // namespace wongkinyiu } // namespace detection
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -0,0 +1,37 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindYOLOR(pybind11::module& m) {
pybind11::class_<vision::detection::YOLOR, FastDeployModel>(m, "YOLOR")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::detection::YOLOR& self, pybind11::array& data,
float conf_threshold, float nms_iou_threshold) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
})
.def_readwrite("size", &vision::detection::YOLOR::size)
.def_readwrite("padding_value", &vision::detection::YOLOR::padding_value)
.def_readwrite("is_mini_pad", &vision::detection::YOLOR::is_mini_pad)
.def_readwrite("is_no_pad", &vision::detection::YOLOR::is_no_pad)
.def_readwrite("is_scale_up", &vision::detection::YOLOR::is_scale_up)
.def_readwrite("stride", &vision::detection::YOLOR::stride)
.def_readwrite("max_wh", &vision::detection::YOLOR::max_wh);
}
} // namespace fastdeploy

View File

@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "fastdeploy/vision/wongkinyiu/yolov7.h" #include "fastdeploy/vision/detection/contrib/yolov7.h"
#include "fastdeploy/utils/perf.h" #include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h" #include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
namespace wongkinyiu { namespace detection {
void YOLOv7::LetterBox(Mat* mat, const std::vector<int>& size, void YOLOv7::LetterBox(Mat* mat, const std::vector<int>& size,
const std::vector<float>& color, bool _auto, const std::vector<float>& color, bool _auto,
@@ -64,13 +64,12 @@ YOLOv7::YOLOv7(const std::string& model_file, const std::string& params_file,
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端 valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端 valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
} else { } else {
valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; valid_cpu_backends = {Backend::PDINFER};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; valid_gpu_backends = {Backend::PDINFER};
} }
runtime_option = custom_option; runtime_option = custom_option;
runtime_option.model_format = model_format; runtime_option.model_format = model_format;
runtime_option.model_file = model_file; runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize(); initialized = Initialize();
} }
@@ -217,10 +216,6 @@ bool YOLOv7::Postprocess(
bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold, bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
float nms_iou_threshold) { float nms_iou_threshold) {
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_START(0)
#endif
Mat mat(*im); Mat mat(*im);
std::vector<FDTensor> input_tensors(1); std::vector<FDTensor> input_tensors(1);
@@ -237,21 +232,12 @@ bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
return false; return false;
} }
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(0, "Preprocess")
TIMERECORD_START(1)
#endif
input_tensors[0].name = InputInfoOfRuntime(0).name; input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors; std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) { if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl; FDERROR << "Failed to inference." << std::endl;
return false; return false;
} }
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(1, "Inference")
TIMERECORD_START(2)
#endif
if (!Postprocess(output_tensors[0], result, im_info, conf_threshold, if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
nms_iou_threshold)) { nms_iou_threshold)) {
@@ -259,12 +245,9 @@ bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
return false; return false;
} }
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(2, "Postprocess")
#endif
return true; return true;
} }
} // namespace wongkinyiu } // namespace detection
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -19,18 +19,16 @@
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
namespace wongkinyiu { namespace detection {
class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel { class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel {
public: public:
// 当model_format为ONNX时无需指定params_file
// 当model_format为Paddle时则需同时指定model_file & params_file
YOLOv7(const std::string& model_file, const std::string& params_file = "", YOLOv7(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(), const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX); const Frontend& model_format = Frontend::ONNX);
// 定义模型的名称 // 定义模型的名称
virtual std::string ModelName() const { return "WongKinYiu/yolov7"; } virtual std::string ModelName() const { return "yolov7"; }
// 模型预测接口,即用户调用的接口 // 模型预测接口,即用户调用的接口
// im 为用户的输入数据目前对于CV均定义为cv::Mat // im 为用户的输入数据目前对于CV均定义为cv::Mat
@@ -97,6 +95,6 @@ class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel {
// auto check by fastdeploy after the internal Runtime already initialized. // auto check by fastdeploy after the internal Runtime already initialized.
bool is_dynamic_input_; bool is_dynamic_input_;
}; };
} // namespace wongkinyiu } // namespace detection
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -0,0 +1,37 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindYOLOv7(pybind11::module& m) {
pybind11::class_<vision::detection::YOLOv7, FastDeployModel>(m, "YOLOv7")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::detection::YOLOv7& self, pybind11::array& data,
float conf_threshold, float nms_iou_threshold) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
})
.def_readwrite("size", &vision::detection::YOLOv7::size)
.def_readwrite("padding_value", &vision::detection::YOLOv7::padding_value)
.def_readwrite("is_mini_pad", &vision::detection::YOLOv7::is_mini_pad)
.def_readwrite("is_no_pad", &vision::detection::YOLOv7::is_no_pad)
.def_readwrite("is_scale_up", &vision::detection::YOLOv7::is_scale_up)
.def_readwrite("stride", &vision::detection::YOLOv7::stride)
.def_readwrite("max_wh", &vision::detection::YOLOv7::max_wh);
}
} // namespace fastdeploy

View File

@@ -0,0 +1,30 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindYOLOv7(pybind11::module& m);
void BindScaledYOLOv4(pybind11::module& m);
void BindYOLOR(pybind11::module& m);
void BindDetection(pybind11::module& m) {
auto detection_module =
m.def_submodule("detection", "Image object detection models.");
BindYOLOv7(detection_module);
BindScaledYOLOv4(detection_module);
BindYOLOR(detection_module);
}
} // namespace fastdeploy

View File

@@ -18,7 +18,6 @@ namespace fastdeploy {
void BindPPCls(pybind11::module& m); void BindPPCls(pybind11::module& m);
void BindPPDet(pybind11::module& m); void BindPPDet(pybind11::module& m);
void BindWongkinyiu(pybind11::module& m);
void BindPPSeg(pybind11::module& m); void BindPPSeg(pybind11::module& m);
void BindUltralytics(pybind11::module& m); void BindUltralytics(pybind11::module& m);
void BindMeituan(pybind11::module& m); void BindMeituan(pybind11::module& m);
@@ -30,6 +29,8 @@ void BindBiubug6(pybind11::module& m);
void BindPpogg(pybind11::module& m); void BindPpogg(pybind11::module& m);
void BindDeepInsight(pybind11::module& m); void BindDeepInsight(pybind11::module& m);
void BindZHKKKe(pybind11::module& m); void BindZHKKKe(pybind11::module& m);
void BindDetection(pybind11::module& m);
#ifdef ENABLE_VISION_VISUALIZE #ifdef ENABLE_VISION_VISUALIZE
void BindVisualize(pybind11::module& m); void BindVisualize(pybind11::module& m);
#endif #endif
@@ -88,7 +89,6 @@ void BindVision(pybind11::module& m) {
BindPPDet(m); BindPPDet(m);
BindPPSeg(m); BindPPSeg(m);
BindUltralytics(m); BindUltralytics(m);
BindWongkinyiu(m);
BindMeituan(m); BindMeituan(m);
BindMegvii(m); BindMegvii(m);
BindDeepCam(m); BindDeepCam(m);
@@ -98,6 +98,8 @@ void BindVision(pybind11::module& m) {
BindPpogg(m); BindPpogg(m);
BindDeepInsight(m); BindDeepInsight(m);
BindZHKKKe(m); BindZHKKKe(m);
BindDetection(m);
#ifdef ENABLE_VISION_VISUALIZE #ifdef ENABLE_VISION_VISUALIZE
BindVisualize(m); BindVisualize(m);
#endif #endif

View File

@@ -23,11 +23,13 @@ namespace vision {
// Default only support visualize num_classes <= 1000 // Default only support visualize num_classes <= 1000
// If need to visualize num_classes > 1000 // If need to visualize num_classes > 1000
// Please call Visualize::GetColorMap(num_classes) first // Please call Visualize::GetColorMap(num_classes) first
void Visualize::VisDetection(cv::Mat* im, const DetectionResult& result, cv::Mat Visualize::VisDetection(const cv::Mat& im,
int line_size, float font_size) { const DetectionResult& result, int line_size,
float font_size) {
auto color_map = GetColorMap(); auto color_map = GetColorMap();
int h = im->rows; int h = im.rows;
int w = im->cols; int w = im.cols;
auto vis_im = im.clone();
for (size_t i = 0; i < result.boxes.size(); ++i) { for (size_t i = 0; i < result.boxes.size(); ++i) {
cv::Rect rect(result.boxes[i][0], result.boxes[i][1], cv::Rect rect(result.boxes[i][0], result.boxes[i][1],
result.boxes[i][2] - result.boxes[i][0], result.boxes[i][2] - result.boxes[i][0],
@@ -50,10 +52,11 @@ void Visualize::VisDetection(cv::Mat* im, const DetectionResult& result,
cv::Rect text_background = cv::Rect text_background =
cv::Rect(result.boxes[i][0], result.boxes[i][1] - text_size.height, cv::Rect(result.boxes[i][0], result.boxes[i][1] - text_size.height,
text_size.width, text_size.height); text_size.width, text_size.height);
cv::rectangle(*im, rect, rect_color, line_size); cv::rectangle(vis_im, rect, rect_color, line_size);
cv::putText(*im, text, origin, font, font_size, cv::Scalar(255, 255, 255), cv::putText(vis_im, text, origin, font, font_size,
1); cv::Scalar(255, 255, 255), 1);
} }
return vis_im;
} }
} // namespace vision } // namespace vision

View File

@@ -24,12 +24,14 @@ namespace vision {
// Default only support visualize num_classes <= 1000 // Default only support visualize num_classes <= 1000
// If need to visualize num_classes > 1000 // If need to visualize num_classes > 1000
// Please call Visualize::GetColorMap(num_classes) first // Please call Visualize::GetColorMap(num_classes) first
void Visualize::VisFaceDetection(cv::Mat* im, const FaceDetectionResult& result, cv::Mat Visualize::VisFaceDetection(const cv::Mat& im,
int line_size, float font_size) { const FaceDetectionResult& result,
int line_size, float font_size) {
auto color_map = GetColorMap(); auto color_map = GetColorMap();
int h = im->rows; int h = im.rows;
int w = im->cols; int w = im.cols;
auto vis_im = im.clone();
bool vis_landmarks = false; bool vis_landmarks = false;
if ((result.landmarks_per_face > 0) && if ((result.landmarks_per_face > 0) &&
(result.boxes.size() * result.landmarks_per_face == (result.boxes.size() * result.landmarks_per_face ==
@@ -57,9 +59,9 @@ void Visualize::VisFaceDetection(cv::Mat* im, const FaceDetectionResult& result,
cv::Rect text_background = cv::Rect text_background =
cv::Rect(result.boxes[i][0], result.boxes[i][1] - text_size.height, cv::Rect(result.boxes[i][0], result.boxes[i][1] - text_size.height,
text_size.width, text_size.height); text_size.width, text_size.height);
cv::rectangle(*im, rect, rect_color, line_size); cv::rectangle(vis_im, rect, rect_color, line_size);
cv::putText(*im, text, origin, font, font_size, cv::Scalar(255, 255, 255), cv::putText(vis_im, text, origin, font, font_size,
1); cv::Scalar(255, 255, 255), 1);
// vis landmarks (if have) // vis landmarks (if have)
if (vis_landmarks) { if (vis_landmarks) {
cv::Scalar landmark_color = rect_color; cv::Scalar landmark_color = rect_color;
@@ -69,10 +71,11 @@ void Visualize::VisFaceDetection(cv::Mat* im, const FaceDetectionResult& result,
result.landmarks[i * result.landmarks_per_face + j][0]); result.landmarks[i * result.landmarks_per_face + j][0]);
landmark.y = static_cast<int>( landmark.y = static_cast<int>(
result.landmarks[i * result.landmarks_per_face + j][1]); result.landmarks[i * result.landmarks_per_face + j][1]);
cv::circle(*im, landmark, line_size, landmark_color, -1); cv::circle(vis_im, landmark, line_size, landmark_color, -1);
} }
} }
} }
return vis_im;
} }
} // namespace vision } // namespace vision

View File

@@ -65,12 +65,14 @@ static void RemoveSmallConnectedArea(cv::Mat* alpha_pred,
} }
} }
void Visualize::VisMattingAlpha(const cv::Mat& im, const MattingResult& result, cv::Mat Visualize::VisMattingAlpha(const cv::Mat& im,
cv::Mat* vis_img, const MattingResult& result,
bool remove_small_connected_area) { bool remove_small_connected_area) {
// 只可视化alphafgr(前景)本身就是一张图 不需要可视化 // 只可视化alphafgr(前景)本身就是一张图 不需要可视化
FDASSERT((!im.empty()), "im can't be empty!"); FDASSERT((!im.empty()), "im can't be empty!");
FDASSERT((im.channels() == 3), "Only support 3 channels mat!"); FDASSERT((im.channels() == 3), "Only support 3 channels mat!");
auto vis_img = im.clone();
int out_h = static_cast<int>(result.shape[0]); int out_h = static_cast<int>(result.shape[0]);
int out_w = static_cast<int>(result.shape[1]); int out_w = static_cast<int>(result.shape[1]);
int height = im.rows; int height = im.rows;
@@ -87,18 +89,11 @@ void Visualize::VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
cv::resize(alpha, alpha, cv::Size(width, height)); cv::resize(alpha, alpha, cv::Size(width, height));
} }
int vis_h = (*vis_img).rows; if ((vis_img).type() != CV_8UC3) {
int vis_w = (*vis_img).cols; (vis_img).convertTo((vis_img), CV_8UC3);
if ((vis_h != height) || (vis_w != width)) {
// faster than resize
(*vis_img) = cv::Mat::zeros(height, width, CV_8UC3);
}
if ((*vis_img).type() != CV_8UC3) {
(*vis_img).convertTo((*vis_img), CV_8UC3);
} }
uchar* vis_data = static_cast<uchar*>(vis_img->data); uchar* vis_data = static_cast<uchar*>(vis_img.data);
uchar* im_data = static_cast<uchar*>(im.data); uchar* im_data = static_cast<uchar*>(im.data);
float* alpha_data = reinterpret_cast<float*>(alpha.data); float* alpha_data = reinterpret_cast<float*>(alpha.data);
@@ -116,6 +111,7 @@ void Visualize::VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
(1.f - alpha_val) * 120.f); (1.f - alpha_val) * 120.f);
} }
} }
return vis_img;
} }
} // namespace vision } // namespace vision

View File

@@ -21,24 +21,24 @@
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
void Visualize::VisSegmentation(const cv::Mat& im, cv::Mat Visualize::VisSegmentation(const cv::Mat& im,
const SegmentationResult& result, const SegmentationResult& result) {
cv::Mat* vis_img, const int& num_classes) { auto color_map = GetColorMap();
auto color_map = GetColorMap(num_classes);
int64_t height = result.shape[0]; int64_t height = result.shape[0];
int64_t width = result.shape[1]; int64_t width = result.shape[1];
*vis_img = cv::Mat::zeros(height, width, CV_8UC3); auto vis_img = cv::Mat(height, width, CV_8UC3);
int64_t index = 0; int64_t index = 0;
for (int i = 0; i < height; i++) { for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) { for (int j = 0; j < width; j++) {
int category_id = result.label_map[index++]; int category_id = result.label_map[index++];
vis_img->at<cv::Vec3b>(i, j)[0] = color_map[3 * category_id + 0]; vis_img.at<cv::Vec3b>(i, j)[0] = color_map[3 * category_id + 0];
vis_img->at<cv::Vec3b>(i, j)[1] = color_map[3 * category_id + 1]; vis_img.at<cv::Vec3b>(i, j)[1] = color_map[3 * category_id + 1];
vis_img->at<cv::Vec3b>(i, j)[2] = color_map[3 * category_id + 2]; vis_img.at<cv::Vec3b>(i, j)[2] = color_map[3 * category_id + 2];
} }
} }
cv::addWeighted(im, .5, *vis_img, .5, 0, *vis_img); cv::addWeighted(im, .5, vis_img, .5, 0, vis_img);
return vis_img;
} }
} // namespace vision } // namespace vision

View File

@@ -25,16 +25,15 @@ class FASTDEPLOY_DECL Visualize {
static int num_classes_; static int num_classes_;
static std::vector<int> color_map_; static std::vector<int> color_map_;
static const std::vector<int>& GetColorMap(int num_classes = 1000); static const std::vector<int>& GetColorMap(int num_classes = 1000);
static void VisDetection(cv::Mat* im, const DetectionResult& result, static cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result,
int line_size = 2, float font_size = 0.5f); int line_size = 2, float font_size = 0.5f);
static void VisFaceDetection(cv::Mat* im, const FaceDetectionResult& result, static cv::Mat VisFaceDetection(const cv::Mat& im,
int line_size = 2, float font_size = 0.5f); const FaceDetectionResult& result,
static void VisSegmentation(const cv::Mat& im, int line_size = 2, float font_size = 0.5f);
const SegmentationResult& result, static cv::Mat VisSegmentation(const cv::Mat& im,
cv::Mat* vis_img, const int& num_classes = 1000); const SegmentationResult& result);
static void VisMattingAlpha(const cv::Mat& im, const MattingResult& result, static cv::Mat VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
cv::Mat* vis_img, bool remove_small_connected_area = false);
bool remove_small_connected_area = false);
}; };
} // namespace vision } // namespace vision

View File

@@ -22,34 +22,41 @@ void BindVisualize(pybind11::module& m) {
[](pybind11::array& im_data, vision::DetectionResult& result, [](pybind11::array& im_data, vision::DetectionResult& result,
int line_size, float font_size) { int line_size, float font_size) {
auto im = PyArrayToCvMat(im_data); auto im = PyArrayToCvMat(im_data);
vision::Visualize::VisDetection(&im, result, line_size, auto vis_im = vision::Visualize::VisDetection(
font_size); im, result, line_size, font_size);
FDTensor out;
vision::Mat(vis_im).ShareWithTensor(&out);
return TensorToPyArray(out);
}) })
.def_static( .def_static(
"vis_face_detection", "vis_face_detection",
[](pybind11::array& im_data, vision::FaceDetectionResult& result, [](pybind11::array& im_data, vision::FaceDetectionResult& result,
int line_size, float font_size) { int line_size, float font_size) {
auto im = PyArrayToCvMat(im_data); auto im = PyArrayToCvMat(im_data);
vision::Visualize::VisFaceDetection(&im, result, line_size, auto vis_im = vision::Visualize::VisFaceDetection(
font_size); im, result, line_size, font_size);
FDTensor out;
vision::Mat(vis_im).ShareWithTensor(&out);
return TensorToPyArray(out);
}) })
.def_static( .def_static(
"vis_segmentation", "vis_segmentation",
[](pybind11::array& im_data, vision::SegmentationResult& result, [](pybind11::array& im_data, vision::SegmentationResult& result) {
pybind11::array& vis_im_data, const int& num_classes) {
cv::Mat im = PyArrayToCvMat(im_data); cv::Mat im = PyArrayToCvMat(im_data);
cv::Mat vis_im = PyArrayToCvMat(vis_im_data); auto vis_im = vision::Visualize::VisSegmentation(im, result);
vision::Visualize::VisSegmentation(im, result, &vis_im, FDTensor out;
num_classes); vision::Mat(vis_im).ShareWithTensor(&out);
return TensorToPyArray(out);
}) })
.def_static( .def_static("vis_matting_alpha",
"vis_matting_alpha", [](pybind11::array& im_data, vision::MattingResult& result,
[](pybind11::array& im_data, vision::MattingResult& result, bool remove_small_connected_area) {
pybind11::array& vis_im_data, bool remove_small_connected_area) { cv::Mat im = PyArrayToCvMat(im_data);
cv::Mat im = PyArrayToCvMat(im_data); auto vis_im = vision::Visualize::VisMattingAlpha(
cv::Mat vis_im = PyArrayToCvMat(vis_im_data); im, result, remove_small_connected_area);
vision::Visualize::VisMattingAlpha(im, result, &vis_im, FDTensor out;
remove_small_connected_area); vision::Mat(vis_im).ShareWithTensor(&out);
}); return TensorToPyArray(out);
});
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -1,79 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindWongkinyiu(pybind11::module& m) {
auto wongkinyiu_module =
m.def_submodule("wongkinyiu", "https://github.com/WongKinYiu");
pybind11::class_<vision::wongkinyiu::YOLOv7, FastDeployModel>(
wongkinyiu_module, "YOLOv7")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::wongkinyiu::YOLOv7& self, pybind11::array& data,
float conf_threshold, float nms_iou_threshold) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
})
.def_readwrite("size", &vision::wongkinyiu::YOLOv7::size)
.def_readwrite("padding_value",
&vision::wongkinyiu::YOLOv7::padding_value)
.def_readwrite("is_mini_pad", &vision::wongkinyiu::YOLOv7::is_mini_pad)
.def_readwrite("is_no_pad", &vision::wongkinyiu::YOLOv7::is_no_pad)
.def_readwrite("is_scale_up", &vision::wongkinyiu::YOLOv7::is_scale_up)
.def_readwrite("stride", &vision::wongkinyiu::YOLOv7::stride)
.def_readwrite("max_wh", &vision::wongkinyiu::YOLOv7::max_wh);
pybind11::class_<vision::wongkinyiu::YOLOR, FastDeployModel>(
wongkinyiu_module, "YOLOR")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::wongkinyiu::YOLOR& self, pybind11::array& data,
float conf_threshold, float nms_iou_threshold) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
})
.def_readwrite("size", &vision::wongkinyiu::YOLOR::size)
.def_readwrite("padding_value", &vision::wongkinyiu::YOLOR::padding_value)
.def_readwrite("is_mini_pad", &vision::wongkinyiu::YOLOR::is_mini_pad)
.def_readwrite("is_no_pad", &vision::wongkinyiu::YOLOR::is_no_pad)
.def_readwrite("is_scale_up", &vision::wongkinyiu::YOLOR::is_scale_up)
.def_readwrite("stride", &vision::wongkinyiu::YOLOR::stride)
.def_readwrite("max_wh", &vision::wongkinyiu::YOLOR::max_wh);
pybind11::class_<vision::wongkinyiu::ScaledYOLOv4, FastDeployModel>(
wongkinyiu_module, "ScaledYOLOv4")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::wongkinyiu::ScaledYOLOv4& self, pybind11::array& data,
float conf_threshold, float nms_iou_threshold) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
})
.def_readwrite("size", &vision::wongkinyiu::ScaledYOLOv4::size)
.def_readwrite("padding_value", &vision::wongkinyiu::ScaledYOLOv4::padding_value)
.def_readwrite("is_mini_pad", &vision::wongkinyiu::ScaledYOLOv4::is_mini_pad)
.def_readwrite("is_no_pad", &vision::wongkinyiu::ScaledYOLOv4::is_no_pad)
.def_readwrite("is_scale_up", &vision::wongkinyiu::ScaledYOLOv4::is_scale_up)
.def_readwrite("stride", &vision::wongkinyiu::ScaledYOLOv4::stride)
.def_readwrite("max_wh", &vision::wongkinyiu::ScaledYOLOv4::max_wh);
}
} // namespace fastdeploy

View File

@@ -20,8 +20,6 @@ from . import ppseg
from . import ultralytics from . import ultralytics
from . import meituan from . import meituan
from . import megvii from . import megvii
from . import visualize
from . import wongkinyiu
from . import deepcam from . import deepcam
from . import rangilyu from . import rangilyu
from . import linzaer from . import linzaer
@@ -29,3 +27,7 @@ from . import biubug6
from . import ppogg from . import ppogg
from . import deepinsight from . import deepinsight
from . import zhkkke from . import zhkkke
from . import detection
from .visualize import *

View File

@@ -0,0 +1,18 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .yolov7 import YOLOv7
from .yolor import YOLOR
from .scaled_yolov4 import ScaledYOLOv4

View File

@@ -0,0 +1,116 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
from ... import FastDeployModel, Frontend
from ... import c_lib_wrap as C
class ScaledYOLOv4(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=Frontend.ONNX):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(ScaledYOLOv4, self).__init__(runtime_option)
self._model = C.vision.detection.ScaledYOLOv4(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "ScaledYOLOv4 initialize failed."
def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5):
return self._model.predict(input_image, conf_threshold,
nms_iou_threshold)
# 一些跟ScaledYOLOv4模型有关的属性封装
# 多数是预处理相关可通过修改如model.size = [1280, 1280]改变预处理时resize的大小前提是模型支持
@property
def size(self):
return self._model.size
@property
def padding_value(self):
return self._model.padding_value
@property
def is_no_pad(self):
return self._model.is_no_pad
@property
def is_mini_pad(self):
return self._model.is_mini_pad
@property
def is_scale_up(self):
return self._model.is_scale_up
@property
def stride(self):
return self._model.stride
@property
def max_wh(self):
return self._model.max_wh
@size.setter
def size(self, wh):
assert isinstance(wh, [list, tuple]),\
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2,\
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh
@padding_value.setter
def padding_value(self, value):
assert isinstance(
value,
list), "The value to set `padding_value` must be type of list."
self._model.padding_value = value
@is_no_pad.setter
def is_no_pad(self, value):
assert isinstance(
value, bool), "The value to set `is_no_pad` must be type of bool."
self._model.is_no_pad = value
@is_mini_pad.setter
def is_mini_pad(self, value):
assert isinstance(
value,
bool), "The value to set `is_mini_pad` must be type of bool."
self._model.is_mini_pad = value
@is_scale_up.setter
def is_scale_up(self, value):
assert isinstance(
value,
bool), "The value to set `is_scale_up` must be type of bool."
self._model.is_scale_up = value
@stride.setter
def stride(self, value):
assert isinstance(
value, int), "The value to set `stride` must be type of int."
self._model.stride = value
@max_wh.setter
def max_wh(self, value):
assert isinstance(
value, float), "The value to set `max_wh` must be type of float."
self._model.max_wh = value

View File

@@ -0,0 +1,116 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
from ... import FastDeployModel, Frontend
from ... import c_lib_wrap as C
class YOLOR(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=Frontend.ONNX):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(YOLOR, self).__init__(runtime_option)
self._model = C.vision.detection.YOLOR(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "YOLOR initialize failed."
def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5):
return self._model.predict(input_image, conf_threshold,
nms_iou_threshold)
# 一些跟YOLOR模型有关的属性封装
# 多数是预处理相关可通过修改如model.size = [1280, 1280]改变预处理时resize的大小前提是模型支持
@property
def size(self):
return self._model.size
@property
def padding_value(self):
return self._model.padding_value
@property
def is_no_pad(self):
return self._model.is_no_pad
@property
def is_mini_pad(self):
return self._model.is_mini_pad
@property
def is_scale_up(self):
return self._model.is_scale_up
@property
def stride(self):
return self._model.stride
@property
def max_wh(self):
return self._model.max_wh
@size.setter
def size(self, wh):
assert isinstance(wh, [list, tuple]),\
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2,\
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh
@padding_value.setter
def padding_value(self, value):
assert isinstance(
value,
list), "The value to set `padding_value` must be type of list."
self._model.padding_value = value
@is_no_pad.setter
def is_no_pad(self, value):
assert isinstance(
value, bool), "The value to set `is_no_pad` must be type of bool."
self._model.is_no_pad = value
@is_mini_pad.setter
def is_mini_pad(self, value):
assert isinstance(
value,
bool), "The value to set `is_mini_pad` must be type of bool."
self._model.is_mini_pad = value
@is_scale_up.setter
def is_scale_up(self, value):
assert isinstance(
value,
bool), "The value to set `is_scale_up` must be type of bool."
self._model.is_scale_up = value
@stride.setter
def stride(self, value):
assert isinstance(
value, int), "The value to set `stride` must be type of int."
self._model.stride = value
@max_wh.setter
def max_wh(self, value):
assert isinstance(
value, float), "The value to set `max_wh` must be type of float."
self._model.max_wh = value

View File

@@ -0,0 +1,116 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
from ... import FastDeployModel, Frontend
from ... import c_lib_wrap as C
class YOLOv7(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=Frontend.ONNX):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(YOLOv7, self).__init__(runtime_option)
self._model = C.vision.detection.YOLOv7(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "YOLOv7 initialize failed."
def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5):
return self._model.predict(input_image, conf_threshold,
nms_iou_threshold)
# 一些跟YOLOv7模型有关的属性封装
# 多数是预处理相关可通过修改如model.size = [1280, 1280]改变预处理时resize的大小前提是模型支持
@property
def size(self):
return self._model.size
@property
def padding_value(self):
return self._model.padding_value
@property
def is_no_pad(self):
return self._model.is_no_pad
@property
def is_mini_pad(self):
return self._model.is_mini_pad
@property
def is_scale_up(self):
return self._model.is_scale_up
@property
def stride(self):
return self._model.stride
@property
def max_wh(self):
return self._model.max_wh
@size.setter
def size(self, wh):
assert isinstance(wh, [list, tuple]),\
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2,\
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh
@padding_value.setter
def padding_value(self, value):
assert isinstance(
value,
list), "The value to set `padding_value` must be type of list."
self._model.padding_value = value
@is_no_pad.setter
def is_no_pad(self, value):
assert isinstance(
value, bool), "The value to set `is_no_pad` must be type of bool."
self._model.is_no_pad = value
@is_mini_pad.setter
def is_mini_pad(self, value):
assert isinstance(
value,
bool), "The value to set `is_mini_pad` must be type of bool."
self._model.is_mini_pad = value
@is_scale_up.setter
def is_scale_up(self, value):
assert isinstance(
value,
bool), "The value to set `is_scale_up` must be type of bool."
self._model.is_scale_up = value
@stride.setter
def stride(self, value):
assert isinstance(
value, int), "The value to set `stride` must be type of int."
self._model.stride = value
@max_wh.setter
def max_wh(self, value):
assert isinstance(
value, float), "The value to set `max_wh` must be type of float."
self._model.max_wh = value

View File

@@ -17,23 +17,22 @@ import logging
from ... import c_lib_wrap as C from ... import c_lib_wrap as C
def vis_detection(im_data, det_result, line_size=1, font_size=0.5): def vis_detection(im_data, det_result, line_size=2, font_size=0.5):
C.vision.Visualize.vis_detection(im_data, det_result, line_size, font_size) return C.vision.Visualize.vis_detection(im_data, det_result, line_size,
font_size)
def vis_face_detection(im_data, face_det_result, line_size=1, font_size=0.5): def vis_face_detection(im_data, face_det_result, line_size=2, font_size=0.5):
C.vision.Visualize.vis_face_detection(im_data, face_det_result, line_size, return C.vision.Visualize.vis_face_detection(im_data, face_det_result,
font_size) line_size, font_size)
def vis_segmentation(im_data, seg_result, vis_im_data, num_classes=1000): def vis_segmentation(im_data, seg_result):
C.vision.Visualize.vis_segmentation(im_data, seg_result, vis_im_data, return C.vision.Visualize.vis_segmentation(im_data, seg_result)
num_classes)
def vis_matting_alpha(im_data, def vis_matting_alpha(im_data,
matting_result, matting_result,
vis_im_data,
remove_small_connected_area=False): remove_small_connected_area=False):
C.vision.Visualize.vis_matting_alpha(im_data, matting_result, vis_im_data, return C.vision.Visualize.vis_matting_alpha(im_data, matting_result,
remove_small_connected_area) remove_small_connected_area)

View File

@@ -16,8 +16,8 @@ im = cv2.imread("3.jpg")
result = model.predict(im, conf_threshold=0.7, nms_iou_threshold=0.3) result = model.predict(im, conf_threshold=0.7, nms_iou_threshold=0.3)
# 可视化结果 # 可视化结果
fd.vision.visualize.vis_face_detection(im, result) vis_im = fd.vision.visualize.vis_face_detection(im, result)
cv2.imwrite("vis_result.jpg", im) cv2.imwrite("vis_result.jpg", vis_im)
# 输出预测结果 # 输出预测结果
print(result) print(result)