commit e7e1025c16196fe46cf55022e0f382b351dc4307 Author: hpc203 <1749069040@qq.com> Date: Sun Aug 29 10:48:26 2021 +0800 Add files via upload diff --git a/bdd100k.names b/bdd100k.names new file mode 100644 index 0000000..ee16feb --- /dev/null +++ b/bdd100k.names @@ -0,0 +1 @@ +car \ No newline at end of file diff --git a/export_onnx.py b/export_onnx.py new file mode 100644 index 0000000..53be746 --- /dev/null +++ b/export_onnx.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn +from lib.models.common import Conv, SPP, Bottleneck, BottleneckCSP, Focus, Concat, Detect, SharpenConv +from torch.nn import Upsample +import cv2 + +# The lane line and the driving area segment branches without share information with each other and without link +YOLOP = [ + [24, 33, 42], # Det_out_idx, Da_Segout_idx, LL_Segout_idx + [-1, Focus, [3, 32, 3]], # 0 + [-1, Conv, [32, 64, 3, 2]], # 1 + [-1, BottleneckCSP, [64, 64, 1]], # 2 + [-1, Conv, [64, 128, 3, 2]], # 3 + [-1, BottleneckCSP, [128, 128, 3]], # 4 + [-1, Conv, [128, 256, 3, 2]], # 5 + [-1, BottleneckCSP, [256, 256, 3]], # 6 + [-1, Conv, [256, 512, 3, 2]], # 7 + [-1, SPP, [512, 512, [5, 9, 13]]], # 8 + [-1, BottleneckCSP, [512, 512, 1, False]], # 9 + [-1, Conv, [512, 256, 1, 1]], # 10 + [-1, Upsample, [None, 2, 'nearest']], # 11 + [[-1, 6], Concat, [1]], # 12 + [-1, BottleneckCSP, [512, 256, 1, False]], # 13 + [-1, Conv, [256, 128, 1, 1]], # 14 + [-1, Upsample, [None, 2, 'nearest']], # 15 + [[-1, 4], Concat, [1]], # 16 #Encoder + + [-1, BottleneckCSP, [256, 128, 1, False]], # 17 + [-1, Conv, [128, 128, 3, 2]], # 18 + [[-1, 14], Concat, [1]], # 19 + [-1, BottleneckCSP, [256, 256, 1, False]], # 20 + [-1, Conv, [256, 256, 3, 2]], # 21 + [[-1, 10], Concat, [1]], # 22 + [-1, BottleneckCSP, [512, 512, 1, False]], # 23 + [[17, 20, 23], Detect, + [1, [[3, 9, 5, 11, 4, 20], [7, 18, 6, 39, 12, 31], [19, 50, 38, 81, 68, 157]], [128, 256, 512]]], + # Detection head 24 + + [16, Conv, [256, 128, 3, 1]], # 25 + [-1, Upsample, [None, 2, 'nearest']], # 26 + [-1, BottleneckCSP, [128, 64, 1, False]], # 27 + [-1, Conv, [64, 32, 3, 1]], # 28 + [-1, Upsample, [None, 2, 'nearest']], # 29 + [-1, Conv, [32, 16, 3, 1]], # 30 + [-1, BottleneckCSP, [16, 8, 1, False]], # 31 + [-1, Upsample, [None, 2, 'nearest']], # 32 + [-1, Conv, [8, 2, 3, 1]], # 33 Driving area segmentation head + + [16, Conv, [256, 128, 3, 1]], # 34 + [-1, Upsample, [None, 2, 'nearest']], # 35 + [-1, BottleneckCSP, [128, 64, 1, False]], # 36 + [-1, Conv, [64, 32, 3, 1]], # 37 + [-1, Upsample, [None, 2, 'nearest']], # 38 + [-1, Conv, [32, 16, 3, 1]], # 39 + [-1, BottleneckCSP, [16, 8, 1, False]], # 40 + [-1, Upsample, [None, 2, 'nearest']], # 41 + [-1, Conv, [8, 2, 3, 1]] # 42 Lane line segmentation head +] + +class MCnet(nn.Module): + def __init__(self, block_cfg): + super(MCnet, self).__init__() + layers, save = [], [] + self.nc = 1 + self.detector_index = -1 + self.det_out_idx = block_cfg[0][0] + self.seg_out_idx = block_cfg[0][1:] + self.num_anchors = 3 + self.num_outchannel = 5 + self.nc + # Build model + for i, (from_, block, args) in enumerate(block_cfg[1:]): + block = eval(block) if isinstance(block, str) else block # eval strings + if block is Detect: + self.detector_index = i + block_ = block(*args) + block_.index, block_.from_ = i, from_ + layers.append(block_) + save.extend(x % i for x in ([from_] if isinstance(from_, int) else from_) if x != -1) # append to savelist + assert self.detector_index == block_cfg[0][0] + + self.model, self.save = nn.Sequential(*layers), sorted(save) + self.names = [str(i) for i in range(self.nc)] + + # set stride、anchor for detector + # Detector = self.model[self.detector_index] # detector + # if isinstance(Detector, Detect): + # s = 128 # 2x min stride + # # for x in self.forward(torch.zeros(1, 3, s, s)): + # # print (x.shape) + # with torch.no_grad(): + # model_out = self.forward(torch.zeros(1, 3, s, s)) + # detects, _, _ = model_out + # Detector.stride = torch.tensor([s / x.shape[-2] for x in detects]) # forward + # # print("stride"+str(Detector.stride )) + # Detector.anchors /= Detector.stride.view(-1, 1, 1) # Set the anchors for the corresponding scale + # check_anchor_order(Detector) + # self.stride = Detector.stride + def forward(self, x): + cache = [] + out = [] + det_out = None + for i, block in enumerate(self.model): + if block.from_ != -1: + x = cache[block.from_] if isinstance(block.from_, int) else [x if j == -1 else cache[j] for j in block.from_] # calculate concat detect + x = block(x) + if i in self.seg_out_idx: # save driving area segment result + # m = nn.Sigmoid() + # out.append(m(x)) + out.append(torch.sigmoid(x)) + if i == self.detector_index: + det_out = x + cache.append(x if block.index in self.save else None) + out[0] = out[0].view(2, 640, 640) + out[1] = out[1].view(2, 640, 640) + return det_out, out[0], out[1] + +if __name__ == "__main__": + device = 'cuda' if torch.cuda.is_available() else 'cpu' + model = MCnet(YOLOP) + checkpoint = torch.load('weights/End-to-end.pth', map_location=device) + model.load_state_dict(checkpoint['state_dict']) + model.eval() + output_onnx = 'yolop.onnx' + inputs = torch.randn(1, 3, 640, 640) + # with torch.no_grad(): + # output = model(inputs) + # print(output) + + torch.onnx.export(model, inputs, output_onnx, verbose=False, opset_version=12, input_names=['images'], output_names=['det_out', 'drive_area_seg', 'lane_line_seg']) + print('convert', output_onnx, 'to onnx finish!!!') + + try: + dnnnet = cv2.dnn.readNet(output_onnx) + print('read sucess') + except: + print('read failed') \ No newline at end of file diff --git a/images/0ace96c3-48481887.jpg b/images/0ace96c3-48481887.jpg new file mode 100644 index 0000000..8981747 Binary files /dev/null and b/images/0ace96c3-48481887.jpg differ diff --git a/images/3c0e7240-96e390d2.jpg b/images/3c0e7240-96e390d2.jpg new file mode 100644 index 0000000..bf2b675 Binary files /dev/null and b/images/3c0e7240-96e390d2.jpg differ diff --git a/images/7dd9ef45-f197db95.jpg b/images/7dd9ef45-f197db95.jpg new file mode 100644 index 0000000..5de5771 Binary files /dev/null and b/images/7dd9ef45-f197db95.jpg differ diff --git a/images/8e1c1ab0-a8b92173.jpg b/images/8e1c1ab0-a8b92173.jpg new file mode 100644 index 0000000..b9d3825 Binary files /dev/null and b/images/8e1c1ab0-a8b92173.jpg differ diff --git a/images/9aa94005-ff1d4c9a.jpg b/images/9aa94005-ff1d4c9a.jpg new file mode 100644 index 0000000..e0e9278 Binary files /dev/null and b/images/9aa94005-ff1d4c9a.jpg differ diff --git a/images/adb4871d-4d063244.jpg b/images/adb4871d-4d063244.jpg new file mode 100644 index 0000000..1564575 Binary files /dev/null and b/images/adb4871d-4d063244.jpg differ diff --git a/main.cpp b/main.cpp new file mode 100644 index 0000000..a289762 --- /dev/null +++ b/main.cpp @@ -0,0 +1,228 @@ +#include +#include +#include +#include +#include +#include + +using namespace cv; +using namespace dnn; +using namespace std; + +class YOLO +{ +public: + YOLO(string modelpath, float confThreshold, float nmsThreshold, float objThreshold); + Mat detect(Mat& frame); +private: + const float mean[3] = { 0.485, 0.456, 0.406 }; + const float std[3] = { 0.229, 0.224, 0.225 }; + const float anchors[3][6] = { {3,9,5,11,4,20}, {7,18,6,39,12,31},{19,50,38,81,68,157} }; + const float stride[3] = { 8.0, 16.0, 32.0 }; + const string classesFile = "bdd100k.names"; + const int inpWidth = 640; + const int inpHeight = 640; + float confThreshold; + float nmsThreshold; + float objThreshold; + const bool keep_ratio = true; + vector classes; + Net net; + Mat resize_image(Mat srcimg, int* newh, int* neww, int* top, int* left); + void normalize(Mat& srcimg); + void drawPred(int classId, float conf, int left, int top, int right, int bottom, Mat& frame); +}; + +YOLO::YOLO(string modelpath, float confThreshold, float nmsThreshold, float objThreshold) +{ + this->confThreshold = confThreshold; + this->nmsThreshold = nmsThreshold; + this->objThreshold = objThreshold; + + ifstream ifs(this->classesFile.c_str()); + string line; + while (getline(ifs, line)) this->classes.push_back(line); + this->net = readNet(modelpath); +} + +Mat YOLO::resize_image(Mat srcimg, int* newh, int* neww, int* top, int* left) +{ + int srch = srcimg.rows, srcw = srcimg.cols; + *newh = this->inpHeight; + *neww = this->inpWidth; + Mat dstimg; + if (this->keep_ratio && srch != srcw) + { + float hw_scale = (float)srch / srcw; + if (hw_scale > 1) + { + *newh = this->inpHeight; + *neww = int(this->inpWidth / hw_scale); + resize(srcimg, dstimg, Size(*neww, *newh), INTER_AREA); + *left = int((this->inpWidth - *neww) * 0.5); + copyMakeBorder(dstimg, dstimg, 0, 0, *left, this->inpWidth - *neww - *left, BORDER_CONSTANT, 0); + } + else + { + *newh = (int)this->inpHeight * hw_scale; + *neww = this->inpWidth; + resize(srcimg, dstimg, Size(*neww, *newh), INTER_AREA); + *top = (int)(this->inpHeight - *newh) * 0.5; + copyMakeBorder(dstimg, dstimg, *top, this->inpHeight - *newh - *top, 0, 0, BORDER_CONSTANT, 0); + } + } + else + { + resize(srcimg, dstimg, Size(*neww, *newh), INTER_AREA); + } + return dstimg; +} + +void YOLO::normalize(Mat& img) +{ + img.convertTo(img, CV_32F); + int i = 0, j = 0; + const float scale = 1.0 / 255.0; + for (i = 0; i < img.rows; i++) + { + float* pdata = (float*)(img.data + i * img.step); + for (j = 0; j < img.cols; j++) + { + pdata[0] = (pdata[0] * scale - this->mean[0]) / this->std[0]; + pdata[1] = (pdata[1] * scale - this->mean[1]) / this->std[1]; + pdata[2] = (pdata[2] * scale - this->mean[2]) / this->std[2]; + pdata += 3; + } + } +} + +void YOLO::drawPred(int classId, float conf, int left, int top, int right, int bottom, Mat& frame) // Draw the predicted bounding box +{ + //Draw a rectangle displaying the bounding box + rectangle(frame, Point(left, top), Point(right, bottom), Scalar(0, 0, 255), 2); + + //Get the label for the class name and its confidence + string label = format("%.2f", conf); + label = this->classes[classId] + ":" + label; + + //Display the label at the top of the bounding box + int baseLine; + Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine); + top = max(top, labelSize.height); + //rectangle(frame, Point(left, top - int(1.5 * labelSize.height)), Point(left + int(1.5 * labelSize.width), top + baseLine), Scalar(0, 255, 0), FILLED); + putText(frame, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 255, 0), 1); +} + +Mat YOLO::detect(Mat& srcimg) +{ + int newh = 0, neww = 0, padh = 0, padw = 0; + Mat dstimg = this->resize_image(srcimg, &newh, &neww, &padh, &padw); + this->normalize(dstimg); + Mat blob = blobFromImage(dstimg); + this->net.setInput(blob); + vector outs; + this->net.forward(outs, this->net.getUnconnectedOutLayersNames()); + + Mat outimg = srcimg.clone(); + float ratioh = (float)newh / srcimg.rows; + float ratiow = (float)neww / srcimg.cols; + int i = 0, j = 0, area = this->inpHeight*this->inpWidth; + float* pdata_drive = (float*)outs[1].data; ///drive area segment + float* pdata_lane_line = (float*)outs[2].data; ///lane line segment + for (i = 0; i < outimg.rows; i++) + { + for (j = 0; j < outimg.cols; j++) + { + const int x = int(j*ratiow) + padw; + const int y = int(i*ratioh) + padh; + if (pdata_drive[y * this->inpWidth + x] < pdata_drive[area + y * this->inpWidth + x]) + { + outimg.at(i, j)[0] = 0; + outimg.at(i, j)[1] = 255; + outimg.at(i, j)[2] = 0; + } + if (pdata_lane_line[y * this->inpWidth + x] < pdata_lane_line[area + y * this->inpWidth + x]) + { + outimg.at(i, j)[0] = 255; + outimg.at(i, j)[1] = 0; + outimg.at(i, j)[2] = 0; + } + } + } + /////generate proposals + vector classIds; + vector confidences; + vector boxes; + ratioh = (float)srcimg.rows / newh; + ratiow = (float)srcimg.cols / neww; + int n = 0, q = 0, nout = this->classes.size() + 5, row_ind = 0; + float* pdata = (float*)outs[0].data; + for (n = 0; n < 3; n++) ///�߶� + { + int num_grid_x = (int)(this->inpWidth / this->stride[n]); + int num_grid_y = (int)(this->inpHeight / this->stride[n]); + for (q = 0; q < 3; q++) ///anchor�� + { + const float anchor_w = this->anchors[n][q * 2]; + const float anchor_h = this->anchors[n][q * 2 + 1]; + for (i = 0; i < num_grid_y; i++) + { + for (j = 0; j < num_grid_x; j++) + { + const float box_score = pdata[4]; + if (box_score > this->objThreshold) + { + Mat scores = outs[0].row(row_ind).colRange(5, outs[0].cols); + Point classIdPoint; + double max_class_socre; + // Get the value and location of the maximum score + minMaxLoc(scores, 0, &max_class_socre, 0, &classIdPoint); + if (max_class_socre > this->confThreshold) + { + float cx = (pdata[0] * 2.f - 0.5f + j) * this->stride[n]; ///cx + float cy = (pdata[1] * 2.f - 0.5f + i) * this->stride[n]; ///cy + float w = powf(pdata[2] * 2.f, 2.f) * anchor_w; ///w + float h = powf(pdata[3] * 2.f, 2.f) * anchor_h; ///h + + int left = (cx - 0.5*w - padw)*ratiow; + int top = (cy - 0.5*h - padh)*ratioh; + + classIds.push_back(classIdPoint.x); + confidences.push_back(max_class_socre * box_score); + boxes.push_back(Rect(left, top, (int)(w*ratiow), (int)(h*ratioh))); + } + } + row_ind++; + pdata += nout; + } + } + } + } + + // Perform non maximum suppression to eliminate redundant overlapping boxes with + // lower confidences + vector indices; + NMSBoxes(boxes, confidences, this->confThreshold, this->nmsThreshold, indices); + for (size_t i = 0; i < indices.size(); ++i) + { + int idx = indices[i]; + Rect box = boxes[idx]; + this->drawPred(classIds[idx], confidences[idx], box.x, box.y, + box.x + box.width, box.y + box.height, outimg); + } + return outimg; +} + +int main() +{ + YOLO yolo_model("yolop.onnx", 0.25, 0.45, 0.5); + string imgpath = "images/0ace96c3-48481887.jpg"; + Mat srcimg = imread(imgpath); + Mat outimg = yolo_model.detect(srcimg); + + static const string kWinName = "Deep learning object detection in OpenCV"; + namedWindow(kWinName, WINDOW_NORMAL); + imshow(kWinName, outimg); + waitKey(0); + destroyAllWindows(); +} \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..ea33338 --- /dev/null +++ b/main.py @@ -0,0 +1,158 @@ +import cv2 +import argparse +import numpy as np + +class yolop(): + def __init__(self, confThreshold=0.25, nmsThreshold=0.5, objThreshold=0.45): + with open('bdd100k.names', 'rt') as f: + self.classes = f.read().rstrip('\n').split('\n') ###这个是在bdd100k数据集上训练的模型做opencv部署的,如果你在自己的数据集上训练出的模型做opencv部署,那么需要修改self.classes + num_classes = len(self.classes) + anchors = [[3,9,5,11,4,20], [7,18,6,39,12,31], [19,50,38,81,68,157]] + self.nl = len(anchors) + self.na = len(anchors[0]) // 2 + self.no = num_classes + 5 + self.stride = np.array([8., 16., 32.]) + self.anchor_grid = np.asarray(anchors, dtype=np.float32).reshape(self.nl, -1, 2) + self.inpWidth = 640 + self.inpHeight = 640 + self.generate_grid() + self.net = cv2.dnn.readNet('yolop.onnx') + self.confThreshold = confThreshold + self.nmsThreshold = nmsThreshold + self.objThreshold = objThreshold + self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) + self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) + self.keep_ratio = True + def generate_grid(self): + self.grid = [np.zeros(1)] * self.nl + self.length = [] + self.areas = [] + for i in range(self.nl): + h, w = int(self.inpHeight/self.stride[i]), int(self.inpWidth/self.stride[i]) + self.length.append(int(self.na * h * w)) + self.areas.append(h*w) + if self.grid[i].shape[2:4] != (h,w): + self.grid[i] = self._make_grid(w, h) + def _make_grid(self, nx=20, ny=20): + xv, yv = np.meshgrid(np.arange(ny), np.arange(nx)) + return np.stack((xv, yv), 2).reshape((-1, 2)).astype(np.float32) + + def postprocess(self, frame, outs, newh, neww, padh, padw): + frameHeight = frame.shape[0] + frameWidth = frame.shape[1] + ratioh, ratiow = frameHeight / newh, frameWidth / neww + # Scan through all the bounding boxes output from the network and keep only the + # ones with high confidence scores. Assign the box's class label as the class with the highest score. + classIds = [] + confidences = [] + boxes = [] + for detection in outs: + scores = detection[5:] + classId = np.argmax(scores) + confidence = scores[classId] + if confidence > self.confThreshold and detection[4] > self.objThreshold: + center_x = int((detection[0]-padw) * ratiow) + center_y = int((detection[1]-padh) * ratioh) + width = int(detection[2] * ratiow) + height = int(detection[3] * ratioh) + left = int(center_x - width / 2) + top = int(center_y - height / 2) + classIds.append(classId) + confidences.append(float(confidence) * detection[4]) + boxes.append([left, top, width, height]) + + # Perform non maximum suppression to eliminate redundant overlapping boxes with + # lower confidences. + indices = cv2.dnn.NMSBoxes(boxes, confidences, self.confThreshold, self.nmsThreshold) + for i in indices: + i = i[0] + box = boxes[i] + left = box[0] + top = box[1] + width = box[2] + height = box[3] + frame = self.drawPred(frame, classIds[i], confidences[i], left, top, left + width, top + height) + return frame + def drawPred(self, frame, classId, conf, left, top, right, bottom): + # Draw a bounding box. + cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), thickness=2) + + label = '%.2f' % conf + label = '%s:%s' % (self.classes[classId], label) + + # Display the label at the top of the bounding box + labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) + top = max(top, labelSize[1]) + # cv.rectangle(frame, (left, top - round(1.5 * labelSize[1])), (left + round(1.5 * labelSize[0]), top + baseLine), (255,255,255), cv.FILLED) + cv2.putText(frame, label, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=1) + return frame + def resize_image(self, srcimg): + padh, padw, newh, neww = 0, 0, self.inpHeight, self.inpWidth + if self.keep_ratio and srcimg.shape[0] != srcimg.shape[1]: + hw_scale = srcimg.shape[0] / srcimg.shape[1] + if hw_scale > 1: + newh, neww = self.inpHeight, int(self.inpWidth / hw_scale) + img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA) + padw = int((self.inpWidth - neww) * 0.5) + img = cv2.copyMakeBorder(img, 0, 0, padw, self.inpWidth - neww - padw, cv2.BORDER_CONSTANT, + value=0) # add border + else: + newh, neww = int(self.inpHeight * hw_scale), self.inpWidth + img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA) + padh = int((self.inpHeight - newh) * 0.5) + img = cv2.copyMakeBorder(img, padh, self.inpHeight - newh - padh, 0, 0, cv2.BORDER_CONSTANT, value=0) + else: + img = cv2.resize(srcimg, (self.inpWidth, self.inpHeight), interpolation=cv2.INTER_AREA) + return img, newh, neww, padh, padw + + def _normalize(self, img): ### c++: https://blog.csdn.net/wuqingshan2010/article/details/107727909 + img = img.astype(np.float32) / 255.0 + img = (img - self.mean) / self.std + return img + def detect(self, srcimg): + img, newh, neww, padh, padw = self.resize_image(srcimg) + img = self._normalize(img) + blob = cv2.dnn.blobFromImage(img) + # Sets the input to the network + self.net.setInput(blob) + + # Runs the forward pass to get output of the output layers + outs = self.net.forward(self.net.getUnconnectedOutLayersNames()) + # inference output + outimg = srcimg.copy() + drive_area_mask = outs[1][:, padh:(self.inpHeight - padh), padw:(self.inpWidth - padw)] + seg_id = np.argmax(drive_area_mask, axis=0).astype(np.uint8) + seg_id = cv2.resize(seg_id, (srcimg.shape[1], srcimg.shape[0]), interpolation=cv2.INTER_NEAREST) + outimg[seg_id == 1] = [0, 255, 0] + + lane_line_mask = outs[2][:, padh:(self.inpHeight - padh), padw:(self.inpWidth - padw)] + seg_id = np.argmax(lane_line_mask, axis=0).astype(np.uint8) + seg_id = cv2.resize(seg_id, (srcimg.shape[1], srcimg.shape[0]), interpolation=cv2.INTER_NEAREST) + outimg[seg_id == 1] = [255, 0, 0] + + det_out = outs[0] + row_ind = 0 + for i in range(self.nl): + det_out[row_ind:row_ind+self.length[i], 0:2] = (det_out[row_ind:row_ind+self.length[i], 0:2] * 2. - 0.5 + np.tile(self.grid[i],(self.na, 1))) * int(self.stride[i]) + det_out[row_ind:row_ind+self.length[i], 2:4] = (det_out[row_ind:row_ind+self.length[i], 2:4] * 2) ** 2 * np.repeat(self.anchor_grid[i], self.areas[i], axis=0) + row_ind += self.length[i] + outimg = self.postprocess(outimg, det_out, newh, neww, padh, padw) + return outimg + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--imgpath", type=str, default='images/0ace96c3-48481887.jpg', help="image path") + parser.add_argument('--confThreshold', default=0.25, type=float, help='class confidence') + parser.add_argument('--nmsThreshold', default=0.45, type=float, help='nms iou thresh') + parser.add_argument('--objThreshold', default=0.5, type=float, help='object confidence') + args = parser.parse_args() + + yolonet = yolop(confThreshold=args.confThreshold, nmsThreshold=args.nmsThreshold, objThreshold=args.objThreshold) + srcimg = cv2.imread(args.imgpath) + outimg = yolonet.detect(srcimg) + + winName = 'Deep learning object detection in OpenCV' + cv2.namedWindow(winName, 0) + cv2.imshow(winName, outimg) + cv2.waitKey(0) + cv2.destroyAllWindows() \ No newline at end of file