mirror of
https://github.com/we0091234/yolov7_plate.git
synced 2025-09-26 21:01:13 +08:00
test.py bug fix
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -9,6 +9,7 @@ runs/
|
||||
.vscode/
|
||||
build/
|
||||
result1/
|
||||
result/
|
||||
*.pyc
|
||||
plate/
|
||||
# 不忽略下面指定的文件类型
|
||||
|
@@ -1,6 +1,6 @@
|
||||
# parameters
|
||||
nc: 1 # number of classes
|
||||
nkpt: 5 # number of keypoints
|
||||
nc: 2 # number of classes
|
||||
nkpt: 4 # number of keypoints
|
||||
depth_multiple: 1.0 # model depth multiple
|
||||
width_multiple: 1.0 # layer channel multiple
|
||||
dw_conv_kpt: True
|
||||
|
@@ -9,7 +9,7 @@
|
||||
|
||||
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
|
||||
train: /mnt/Gpan/Mydata/pytorchPorject/datasets/ccpd/train_detect
|
||||
val: /mnt/Gpan/Mydata/pytorchPorject/datasets/ccpd/val_detect
|
||||
val: /mnt/Gpan/Mydata/pytorchPorject/datasets/ccpd/train_detect
|
||||
#val: /ssd_1t/derron/yolov5-face/data/widerface/train/ # 4952 images
|
||||
|
||||
# number of classes
|
||||
|
@@ -71,6 +71,7 @@ def get_plate_rec_landmark(img, xyxy, conf, landmarks, class_num,device,plate_re
|
||||
result_dict['landmarks']=landmarks_np.tolist()
|
||||
result_dict['plate_no']=plate_number
|
||||
result_dict['roi_height']=roi_img.shape[0]
|
||||
result_dict['score']=conf
|
||||
return result_dict
|
||||
|
||||
def detect_Recognition_plate(model, orgimg, device,plate_rec_model,img_size):
|
||||
@@ -119,7 +120,7 @@ def draw_result(orgimg,dict_list):
|
||||
|
||||
height_area = result['roi_height']
|
||||
landmarks=result['landmarks']
|
||||
result = result['plate_no']
|
||||
result = result['plate_no']+" "+"{:.2f}".format(result['score'])
|
||||
result_str+=result+" "
|
||||
cv2.rectangle(orgimg,(rect_area[0],rect_area[1]),(rect_area[2],rect_area[3]),(0,0,255),2) #画框
|
||||
if len(result)>=7:
|
||||
@@ -133,9 +134,9 @@ def draw_result(orgimg,dict_list):
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--detect_model', nargs='+', type=str, default='runs/train/exp/weights/best.pt', help='model.pt path(s)')
|
||||
parser.add_argument('--detect_model', nargs='+', type=str, default='weights/plate_detect.pt', help='model.pt path(s)')
|
||||
parser.add_argument('--rec_model', type=str, default='weights/plate_rec.pth', help='model.pt path(s)')
|
||||
parser.add_argument('--source', type=str, default='imgs', help='source') # file/folder, 0 for webcam
|
||||
parser.add_argument('--source', type=str, default='../Chinese_license_plate_detection_recognition/mytest/', help='source') # file/folder, 0 for webcam
|
||||
# parser.add_argument('--img-size', nargs= '+', type=int, default=640, help='inference size (pixels)')
|
||||
parser.add_argument('--img_size', type=int, default=640, help='inference size (pixels)')
|
||||
parser.add_argument('--output', type=str, default='result', help='source')
|
||||
@@ -145,6 +146,7 @@ if __name__ == '__main__':
|
||||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
model = attempt_load(opt.detect_model, map_location=device)
|
||||
# torch.save()
|
||||
plate_rec_model=init_model(device,opt.rec_model)
|
||||
if not os.path.exists(opt.output):
|
||||
os.mkdir(opt.output)
|
||||
|
6
test.py
6
test.py
@@ -362,8 +362,8 @@ def test(data,
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(prog='test.py')
|
||||
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
|
||||
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
|
||||
parser.add_argument('--weights', nargs='+', type=str, default='weights/plate_detect.pt', help='model.pt path(s)')
|
||||
parser.add_argument('--data', type=str, default='data/plate.yaml', help='*.data path')
|
||||
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
|
||||
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
|
||||
parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')
|
||||
@@ -384,7 +384,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--project', default='runs/test', help='save to project/name')
|
||||
parser.add_argument('--name', default='exp', help='save to project/name')
|
||||
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
|
||||
parser.add_argument('--kpt-label', type=int, default=5, help='number of keypoints')
|
||||
parser.add_argument('--kpt-label', type=int, default=4, help='number of keypoints')
|
||||
parser.add_argument('--flip-test', action='store_true', help='Whether to run flip_test or not')
|
||||
opt = parser.parse_args()
|
||||
opt.save_json |= opt.data.endswith('coco.yaml')
|
||||
|
10
train.py
10
train.py
@@ -417,7 +417,15 @@ def train(hyp, opt, device, tb_writer=None):
|
||||
# Save last, best and delete
|
||||
torch.save(ckpt, last)
|
||||
if best_fitness == fi:
|
||||
torch.save(ckpt, best)
|
||||
best_ckpt = {'epoch': epoch,
|
||||
'best_fitness': best_fitness,
|
||||
# 'training_results': results_file.read_text(),
|
||||
'model': deepcopy(model.module if is_parallel(model) else model).half(),
|
||||
'ema': deepcopy(ema.ema).half(),
|
||||
'updates': ema.updates,
|
||||
# 'optimizer': optimizer.state_dict(),
|
||||
'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
|
||||
torch.save(best_ckpt, best)
|
||||
if wandb_logger.wandb:
|
||||
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
|
||||
wandb_logger.log_model(
|
||||
|
@@ -361,7 +361,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
self.stride = stride
|
||||
self.path = path
|
||||
self.kpt_label = kpt_label
|
||||
self.flip_index = [1, 0, 2, 4, 3] #人脸
|
||||
self.flip_index = [1, 0, 2, 4, 3] #人脸 4是鼻子
|
||||
self.flip_index_plate = [1, 0, 3, 2] #车牌
|
||||
|
||||
try:
|
||||
|
@@ -530,8 +530,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
|
||||
|
||||
# Detections matrix nx6 (xyxy, conf, cls)
|
||||
if multi_label:
|
||||
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
|
||||
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
|
||||
i, j = (x[:, 5:5+nc] > conf_thres).nonzero(as_tuple=False).T
|
||||
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float(),x[i,5+nc:]), 1)
|
||||
else: # best class only
|
||||
if not kpt_label:
|
||||
conf, j = x[:, 5:].max(1, keepdim=True)
|
||||
|
Reference in New Issue
Block a user