mirror of
https://github.com/hpc203/YOLOP-opencv-dnn.git
synced 2025-09-26 20:31:17 +08:00
Add files via upload
This commit is contained in:
1
bdd100k.names
Normal file
1
bdd100k.names
Normal file
@@ -0,0 +1 @@
|
|||||||
|
car
|
136
export_onnx.py
Normal file
136
export_onnx.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from lib.models.common import Conv, SPP, Bottleneck, BottleneckCSP, Focus, Concat, Detect, SharpenConv
|
||||||
|
from torch.nn import Upsample
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# The lane line and the driving area segment branches without share information with each other and without link
|
||||||
|
YOLOP = [
|
||||||
|
[24, 33, 42], # Det_out_idx, Da_Segout_idx, LL_Segout_idx
|
||||||
|
[-1, Focus, [3, 32, 3]], # 0
|
||||||
|
[-1, Conv, [32, 64, 3, 2]], # 1
|
||||||
|
[-1, BottleneckCSP, [64, 64, 1]], # 2
|
||||||
|
[-1, Conv, [64, 128, 3, 2]], # 3
|
||||||
|
[-1, BottleneckCSP, [128, 128, 3]], # 4
|
||||||
|
[-1, Conv, [128, 256, 3, 2]], # 5
|
||||||
|
[-1, BottleneckCSP, [256, 256, 3]], # 6
|
||||||
|
[-1, Conv, [256, 512, 3, 2]], # 7
|
||||||
|
[-1, SPP, [512, 512, [5, 9, 13]]], # 8
|
||||||
|
[-1, BottleneckCSP, [512, 512, 1, False]], # 9
|
||||||
|
[-1, Conv, [512, 256, 1, 1]], # 10
|
||||||
|
[-1, Upsample, [None, 2, 'nearest']], # 11
|
||||||
|
[[-1, 6], Concat, [1]], # 12
|
||||||
|
[-1, BottleneckCSP, [512, 256, 1, False]], # 13
|
||||||
|
[-1, Conv, [256, 128, 1, 1]], # 14
|
||||||
|
[-1, Upsample, [None, 2, 'nearest']], # 15
|
||||||
|
[[-1, 4], Concat, [1]], # 16 #Encoder
|
||||||
|
|
||||||
|
[-1, BottleneckCSP, [256, 128, 1, False]], # 17
|
||||||
|
[-1, Conv, [128, 128, 3, 2]], # 18
|
||||||
|
[[-1, 14], Concat, [1]], # 19
|
||||||
|
[-1, BottleneckCSP, [256, 256, 1, False]], # 20
|
||||||
|
[-1, Conv, [256, 256, 3, 2]], # 21
|
||||||
|
[[-1, 10], Concat, [1]], # 22
|
||||||
|
[-1, BottleneckCSP, [512, 512, 1, False]], # 23
|
||||||
|
[[17, 20, 23], Detect,
|
||||||
|
[1, [[3, 9, 5, 11, 4, 20], [7, 18, 6, 39, 12, 31], [19, 50, 38, 81, 68, 157]], [128, 256, 512]]],
|
||||||
|
# Detection head 24
|
||||||
|
|
||||||
|
[16, Conv, [256, 128, 3, 1]], # 25
|
||||||
|
[-1, Upsample, [None, 2, 'nearest']], # 26
|
||||||
|
[-1, BottleneckCSP, [128, 64, 1, False]], # 27
|
||||||
|
[-1, Conv, [64, 32, 3, 1]], # 28
|
||||||
|
[-1, Upsample, [None, 2, 'nearest']], # 29
|
||||||
|
[-1, Conv, [32, 16, 3, 1]], # 30
|
||||||
|
[-1, BottleneckCSP, [16, 8, 1, False]], # 31
|
||||||
|
[-1, Upsample, [None, 2, 'nearest']], # 32
|
||||||
|
[-1, Conv, [8, 2, 3, 1]], # 33 Driving area segmentation head
|
||||||
|
|
||||||
|
[16, Conv, [256, 128, 3, 1]], # 34
|
||||||
|
[-1, Upsample, [None, 2, 'nearest']], # 35
|
||||||
|
[-1, BottleneckCSP, [128, 64, 1, False]], # 36
|
||||||
|
[-1, Conv, [64, 32, 3, 1]], # 37
|
||||||
|
[-1, Upsample, [None, 2, 'nearest']], # 38
|
||||||
|
[-1, Conv, [32, 16, 3, 1]], # 39
|
||||||
|
[-1, BottleneckCSP, [16, 8, 1, False]], # 40
|
||||||
|
[-1, Upsample, [None, 2, 'nearest']], # 41
|
||||||
|
[-1, Conv, [8, 2, 3, 1]] # 42 Lane line segmentation head
|
||||||
|
]
|
||||||
|
|
||||||
|
class MCnet(nn.Module):
|
||||||
|
def __init__(self, block_cfg):
|
||||||
|
super(MCnet, self).__init__()
|
||||||
|
layers, save = [], []
|
||||||
|
self.nc = 1
|
||||||
|
self.detector_index = -1
|
||||||
|
self.det_out_idx = block_cfg[0][0]
|
||||||
|
self.seg_out_idx = block_cfg[0][1:]
|
||||||
|
self.num_anchors = 3
|
||||||
|
self.num_outchannel = 5 + self.nc
|
||||||
|
# Build model
|
||||||
|
for i, (from_, block, args) in enumerate(block_cfg[1:]):
|
||||||
|
block = eval(block) if isinstance(block, str) else block # eval strings
|
||||||
|
if block is Detect:
|
||||||
|
self.detector_index = i
|
||||||
|
block_ = block(*args)
|
||||||
|
block_.index, block_.from_ = i, from_
|
||||||
|
layers.append(block_)
|
||||||
|
save.extend(x % i for x in ([from_] if isinstance(from_, int) else from_) if x != -1) # append to savelist
|
||||||
|
assert self.detector_index == block_cfg[0][0]
|
||||||
|
|
||||||
|
self.model, self.save = nn.Sequential(*layers), sorted(save)
|
||||||
|
self.names = [str(i) for i in range(self.nc)]
|
||||||
|
|
||||||
|
# set stride、anchor for detector
|
||||||
|
# Detector = self.model[self.detector_index] # detector
|
||||||
|
# if isinstance(Detector, Detect):
|
||||||
|
# s = 128 # 2x min stride
|
||||||
|
# # for x in self.forward(torch.zeros(1, 3, s, s)):
|
||||||
|
# # print (x.shape)
|
||||||
|
# with torch.no_grad():
|
||||||
|
# model_out = self.forward(torch.zeros(1, 3, s, s))
|
||||||
|
# detects, _, _ = model_out
|
||||||
|
# Detector.stride = torch.tensor([s / x.shape[-2] for x in detects]) # forward
|
||||||
|
# # print("stride"+str(Detector.stride ))
|
||||||
|
# Detector.anchors /= Detector.stride.view(-1, 1, 1) # Set the anchors for the corresponding scale
|
||||||
|
# check_anchor_order(Detector)
|
||||||
|
# self.stride = Detector.stride
|
||||||
|
def forward(self, x):
|
||||||
|
cache = []
|
||||||
|
out = []
|
||||||
|
det_out = None
|
||||||
|
for i, block in enumerate(self.model):
|
||||||
|
if block.from_ != -1:
|
||||||
|
x = cache[block.from_] if isinstance(block.from_, int) else [x if j == -1 else cache[j] for j in block.from_] # calculate concat detect
|
||||||
|
x = block(x)
|
||||||
|
if i in self.seg_out_idx: # save driving area segment result
|
||||||
|
# m = nn.Sigmoid()
|
||||||
|
# out.append(m(x))
|
||||||
|
out.append(torch.sigmoid(x))
|
||||||
|
if i == self.detector_index:
|
||||||
|
det_out = x
|
||||||
|
cache.append(x if block.index in self.save else None)
|
||||||
|
out[0] = out[0].view(2, 640, 640)
|
||||||
|
out[1] = out[1].view(2, 640, 640)
|
||||||
|
return det_out, out[0], out[1]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
model = MCnet(YOLOP)
|
||||||
|
checkpoint = torch.load('weights/End-to-end.pth', map_location=device)
|
||||||
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
|
model.eval()
|
||||||
|
output_onnx = 'yolop.onnx'
|
||||||
|
inputs = torch.randn(1, 3, 640, 640)
|
||||||
|
# with torch.no_grad():
|
||||||
|
# output = model(inputs)
|
||||||
|
# print(output)
|
||||||
|
|
||||||
|
torch.onnx.export(model, inputs, output_onnx, verbose=False, opset_version=12, input_names=['images'], output_names=['det_out', 'drive_area_seg', 'lane_line_seg'])
|
||||||
|
print('convert', output_onnx, 'to onnx finish!!!')
|
||||||
|
|
||||||
|
try:
|
||||||
|
dnnnet = cv2.dnn.readNet(output_onnx)
|
||||||
|
print('read sucess')
|
||||||
|
except:
|
||||||
|
print('read failed')
|
BIN
images/0ace96c3-48481887.jpg
Normal file
BIN
images/0ace96c3-48481887.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 66 KiB |
BIN
images/3c0e7240-96e390d2.jpg
Normal file
BIN
images/3c0e7240-96e390d2.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 36 KiB |
BIN
images/7dd9ef45-f197db95.jpg
Normal file
BIN
images/7dd9ef45-f197db95.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 82 KiB |
BIN
images/8e1c1ab0-a8b92173.jpg
Normal file
BIN
images/8e1c1ab0-a8b92173.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 77 KiB |
BIN
images/9aa94005-ff1d4c9a.jpg
Normal file
BIN
images/9aa94005-ff1d4c9a.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 43 KiB |
BIN
images/adb4871d-4d063244.jpg
Normal file
BIN
images/adb4871d-4d063244.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 84 KiB |
228
main.cpp
Normal file
228
main.cpp
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <opencv2/dnn.hpp>
|
||||||
|
#include <opencv2/imgproc.hpp>
|
||||||
|
#include <opencv2/highgui.hpp>
|
||||||
|
|
||||||
|
using namespace cv;
|
||||||
|
using namespace dnn;
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
class YOLO
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
YOLO(string modelpath, float confThreshold, float nmsThreshold, float objThreshold);
|
||||||
|
Mat detect(Mat& frame);
|
||||||
|
private:
|
||||||
|
const float mean[3] = { 0.485, 0.456, 0.406 };
|
||||||
|
const float std[3] = { 0.229, 0.224, 0.225 };
|
||||||
|
const float anchors[3][6] = { {3,9,5,11,4,20}, {7,18,6,39,12,31},{19,50,38,81,68,157} };
|
||||||
|
const float stride[3] = { 8.0, 16.0, 32.0 };
|
||||||
|
const string classesFile = "bdd100k.names";
|
||||||
|
const int inpWidth = 640;
|
||||||
|
const int inpHeight = 640;
|
||||||
|
float confThreshold;
|
||||||
|
float nmsThreshold;
|
||||||
|
float objThreshold;
|
||||||
|
const bool keep_ratio = true;
|
||||||
|
vector<string> classes;
|
||||||
|
Net net;
|
||||||
|
Mat resize_image(Mat srcimg, int* newh, int* neww, int* top, int* left);
|
||||||
|
void normalize(Mat& srcimg);
|
||||||
|
void drawPred(int classId, float conf, int left, int top, int right, int bottom, Mat& frame);
|
||||||
|
};
|
||||||
|
|
||||||
|
YOLO::YOLO(string modelpath, float confThreshold, float nmsThreshold, float objThreshold)
|
||||||
|
{
|
||||||
|
this->confThreshold = confThreshold;
|
||||||
|
this->nmsThreshold = nmsThreshold;
|
||||||
|
this->objThreshold = objThreshold;
|
||||||
|
|
||||||
|
ifstream ifs(this->classesFile.c_str());
|
||||||
|
string line;
|
||||||
|
while (getline(ifs, line)) this->classes.push_back(line);
|
||||||
|
this->net = readNet(modelpath);
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat YOLO::resize_image(Mat srcimg, int* newh, int* neww, int* top, int* left)
|
||||||
|
{
|
||||||
|
int srch = srcimg.rows, srcw = srcimg.cols;
|
||||||
|
*newh = this->inpHeight;
|
||||||
|
*neww = this->inpWidth;
|
||||||
|
Mat dstimg;
|
||||||
|
if (this->keep_ratio && srch != srcw)
|
||||||
|
{
|
||||||
|
float hw_scale = (float)srch / srcw;
|
||||||
|
if (hw_scale > 1)
|
||||||
|
{
|
||||||
|
*newh = this->inpHeight;
|
||||||
|
*neww = int(this->inpWidth / hw_scale);
|
||||||
|
resize(srcimg, dstimg, Size(*neww, *newh), INTER_AREA);
|
||||||
|
*left = int((this->inpWidth - *neww) * 0.5);
|
||||||
|
copyMakeBorder(dstimg, dstimg, 0, 0, *left, this->inpWidth - *neww - *left, BORDER_CONSTANT, 0);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
*newh = (int)this->inpHeight * hw_scale;
|
||||||
|
*neww = this->inpWidth;
|
||||||
|
resize(srcimg, dstimg, Size(*neww, *newh), INTER_AREA);
|
||||||
|
*top = (int)(this->inpHeight - *newh) * 0.5;
|
||||||
|
copyMakeBorder(dstimg, dstimg, *top, this->inpHeight - *newh - *top, 0, 0, BORDER_CONSTANT, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
resize(srcimg, dstimg, Size(*neww, *newh), INTER_AREA);
|
||||||
|
}
|
||||||
|
return dstimg;
|
||||||
|
}
|
||||||
|
|
||||||
|
void YOLO::normalize(Mat& img)
|
||||||
|
{
|
||||||
|
img.convertTo(img, CV_32F);
|
||||||
|
int i = 0, j = 0;
|
||||||
|
const float scale = 1.0 / 255.0;
|
||||||
|
for (i = 0; i < img.rows; i++)
|
||||||
|
{
|
||||||
|
float* pdata = (float*)(img.data + i * img.step);
|
||||||
|
for (j = 0; j < img.cols; j++)
|
||||||
|
{
|
||||||
|
pdata[0] = (pdata[0] * scale - this->mean[0]) / this->std[0];
|
||||||
|
pdata[1] = (pdata[1] * scale - this->mean[1]) / this->std[1];
|
||||||
|
pdata[2] = (pdata[2] * scale - this->mean[2]) / this->std[2];
|
||||||
|
pdata += 3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void YOLO::drawPred(int classId, float conf, int left, int top, int right, int bottom, Mat& frame) // Draw the predicted bounding box
|
||||||
|
{
|
||||||
|
//Draw a rectangle displaying the bounding box
|
||||||
|
rectangle(frame, Point(left, top), Point(right, bottom), Scalar(0, 0, 255), 2);
|
||||||
|
|
||||||
|
//Get the label for the class name and its confidence
|
||||||
|
string label = format("%.2f", conf);
|
||||||
|
label = this->classes[classId] + ":" + label;
|
||||||
|
|
||||||
|
//Display the label at the top of the bounding box
|
||||||
|
int baseLine;
|
||||||
|
Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
|
||||||
|
top = max(top, labelSize.height);
|
||||||
|
//rectangle(frame, Point(left, top - int(1.5 * labelSize.height)), Point(left + int(1.5 * labelSize.width), top + baseLine), Scalar(0, 255, 0), FILLED);
|
||||||
|
putText(frame, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 255, 0), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat YOLO::detect(Mat& srcimg)
|
||||||
|
{
|
||||||
|
int newh = 0, neww = 0, padh = 0, padw = 0;
|
||||||
|
Mat dstimg = this->resize_image(srcimg, &newh, &neww, &padh, &padw);
|
||||||
|
this->normalize(dstimg);
|
||||||
|
Mat blob = blobFromImage(dstimg);
|
||||||
|
this->net.setInput(blob);
|
||||||
|
vector<Mat> outs;
|
||||||
|
this->net.forward(outs, this->net.getUnconnectedOutLayersNames());
|
||||||
|
|
||||||
|
Mat outimg = srcimg.clone();
|
||||||
|
float ratioh = (float)newh / srcimg.rows;
|
||||||
|
float ratiow = (float)neww / srcimg.cols;
|
||||||
|
int i = 0, j = 0, area = this->inpHeight*this->inpWidth;
|
||||||
|
float* pdata_drive = (float*)outs[1].data; ///drive area segment
|
||||||
|
float* pdata_lane_line = (float*)outs[2].data; ///lane line segment
|
||||||
|
for (i = 0; i < outimg.rows; i++)
|
||||||
|
{
|
||||||
|
for (j = 0; j < outimg.cols; j++)
|
||||||
|
{
|
||||||
|
const int x = int(j*ratiow) + padw;
|
||||||
|
const int y = int(i*ratioh) + padh;
|
||||||
|
if (pdata_drive[y * this->inpWidth + x] < pdata_drive[area + y * this->inpWidth + x])
|
||||||
|
{
|
||||||
|
outimg.at<Vec3b>(i, j)[0] = 0;
|
||||||
|
outimg.at<Vec3b>(i, j)[1] = 255;
|
||||||
|
outimg.at<Vec3b>(i, j)[2] = 0;
|
||||||
|
}
|
||||||
|
if (pdata_lane_line[y * this->inpWidth + x] < pdata_lane_line[area + y * this->inpWidth + x])
|
||||||
|
{
|
||||||
|
outimg.at<Vec3b>(i, j)[0] = 255;
|
||||||
|
outimg.at<Vec3b>(i, j)[1] = 0;
|
||||||
|
outimg.at<Vec3b>(i, j)[2] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/////generate proposals
|
||||||
|
vector<int> classIds;
|
||||||
|
vector<float> confidences;
|
||||||
|
vector<Rect> boxes;
|
||||||
|
ratioh = (float)srcimg.rows / newh;
|
||||||
|
ratiow = (float)srcimg.cols / neww;
|
||||||
|
int n = 0, q = 0, nout = this->classes.size() + 5, row_ind = 0;
|
||||||
|
float* pdata = (float*)outs[0].data;
|
||||||
|
for (n = 0; n < 3; n++) ///<2F>߶<EFBFBD>
|
||||||
|
{
|
||||||
|
int num_grid_x = (int)(this->inpWidth / this->stride[n]);
|
||||||
|
int num_grid_y = (int)(this->inpHeight / this->stride[n]);
|
||||||
|
for (q = 0; q < 3; q++) ///anchor<6F><72>
|
||||||
|
{
|
||||||
|
const float anchor_w = this->anchors[n][q * 2];
|
||||||
|
const float anchor_h = this->anchors[n][q * 2 + 1];
|
||||||
|
for (i = 0; i < num_grid_y; i++)
|
||||||
|
{
|
||||||
|
for (j = 0; j < num_grid_x; j++)
|
||||||
|
{
|
||||||
|
const float box_score = pdata[4];
|
||||||
|
if (box_score > this->objThreshold)
|
||||||
|
{
|
||||||
|
Mat scores = outs[0].row(row_ind).colRange(5, outs[0].cols);
|
||||||
|
Point classIdPoint;
|
||||||
|
double max_class_socre;
|
||||||
|
// Get the value and location of the maximum score
|
||||||
|
minMaxLoc(scores, 0, &max_class_socre, 0, &classIdPoint);
|
||||||
|
if (max_class_socre > this->confThreshold)
|
||||||
|
{
|
||||||
|
float cx = (pdata[0] * 2.f - 0.5f + j) * this->stride[n]; ///cx
|
||||||
|
float cy = (pdata[1] * 2.f - 0.5f + i) * this->stride[n]; ///cy
|
||||||
|
float w = powf(pdata[2] * 2.f, 2.f) * anchor_w; ///w
|
||||||
|
float h = powf(pdata[3] * 2.f, 2.f) * anchor_h; ///h
|
||||||
|
|
||||||
|
int left = (cx - 0.5*w - padw)*ratiow;
|
||||||
|
int top = (cy - 0.5*h - padh)*ratioh;
|
||||||
|
|
||||||
|
classIds.push_back(classIdPoint.x);
|
||||||
|
confidences.push_back(max_class_socre * box_score);
|
||||||
|
boxes.push_back(Rect(left, top, (int)(w*ratiow), (int)(h*ratioh)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
row_ind++;
|
||||||
|
pdata += nout;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform non maximum suppression to eliminate redundant overlapping boxes with
|
||||||
|
// lower confidences
|
||||||
|
vector<int> indices;
|
||||||
|
NMSBoxes(boxes, confidences, this->confThreshold, this->nmsThreshold, indices);
|
||||||
|
for (size_t i = 0; i < indices.size(); ++i)
|
||||||
|
{
|
||||||
|
int idx = indices[i];
|
||||||
|
Rect box = boxes[idx];
|
||||||
|
this->drawPred(classIds[idx], confidences[idx], box.x, box.y,
|
||||||
|
box.x + box.width, box.y + box.height, outimg);
|
||||||
|
}
|
||||||
|
return outimg;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
YOLO yolo_model("yolop.onnx", 0.25, 0.45, 0.5);
|
||||||
|
string imgpath = "images/0ace96c3-48481887.jpg";
|
||||||
|
Mat srcimg = imread(imgpath);
|
||||||
|
Mat outimg = yolo_model.detect(srcimg);
|
||||||
|
|
||||||
|
static const string kWinName = "Deep learning object detection in OpenCV";
|
||||||
|
namedWindow(kWinName, WINDOW_NORMAL);
|
||||||
|
imshow(kWinName, outimg);
|
||||||
|
waitKey(0);
|
||||||
|
destroyAllWindows();
|
||||||
|
}
|
158
main.py
Normal file
158
main.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
import cv2
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class yolop():
|
||||||
|
def __init__(self, confThreshold=0.25, nmsThreshold=0.5, objThreshold=0.45):
|
||||||
|
with open('bdd100k.names', 'rt') as f:
|
||||||
|
self.classes = f.read().rstrip('\n').split('\n') ###这个是在bdd100k数据集上训练的模型做opencv部署的,如果你在自己的数据集上训练出的模型做opencv部署,那么需要修改self.classes
|
||||||
|
num_classes = len(self.classes)
|
||||||
|
anchors = [[3,9,5,11,4,20], [7,18,6,39,12,31], [19,50,38,81,68,157]]
|
||||||
|
self.nl = len(anchors)
|
||||||
|
self.na = len(anchors[0]) // 2
|
||||||
|
self.no = num_classes + 5
|
||||||
|
self.stride = np.array([8., 16., 32.])
|
||||||
|
self.anchor_grid = np.asarray(anchors, dtype=np.float32).reshape(self.nl, -1, 2)
|
||||||
|
self.inpWidth = 640
|
||||||
|
self.inpHeight = 640
|
||||||
|
self.generate_grid()
|
||||||
|
self.net = cv2.dnn.readNet('yolop.onnx')
|
||||||
|
self.confThreshold = confThreshold
|
||||||
|
self.nmsThreshold = nmsThreshold
|
||||||
|
self.objThreshold = objThreshold
|
||||||
|
self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)
|
||||||
|
self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)
|
||||||
|
self.keep_ratio = True
|
||||||
|
def generate_grid(self):
|
||||||
|
self.grid = [np.zeros(1)] * self.nl
|
||||||
|
self.length = []
|
||||||
|
self.areas = []
|
||||||
|
for i in range(self.nl):
|
||||||
|
h, w = int(self.inpHeight/self.stride[i]), int(self.inpWidth/self.stride[i])
|
||||||
|
self.length.append(int(self.na * h * w))
|
||||||
|
self.areas.append(h*w)
|
||||||
|
if self.grid[i].shape[2:4] != (h,w):
|
||||||
|
self.grid[i] = self._make_grid(w, h)
|
||||||
|
def _make_grid(self, nx=20, ny=20):
|
||||||
|
xv, yv = np.meshgrid(np.arange(ny), np.arange(nx))
|
||||||
|
return np.stack((xv, yv), 2).reshape((-1, 2)).astype(np.float32)
|
||||||
|
|
||||||
|
def postprocess(self, frame, outs, newh, neww, padh, padw):
|
||||||
|
frameHeight = frame.shape[0]
|
||||||
|
frameWidth = frame.shape[1]
|
||||||
|
ratioh, ratiow = frameHeight / newh, frameWidth / neww
|
||||||
|
# Scan through all the bounding boxes output from the network and keep only the
|
||||||
|
# ones with high confidence scores. Assign the box's class label as the class with the highest score.
|
||||||
|
classIds = []
|
||||||
|
confidences = []
|
||||||
|
boxes = []
|
||||||
|
for detection in outs:
|
||||||
|
scores = detection[5:]
|
||||||
|
classId = np.argmax(scores)
|
||||||
|
confidence = scores[classId]
|
||||||
|
if confidence > self.confThreshold and detection[4] > self.objThreshold:
|
||||||
|
center_x = int((detection[0]-padw) * ratiow)
|
||||||
|
center_y = int((detection[1]-padh) * ratioh)
|
||||||
|
width = int(detection[2] * ratiow)
|
||||||
|
height = int(detection[3] * ratioh)
|
||||||
|
left = int(center_x - width / 2)
|
||||||
|
top = int(center_y - height / 2)
|
||||||
|
classIds.append(classId)
|
||||||
|
confidences.append(float(confidence) * detection[4])
|
||||||
|
boxes.append([left, top, width, height])
|
||||||
|
|
||||||
|
# Perform non maximum suppression to eliminate redundant overlapping boxes with
|
||||||
|
# lower confidences.
|
||||||
|
indices = cv2.dnn.NMSBoxes(boxes, confidences, self.confThreshold, self.nmsThreshold)
|
||||||
|
for i in indices:
|
||||||
|
i = i[0]
|
||||||
|
box = boxes[i]
|
||||||
|
left = box[0]
|
||||||
|
top = box[1]
|
||||||
|
width = box[2]
|
||||||
|
height = box[3]
|
||||||
|
frame = self.drawPred(frame, classIds[i], confidences[i], left, top, left + width, top + height)
|
||||||
|
return frame
|
||||||
|
def drawPred(self, frame, classId, conf, left, top, right, bottom):
|
||||||
|
# Draw a bounding box.
|
||||||
|
cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), thickness=2)
|
||||||
|
|
||||||
|
label = '%.2f' % conf
|
||||||
|
label = '%s:%s' % (self.classes[classId], label)
|
||||||
|
|
||||||
|
# Display the label at the top of the bounding box
|
||||||
|
labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||||
|
top = max(top, labelSize[1])
|
||||||
|
# cv.rectangle(frame, (left, top - round(1.5 * labelSize[1])), (left + round(1.5 * labelSize[0]), top + baseLine), (255,255,255), cv.FILLED)
|
||||||
|
cv2.putText(frame, label, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=1)
|
||||||
|
return frame
|
||||||
|
def resize_image(self, srcimg):
|
||||||
|
padh, padw, newh, neww = 0, 0, self.inpHeight, self.inpWidth
|
||||||
|
if self.keep_ratio and srcimg.shape[0] != srcimg.shape[1]:
|
||||||
|
hw_scale = srcimg.shape[0] / srcimg.shape[1]
|
||||||
|
if hw_scale > 1:
|
||||||
|
newh, neww = self.inpHeight, int(self.inpWidth / hw_scale)
|
||||||
|
img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
|
||||||
|
padw = int((self.inpWidth - neww) * 0.5)
|
||||||
|
img = cv2.copyMakeBorder(img, 0, 0, padw, self.inpWidth - neww - padw, cv2.BORDER_CONSTANT,
|
||||||
|
value=0) # add border
|
||||||
|
else:
|
||||||
|
newh, neww = int(self.inpHeight * hw_scale), self.inpWidth
|
||||||
|
img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
|
||||||
|
padh = int((self.inpHeight - newh) * 0.5)
|
||||||
|
img = cv2.copyMakeBorder(img, padh, self.inpHeight - newh - padh, 0, 0, cv2.BORDER_CONSTANT, value=0)
|
||||||
|
else:
|
||||||
|
img = cv2.resize(srcimg, (self.inpWidth, self.inpHeight), interpolation=cv2.INTER_AREA)
|
||||||
|
return img, newh, neww, padh, padw
|
||||||
|
|
||||||
|
def _normalize(self, img): ### c++: https://blog.csdn.net/wuqingshan2010/article/details/107727909
|
||||||
|
img = img.astype(np.float32) / 255.0
|
||||||
|
img = (img - self.mean) / self.std
|
||||||
|
return img
|
||||||
|
def detect(self, srcimg):
|
||||||
|
img, newh, neww, padh, padw = self.resize_image(srcimg)
|
||||||
|
img = self._normalize(img)
|
||||||
|
blob = cv2.dnn.blobFromImage(img)
|
||||||
|
# Sets the input to the network
|
||||||
|
self.net.setInput(blob)
|
||||||
|
|
||||||
|
# Runs the forward pass to get output of the output layers
|
||||||
|
outs = self.net.forward(self.net.getUnconnectedOutLayersNames())
|
||||||
|
# inference output
|
||||||
|
outimg = srcimg.copy()
|
||||||
|
drive_area_mask = outs[1][:, padh:(self.inpHeight - padh), padw:(self.inpWidth - padw)]
|
||||||
|
seg_id = np.argmax(drive_area_mask, axis=0).astype(np.uint8)
|
||||||
|
seg_id = cv2.resize(seg_id, (srcimg.shape[1], srcimg.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
outimg[seg_id == 1] = [0, 255, 0]
|
||||||
|
|
||||||
|
lane_line_mask = outs[2][:, padh:(self.inpHeight - padh), padw:(self.inpWidth - padw)]
|
||||||
|
seg_id = np.argmax(lane_line_mask, axis=0).astype(np.uint8)
|
||||||
|
seg_id = cv2.resize(seg_id, (srcimg.shape[1], srcimg.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
outimg[seg_id == 1] = [255, 0, 0]
|
||||||
|
|
||||||
|
det_out = outs[0]
|
||||||
|
row_ind = 0
|
||||||
|
for i in range(self.nl):
|
||||||
|
det_out[row_ind:row_ind+self.length[i], 0:2] = (det_out[row_ind:row_ind+self.length[i], 0:2] * 2. - 0.5 + np.tile(self.grid[i],(self.na, 1))) * int(self.stride[i])
|
||||||
|
det_out[row_ind:row_ind+self.length[i], 2:4] = (det_out[row_ind:row_ind+self.length[i], 2:4] * 2) ** 2 * np.repeat(self.anchor_grid[i], self.areas[i], axis=0)
|
||||||
|
row_ind += self.length[i]
|
||||||
|
outimg = self.postprocess(outimg, det_out, newh, neww, padh, padw)
|
||||||
|
return outimg
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--imgpath", type=str, default='images/0ace96c3-48481887.jpg', help="image path")
|
||||||
|
parser.add_argument('--confThreshold', default=0.25, type=float, help='class confidence')
|
||||||
|
parser.add_argument('--nmsThreshold', default=0.45, type=float, help='nms iou thresh')
|
||||||
|
parser.add_argument('--objThreshold', default=0.5, type=float, help='object confidence')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
yolonet = yolop(confThreshold=args.confThreshold, nmsThreshold=args.nmsThreshold, objThreshold=args.objThreshold)
|
||||||
|
srcimg = cv2.imread(args.imgpath)
|
||||||
|
outimg = yolonet.detect(srcimg)
|
||||||
|
|
||||||
|
winName = 'Deep learning object detection in OpenCV'
|
||||||
|
cv2.namedWindow(winName, 0)
|
||||||
|
cv2.imshow(winName, outimg)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
cv2.destroyAllWindows()
|
Reference in New Issue
Block a user