add yolov6 c++ and yolov6 pybind (#16)

* update .gitignore

* Added checking for cmake include dir

* fixed missing trt_backend option bug when init from trt

* remove un-need data layout and add pre-check for dtype

* changed RGB2BRG to BGR2RGB in ppcls model

* add model_zoo yolov6 c++/python demo

* fixed CMakeLists.txt typos

* update yolov6 cpp/README.md
This commit is contained in:
DefTruth
2022-07-14 16:12:28 +08:00
committed by GitHub
parent 5f83b3c532
commit de7c06a309
27 changed files with 955 additions and 6 deletions

2
.gitignore vendored
View File

@@ -10,3 +10,5 @@ build-debug.sh
fastdeploy.egg-info fastdeploy.egg-info
.setuptools-cmake-build .setuptools-cmake-build
fastdeploy/version.py fastdeploy/version.py
fastdeploy/LICENSE*
fastdeploy/ThirdPartyNotices*

View File

@@ -38,6 +38,9 @@ option(ENABLE_VISION_VISUALIZE "if to enable visualize vision model result toolb
option(ENABLE_OPENCV_CUDA "if to enable opencv with cuda, this will allow process image with GPU." OFF) option(ENABLE_OPENCV_CUDA "if to enable opencv with cuda, this will allow process image with GPU." OFF)
option(ENABLE_DEBUG "if to enable print debug information, this may reduce performance." OFF) option(ENABLE_DEBUG "if to enable print debug information, this may reduce performance." OFF)
# Whether to build fastdeply with vision/text/... examples, only for testings.
option(WTIH_VISION_EXAMPLES "Whether to build fastdeply with vision examples" ON)
if(ENABLE_DEBUG) if(ENABLE_DEBUG)
add_definitions(-DFASTDEPLOY_DEBUG) add_definitions(-DFASTDEPLOY_DEBUG)
endif() endif()
@@ -50,6 +53,13 @@ option(BUILD_FASTDEPLOY_PYTHON "if build python lib for fastdeploy." OFF)
include_directories(${PROJECT_SOURCE_DIR}) include_directories(${PROJECT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR})
if (WTIH_VISION_EXAMPLES AND EXISTS ${PROJECT_SOURCE_DIR}/examples)
# ENABLE_VISION and ENABLE_VISION_VISUALIZE must be ON if enable vision examples.
message(STATUS "Found WTIH_VISION_EXAMPLES ON, so, force ENABLE_VISION and ENABLE_VISION_VISUALIZE ON")
set(ENABLE_VISION ON CACHE BOOL "force to enable vision models usage" FORCE)
set(ENABLE_VISION_VISUALIZE ON CACHE BOOL "force to enable visualize vision model result toolbox" FORCE)
endif()
add_definitions(-DFASTDEPLOY_LIB) add_definitions(-DFASTDEPLOY_LIB)
file(GLOB_RECURSE ALL_DEPLOY_SRCS ${PROJECT_SOURCE_DIR}/fastdeploy/*.cc) file(GLOB_RECURSE ALL_DEPLOY_SRCS ${PROJECT_SOURCE_DIR}/fastdeploy/*.cc)
file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/fastdeploy/backends/ort/*.cc) file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/fastdeploy/backends/ort/*.cc)
@@ -170,6 +180,13 @@ endif()
set_target_properties(fastdeploy PROPERTIES VERSION ${FASTDEPLOY_VERSION}) set_target_properties(fastdeploy PROPERTIES VERSION ${FASTDEPLOY_VERSION})
target_link_libraries(fastdeploy ${DEPEND_LIBS}) target_link_libraries(fastdeploy ${DEPEND_LIBS})
# add examples after prepare include paths for third-parties
if (WTIH_VISION_EXAMPLES AND EXISTS ${PROJECT_SOURCE_DIR}/examples)
add_definitions(-DWTIH_VISION_EXAMPLES)
set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/examples/bin)
add_subdirectory(examples)
endif()
include(external/summary.cmake) include(external/summary.cmake)
fastdeploy_summary() fastdeploy_summary()

8
examples/.gitignore vendored Normal file
View File

@@ -0,0 +1,8 @@
*.jpg
*.png
*.jpeg
*.onnx
*.engine
*.pd*
*.nb
bin

22
examples/CMakeLists.txt Normal file
View File

@@ -0,0 +1,22 @@
function(add_fastdeploy_executable field url model)
# temp target name/file var in function scope
set(TEMP_TARGET_FILE ${PROJECT_SOURCE_DIR}/examples/${field}/${url}_${model}.cc)
set(TEMP_TARGET_NAME ${field}_${url}_${model})
if (EXISTS ${TEMP_TARGET_FILE} AND TARGET fastdeploy)
add_executable(${TEMP_TARGET_NAME} ${TEMP_TARGET_FILE})
target_link_libraries(${TEMP_TARGET_NAME} PUBLIC fastdeploy)
message(STATUS "Found source file: [${field}/${url}_${model}.cc], ADD!!! fastdeploy executable: [${TEMP_TARGET_NAME}] !")
else ()
message(WARNING "Can not found source file: [${field}/${url}_${model}.cc], SKIP!!! fastdeploy executable: [${TEMP_TARGET_NAME}] !")
endif()
unset(TEMP_TARGET_FILE)
unset(TEMP_TARGET_NAME)
endfunction()
# vision examples
if (WTIH_VISION_EXAMPLES)
add_fastdeploy_executable(vision ultralytics yolov5)
add_fastdeploy_executable(vision meituan yolov6)
endif()
# other examples ...

11
examples/resources/.gitignore vendored Normal file
View File

@@ -0,0 +1,11 @@
images/*.jpg
images/*.jpeg
images/*.png
models/*.onnx
models/*.pd*
models/*.engine
models/*.trt
models/*.nb
outputs/*.jpg
outputs/*.jpeg
outputs/*.png

3
examples/resources/images/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
*.jpg
*.jpeg
*.png

5
examples/resources/models/.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
*.onnx
*.engine
*.pd*
*.nb
*.trt

3
examples/resources/outputs/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
*.jpg
*.png
*.jpeg

View File

@@ -0,0 +1,53 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
int main() {
namespace vis = fastdeploy::vision;
std::string model_file = "../resources/models/yolov6s.onnx";
std::string img_path = "../resources/images/bus.jpg";
std::string vis_path = "../resources/outputs/meituan_yolov6_vis_result.jpg";
auto model = vis::meituan::YOLOv6(model_file);
if (!model.Initialized()) {
std::cerr << "Init Failed." << std::endl;
return -1;
} else {
std::cout << "Init Done! Dynamic Mode: "
<< model.IsDynamicShape() << std::endl;
}
model.EnableDebug();
cv::Mat im = cv::imread(img_path);
cv::Mat vis_im = im.clone();
vis::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Prediction Failed." << std::endl;
return -1;
} else {
std::cout << "Prediction Done!" << std::endl;
}
// 输出预测框结果
std::cout << res.Str() << std::endl;
// 可视化预测结果
vis::Visualize::VisDetection(&vis_im, res);
cv::imwrite(vis_path, vis_im);
std::cout << "Detect Done! Saved: " << vis_path << std::endl;
return 0;
}

View File

@@ -0,0 +1,52 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
int main() {
namespace vis = fastdeploy::vision;
std::string model_file = "../resources/models/yolov5s.onnx";
std::string img_path = "../resources/images/bus.jpg";
std::string vis_path = "../resources/outputs/ultralytics_yolov5_vis_result.jpg";
auto model = vis::ultralytics::YOLOv5(model_file);
if (!model.Initialized()) {
std::cerr << "Init Failed! Model: " << model_file << std::endl;
return -1;
} else {
std::cout << "Init Done! Model:" << model_file << std::endl;
}
model.EnableDebug();
cv::Mat im = cv::imread(img_path);
cv::Mat vis_im = im.clone();
vis::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Prediction Failed." << std::endl;
return -1;
} else {
std::cout << "Prediction Done!" << std::endl;
}
// 输出预测框结果
std::cout << res.Str() << std::endl;
// 可视化预测结果
vis::Visualize::VisDetection(&vis_im, res);
cv::imwrite(vis_path, vis_im);
std::cout << "Detect Done! Saved: " << vis_path << std::endl;
return 0;
}

View File

@@ -18,6 +18,7 @@ import shutil
import requests import requests
import time import time
import zipfile import zipfile
import tarfile
import hashlib import hashlib
import tqdm import tqdm
import logging import logging

View File

@@ -51,5 +51,5 @@ class FastDeployModel:
@property @property
def initialized(self): def initialized(self):
if self._model is None: if self._model is None:
return false return False
return self._model.initialized() return self._model.initialized()

View File

@@ -17,6 +17,7 @@
#ifdef ENABLE_VISION #ifdef ENABLE_VISION
#include "fastdeploy/vision/ppcls/model.h" #include "fastdeploy/vision/ppcls/model.h"
#include "fastdeploy/vision/ultralytics/yolov5.h" #include "fastdeploy/vision/ultralytics/yolov5.h"
#include "fastdeploy/vision/meituan/yolov6.h"
#endif #endif
#include "fastdeploy/vision/visualize/visualize.h" #include "fastdeploy/vision/visualize/visualize.h"

View File

@@ -16,4 +16,5 @@ from __future__ import absolute_import
from . import evaluation from . import evaluation
from . import ppcls from . import ppcls
from . import ultralytics from . import ultralytics
from . import meituan
from . import visualize from . import visualize

View File

@@ -0,0 +1,120 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
from ... import FastDeployModel, Frontend
from ... import fastdeploy_main as C
class YOLOv6(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=Frontend.ONNX):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(YOLOv6, self).__init__(runtime_option)
self._model = C.vision.meituan.YOLOv6(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "YOLOv6 initialize failed."
def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5):
return self._model.predict(input_image, conf_threshold,
nms_iou_threshold)
# BOOL: 查看输入的模型是否为动态维度的
def is_dynamic_shape(self):
return self._model.is_dynamic_shape()
# 一些跟YOLOv6模型有关的属性封装
# 多数是预处理相关可通过修改如model.size = [1280, 1280]改变预处理时resize的大小前提是模型支持
@property
def size(self):
return self._model.size
@property
def padding_value(self):
return self._model.padding_value
@property
def is_no_pad(self):
return self._model.is_no_pad
@property
def is_mini_pad(self):
return self._model.is_mini_pad
@property
def is_scale_up(self):
return self._model.is_scale_up
@property
def stride(self):
return self._model.stride
@property
def max_wh(self):
return self._model.max_wh
@size.setter
def size(self, wh):
assert isinstance(wh, [list, tuple]),\
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2,\
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh
@padding_value.setter
def padding_value(self, value):
assert isinstance(
value,
list), "The value to set `padding_value` must be type of list."
self._model.padding_value = value
@is_no_pad.setter
def is_no_pad(self, value):
assert isinstance(
value, bool), "The value to set `is_no_pad` must be type of bool."
self._model.is_no_pad = value
@is_mini_pad.setter
def is_mini_pad(self, value):
assert isinstance(
value,
bool), "The value to set `is_mini_pad` must be type of bool."
self._model.is_mini_pad = value
@is_scale_up.setter
def is_scale_up(self, value):
assert isinstance(
value,
bool), "The value to set `is_scale_up` must be type of bool."
self._model.is_scale_up = value
@stride.setter
def stride(self, value):
assert isinstance(
value, int), "The value to set `stride` must be type of int."
self._model.stride = value
@max_wh.setter
def max_wh(self, value):
assert isinstance(
value, float), "The value to set `max_wh` must be type of float."
self._model.max_wh = value

View File

@@ -0,0 +1,45 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindMeituan(pybind11::module& m) {
auto meituan_module =
m.def_submodule("meituan", "https://github.com/meituan/YOLOv6");
pybind11::class_<vision::meituan::YOLOv6, FastDeployModel>(
meituan_module, "YOLOv6")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::meituan::YOLOv6& self, pybind11::array& data,
float conf_threshold, float nms_iou_threshold) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
})
.def("is_dynamic_shape",
[](vision::meituan::YOLOv6& self) {
return self.IsDynamicShape();
})
.def_readwrite("size", &vision::meituan::YOLOv6::size)
.def_readwrite("padding_value",
&vision::meituan::YOLOv6::padding_value)
.def_readwrite("is_mini_pad", &vision::meituan::YOLOv6::is_mini_pad)
.def_readwrite("is_no_pad", &vision::meituan::YOLOv6::is_no_pad)
.def_readwrite("is_scale_up", &vision::meituan::YOLOv6::is_scale_up)
.def_readwrite("stride", &vision::meituan::YOLOv6::stride)
.def_readwrite("max_wh", &vision::meituan::YOLOv6::max_wh);
}
} // namespace fastdeploy

View File

@@ -0,0 +1,262 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/meituan/yolov6.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace meituan {
void LetterBox(Mat* mat, std::vector<int> size, std::vector<float> color,
bool _auto, bool scale_fill = false, bool scale_up = true,
int stride = 32) {
float scale = std::min(size[1] * 1.0f / static_cast<float>(mat->Height()),
size[0] * 1.0f / static_cast<float>(mat->Width()));
if (!scale_up) {
scale = std::min(scale, 1.0f);
}
int resize_h = int(round(static_cast<float>(mat->Height()) * scale));
int resize_w = int(round(static_cast<float>(mat->Width()) * scale));
int pad_w = size[0] - resize_w;
int pad_h = size[1] - resize_h;
if (_auto) {
pad_h = pad_h % stride;
pad_w = pad_w % stride;
} else if (scale_fill) {
pad_h = 0;
pad_w = 0;
resize_h = size[1];
resize_w = size[0];
}
if (resize_h != mat->Height() || resize_w != mat->Width()) {
Resize::Run(mat, resize_w, resize_h);
}
if (pad_h > 0 || pad_w > 0) {
float half_h = pad_h * 1.0 / 2;
int top = int(round(half_h - 0.1));
int bottom = int(round(half_h + 0.1));
float half_w = pad_w * 1.0 / 2;
int left = int(round(half_w - 0.1));
int right = int(round(half_w + 0.1));
Pad::Run(mat, top, bottom, left, right, color);
}
}
YOLOv6::YOLOv6(const std::string& model_file, const std::string& params_file,
const RuntimeOption& custom_option,
const Frontend& model_format) {
if (model_format == Frontend::ONNX) {
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
} else {
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
}
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool YOLOv6::Initialize() {
// parameters for preprocess
size = {640, 640};
padding_value = {114.0, 114.0, 114.0};
is_mini_pad = false;
is_no_pad = false;
is_scale_up = false;
stride = 32;
max_wh = 4096.0f;
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
// Check if the input shape is dynamic after Runtime already initialized,
// Note that, YOLOv6 has 1 input only. We need to force is_mini_pad
// 'false' to keep static shape after padding (LetterBox)
// when the is_dynamic_shape is 'false'.
is_dynamic_shape_ = false;
auto shape = InputInfoOfRuntime(0).shape;
for (const auto &d: shape) {
if (d <= 0) {
is_dynamic_shape_ = true;
}
}
if (!is_dynamic_shape_) {
is_mini_pad = false;
}
return true;
}
bool YOLOv6::Preprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info) {
// process after image load
float ratio = std::min(size[1] * 1.0f / static_cast<float>(mat->Height()),
size[0] * 1.0f / static_cast<float>(mat->Width()));
if (ratio != 1.0) {
int interp = cv::INTER_AREA;
if (ratio > 1.0) {
interp = cv::INTER_LINEAR;
}
int resize_h = int(round(static_cast<float>(mat->Height()) * ratio));
int resize_w = int(round(static_cast<float>(mat->Width()) * ratio));
Resize::Run(mat, resize_w, resize_h, -1, -1, interp);
}
// yolov6's preprocess steps
// 1. letterbox
// 2. BGR->RGB
// 3. HWC->CHW
LetterBox(mat, size, padding_value, is_mini_pad, is_no_pad, is_scale_up,
stride);
BGR2RGB::Run(mat);
Normalize::Run(mat, std::vector<float>(mat->Channels(), 0.0),
std::vector<float>(mat->Channels(), 1.0));
// Record output shape of preprocessed image
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
return true;
}
bool YOLOv6::Postprocess(
FDTensor& infer_result, DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
float conf_threshold, float nms_iou_threshold) {
FDASSERT(infer_result.shape[0] == 1, "Only support batch =1 now.");
result->Clear();
result->Reserve(infer_result.shape[1]);
if (infer_result.dtype != FDDataType::FP32) {
FDERROR << "Only support post process with float32 data." << std::endl;
return false;
}
float* data = static_cast<float*>(infer_result.Data());
for (size_t i = 0; i < infer_result.shape[1]; ++i) {
int s = i * infer_result.shape[2];
float confidence = data[s + 4];
float* max_class_score =
std::max_element(data + s + 5, data + s + infer_result.shape[2]);
confidence *= (*max_class_score);
// filter boxes by conf_threshold
if (confidence <= conf_threshold) {
continue;
}
int32_t label_id = std::distance(data + s + 5, max_class_score);
// convert from [x, y, w, h] to [x1, y1, x2, y2]
result->boxes.emplace_back(std::array<float, 4>{
data[s] - data[s + 2] / 2.0f + label_id * max_wh,
data[s + 1] - data[s + 3] / 2.0f + label_id * max_wh,
data[s + 0] + data[s + 2] / 2.0f + label_id * max_wh,
data[s + 1] + data[s + 3] / 2.0f + label_id * max_wh});
result->label_ids.push_back(label_id);
result->scores.push_back(confidence);
}
utils::NMS(result, nms_iou_threshold);
// scale the boxes to the origin image shape
auto iter_out = im_info.find("output_shape");
auto iter_ipt = im_info.find("input_shape");
FDASSERT(iter_out != im_info.end() && iter_ipt != im_info.end(),
"Cannot find input_shape or output_shape from im_info.");
float out_h = iter_out->second[0];
float out_w = iter_out->second[1];
float ipt_h = iter_ipt->second[0];
float ipt_w = iter_ipt->second[1];
float scale = std::min(out_h / ipt_h, out_w / ipt_w);
for (size_t i = 0; i < result->boxes.size(); ++i) {
float pad_h = (out_h - ipt_h * scale) / 2;
float pad_w = (out_w - ipt_w * scale) / 2;
int32_t label_id = (result->label_ids)[i];
// clip box
result->boxes[i][0] = result->boxes[i][0] - max_wh * label_id;
result->boxes[i][1] = result->boxes[i][1] - max_wh * label_id;
result->boxes[i][2] = result->boxes[i][2] - max_wh * label_id;
result->boxes[i][3] = result->boxes[i][3] - max_wh * label_id;
result->boxes[i][0] = std::max((result->boxes[i][0] - pad_w) / scale, 0.0f);
result->boxes[i][1] = std::max((result->boxes[i][1] - pad_h) / scale, 0.0f);
result->boxes[i][2] = std::max((result->boxes[i][2] - pad_w) / scale, 0.0f);
result->boxes[i][3] = std::max((result->boxes[i][3] - pad_h) / scale, 0.0f);
result->boxes[i][0] = std::min(result->boxes[i][0], ipt_w - 1.0f);
result->boxes[i][1] = std::min(result->boxes[i][1], ipt_h - 1.0f);
result->boxes[i][2] = std::min(result->boxes[i][2], ipt_w - 1.0f);
result->boxes[i][3] = std::min(result->boxes[i][3], ipt_h - 1.0f);
}
return true;
}
bool YOLOv6::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
float nms_iou_threshold) {
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_START(0)
#endif
Mat mat(*im);
std::vector<FDTensor> input_tensors(1);
std::map<std::string, std::array<float, 2>> im_info;
// Record the shape of image and the shape of preprocessed image
im_info["input_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};
im_info["output_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(0, "Preprocess")
TIMERECORD_START(1)
#endif
input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
return false;
}
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(1, "Inference")
TIMERECORD_START(2)
#endif
if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
nms_iou_threshold)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(2, "Postprocess")
#endif
return true;
}
} // namespace meituan
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,100 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace meituan {
class FASTDEPLOY_DECL YOLOv6 : public FastDeployModel {
public:
// 当model_format为ONNX时无需指定params_file
// 当model_format为Paddle时则需同时指定model_file & params_file
YOLOv6(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX);
// 定义模型的名称
virtual std::string ModelName() const { return "meituan/YOLOv6"; }
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
virtual bool Initialize();
// 输入图像预处理操作
// Mat为FastDeploy定义的数据结构
// FDTensor为预处理后的Tensor数据传给后端进行推理
// im_info为预处理过程保存的数据在后处理中需要用到
virtual bool Preprocess(Mat* mat, FDTensor* outputs,
std::map<std::string, std::array<float, 2>>* im_info);
// 后端推理结果后处理,输出给用户
// infer_result 为后端推理后的输出Tensor
// result 为模型预测的结果
// im_info 为预处理记录的信息后处理用于还原box
// conf_threshold 后处理时过滤box的置信度阈值
// nms_iou_threshold 后处理时NMS设定的iou阈值
virtual bool Postprocess(
FDTensor& infer_result, DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
float conf_threshold, float nms_iou_threshold);
// 模型预测接口,即用户调用的接口
// im 为用户的输入数据目前对于CV均定义为cv::Mat
// result 为模型预测的输出结构体
// conf_threshold 为后处理的参数
// nms_iou_threshold 为后处理的参数
virtual bool Predict(cv::Mat* im, DetectionResult* result,
float conf_threshold = 0.25,
float nms_iou_threshold = 0.5);
// 用户可以通过该接口 查看输入的模型是否为动态维度
virtual bool IsDynamicShape() const { return is_dynamic_shape_; }
// 以下为模型在预测时的一些参数,基本是前后处理所需
// 用户在创建模型后,可根据模型的要求,以及自己的需求
// 对参数进行修改
// tuple of (width, height)
std::vector<int> size;
// padding value, size should be same with Channels
std::vector<float> padding_value;
// only pad to the minimum rectange which height and width is times of stride
bool is_mini_pad;
// while is_mini_pad = false and is_no_pad = true, will resize the image to
// the set size
bool is_no_pad;
// if is_scale_up is false, the input image only can be zoom out, the maximum
// resize scale cannot exceed 1.0
bool is_scale_up;
// padding stride, for is_mini_pad
int stride;
// for offseting the boxes by classes when using NMS, default 4096 in meituan/YOLOv6
float max_wh;
protected:
// whether to inference with dynamic shape (e.g ONNX export with dynamic shape or not.)
// meituan/YOLOv6 official 'export_onnx.py' script will export static ONNX by default.
// while is_dynamic_shape if 'false', is_mini_pad will force 'false'. This value will
// auto check by fastdeploy after the internal Runtime already initialized.
bool is_dynamic_shape_;
};
} // namespace meituan
} // namespace vision
} // namespace fastdeploy

View File

@@ -18,6 +18,7 @@ namespace fastdeploy {
void BindPpClsModel(pybind11::module& m); void BindPpClsModel(pybind11::module& m);
void BindUltralytics(pybind11::module& m); void BindUltralytics(pybind11::module& m);
void BindMeituan(pybind11::module& m);
#ifdef ENABLE_VISION_VISUALIZE #ifdef ENABLE_VISION_VISUALIZE
void BindVisualize(pybind11::module& m); void BindVisualize(pybind11::module& m);
#endif #endif
@@ -40,6 +41,9 @@ void BindVision(pybind11::module& m) {
BindPpClsModel(m); BindPpClsModel(m);
BindUltralytics(m); BindUltralytics(m);
BindMeituan(m);
#ifdef ENABLE_VISION_VISUALIZE
BindVisualize(m); BindVisualize(m);
#endif
} }
} // namespace fastdeploy } // namespace fastdeploy

12
model_zoo/.gitignore vendored Normal file
View File

@@ -0,0 +1,12 @@
*.png
*.jpg
*.jpeg
*.onnx
*.zip
*.tar
*.pd*
*.engine
*.trt
*.nb
*.tgz
*.gz

View File

@@ -23,7 +23,7 @@ YOLOv5模型加载和初始化当model_format为`fd.Frontend.ONNX`时,只
> >
> **参数** > **参数**
> >
> > * **image_data**(np.ndarray): 输入数据注意需为HWCRGB格式 > > * **image_data**(np.ndarray): 输入数据注意需为HWCBGR格式
> > * **conf_threshold**(float): 检测框置信度过滤阈值 > > * **conf_threshold**(float): 检测框置信度过滤阈值
> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值 > > * **nms_iou_threshold**(float): NMS处理过程中iou阈值
@@ -49,9 +49,9 @@ YOLOv5模型加载和初始化当model_format为`Frontend::ONNX`时,只需
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置 > * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式 > * **model_format**(Frontend): 模型格式
#### predict函数 #### Predict函数
> ``` > ```
> YOLOv5::predict(cv::Mat* im, DetectionResult* result, > YOLOv5::Predict(cv::Mat* im, DetectionResult* result,
> float conf_threshold = 0.25, > float conf_threshold = 0.25,
> float nms_iou_threshold = 0.5) > float nms_iou_threshold = 0.5)
> ``` > ```
@@ -59,7 +59,7 @@ YOLOv5模型加载和初始化当model_format为`Frontend::ONNX`时,只需
> >
> **参数** > **参数**
> >
> > * **im**: 输入图像注意需为HWCRGB格式 > > * **im**: 输入图像注意需为HWCBGR格式
> > * **result**: 检测结果,包括检测框,各个框的置信度 > > * **result**: 检测结果,包括检测框,各个框的置信度
> > * **conf_threshold**: 检测框置信度过滤阈值 > > * **conf_threshold**: 检测框置信度过滤阈值
> > * **nms_iou_threshold**: NMS处理过程中iou阈值 > > * **nms_iou_threshold**: NMS处理过程中iou阈值

View File

@@ -0,0 +1,45 @@
# YOLOv6部署示例
本文档说明如何进行[YOLOv6](https://github.com/meituan/YOLOv6)的快速部署推理。本目录结构如下
```
.
├── cpp # C++ 代码目录
│   ├── CMakeLists.txt # C++ 代码编译CMakeLists文件
│   ├── README.md # C++ 代码编译部署文档
│   └── yolov6.cc # C++ 示例代码
├── README.md # YOLOv6 部署文档
└── yolov6.py # Python示例代码
```
## 安装FastDeploy
使用如下命令安装FastDeploy注意到此处安装的是`vision-cpu`,也可根据需求安装`vision-gpu`
```
# 安装fastdeploy-python工具
pip install fastdeploy-python
# 安装vision-cpu模块
fastdeploy install vision-cpu
```
## Python部署
执行如下代码即会自动下载YOLOv6模型和测试图片
```
python yolov6.py
```
执行完成后会将可视化结果保存在本地`vis_result.jpg`,同时输出检测结果如下
```
DetectionResult: [xmin, ymin, xmax, ymax, score, label_id]
11.772949,229.269287, 792.933838, 748.294189, 0.954794, 5
667.140381,396.185455, 807.701721, 881.810120, 0.900997, 0
223.271011,405.105743, 345.740723, 859.328552, 0.898938, 0
50.135777,405.863129, 245.485519, 904.153809, 0.888936, 0
0.000000,549.002869, 77.864723, 869.455017, 0.614145, 0
```
## 其它文档
- [C++部署](./cpp/README.md)
- [YOLOv6 API文档](./api.md)

View File

@@ -0,0 +1,71 @@
# YOLOv6 API说明
## Python API
### YOLOv6类
```
fastdeploy.vision.meituan.YOLOv6(model_file, params_file=None, runtime_option=None, model_format=fd.Frontend.ONNX)
```
YOLOv6模型加载和初始化当model_format为`fd.Frontend.ONNX`只需提供model_file`yolov6s.onnx`当model_format为`fd.Frontend.PADDLE`则需同时提供model_file和params_file。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式
#### predict函数
> ```
> YOLOv6.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5)
> ```
> 模型预测结口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **image_data**(np.ndarray): 输入数据注意需为HWCBGR格式
> > * **conf_threshold**(float): 检测框置信度过滤阈值
> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值
示例代码参考[yolov6.py](./yolov6.py)
## C++ API
### YOLOv6类
```
fastdeploy::vision::meituan::YOLOv6(
const string& model_file,
const string& params_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX)
```
YOLOv6模型加载和初始化当model_format为`Frontend::ONNX`只需提供model_file`yolov6s.onnx`当model_format为`Frontend::PADDLE`则需同时提供model_file和params_file。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式
#### Predict函数
> ```
> YOLOv6::Predict(cv::Mat* im, DetectionResult* result,
> float conf_threshold = 0.25,
> float nms_iou_threshold = 0.5)
> ```
> 模型预测接口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **im**: 输入图像注意需为HWCBGR格式
> > * **result**: 检测结果,包括检测框,各个框的置信度
> > * **conf_threshold**: 检测框置信度过滤阈值
> > * **nms_iou_threshold**: NMS处理过程中iou阈值
示例代码参考[cpp/yolov6.cc](cpp/yolov6.cc)
## 其它API使用
- [模型部署RuntimeOption配置](../../../docs/api/runtime_option.md)

View File

@@ -0,0 +1,17 @@
PROJECT(yolov6_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.16)
# 在低版本ABI环境中通过如下代码进行兼容性编译
# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
# 指定下载解压后的fastdeploy库路径
set(FASTDEPLOY_INSTALL_DIR ${PROJECT_SOURCE_DIR}/fastdeploy-linux-x64-0.0.3/)
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(yolov6_demo ${PROJECT_SOURCE_DIR}/yolov6.cc)
# 添加FastDeploy库依赖
target_link_libraries(yolov6_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,30 @@
# 编译YOLOv6示例
```
# 下载和解压预测库
wget https://bj.bcebos.com/paddle2onnx/fastdeploy/fastdeploy-linux-x64-0.0.3.tgz
tar xvf fastdeploy-linux-x64-0.0.3.tgz
# 编译示例代码
mkdir build & cd build
cmake ..
make -j
# 下载模型和图片
wget https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6s.onnx
wget https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg
# 执行
./yolov6_demo
```
执行完后可视化的结果保存在本地`vis_result.jpg`,同时会将检测框输出在终端,如下所示
```
DetectionResult: [xmin, ymin, xmax, ymax, score, label_id]
11.772949,229.269287, 792.933838, 748.294189, 0.954794, 5
667.140381,396.185455, 807.701721, 881.810120, 0.900997, 0
223.271011,405.105743, 345.740723, 859.328552, 0.898938, 0
50.135777,405.863129, 245.485519, 904.153809, 0.888936, 0
0.000000,549.002869, 77.864723, 869.455017, 0.614145, 0
```

View File

@@ -0,0 +1,40 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
int main() {
namespace vis = fastdeploy::vision;
auto model = vis::meituan::YOLOv6("yolov6s.onnx");
if (!model.Initialized()) {
std::cerr << "Init Failed." << std::endl;
return -1;
}
cv::Mat im = cv::imread("bus.jpg");
cv::Mat vis_im = im.clone();
vis::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Prediction Failed." << std::endl;
return -1;
}
// 输出预测框结果
std::cout << res.Str() << std::endl;
// 可视化预测结果
vis::Visualize::VisDetection(&vis_im, res);
cv::imwrite("vis_result.jpg", vis_im);
return 0;
}

View File

@@ -0,0 +1,24 @@
import fastdeploy as fd
import cv2
# 下载模型和测试图片
model_url = "https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6s.onnx"
test_jpg_url = "https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg"
fd.download(model_url, ".", show_progress=True)
fd.download(test_jpg_url, ".", show_progress=True)
# 加载模型
model = fd.vision.meituan.YOLOv6("yolov6s.onnx")
print(model.is_dynamic_shape())
# 预测图片
im = cv2.imread("bus.jpg")
result = model.predict(im, conf_threshold=0.25, nms_iou_threshold=0.5)
# 可视化结果
fd.vision.visualize.vis_detection(im, result)
cv2.imwrite("vis_result.jpg", im)
# 输出预测结果
print(result)
print(model.runtime_option)