mirror of
https://github.com/we0091234/Car_recognition.git
synced 2025-09-27 04:45:52 +08:00
add car color
This commit is contained in:
@@ -16,6 +16,7 @@ from plate_recognition.plate_rec import get_plate_result,allFilePath,init_model,
|
||||
# from plate_recognition.plate_cls import cv_imread
|
||||
from plate_recognition.double_plate_split_merge import get_split_merge
|
||||
from plate_recognition.color_rec import plate_color_rec,init_color_model
|
||||
from car_recognition.car_rec import init_car_rec_model,get_color_and_score
|
||||
|
||||
clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)]
|
||||
danger=['危','险']
|
||||
@@ -65,7 +66,7 @@ def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None): #
|
||||
|
||||
coords[:, [0, 2, 4, 6]] -= pad[0] # x padding
|
||||
coords[:, [1, 3, 5, 7]] -= pad[1] # y padding
|
||||
coords[:, :10] /= gain
|
||||
coords[:, :8] /= gain
|
||||
#clip_coords(coords, img0_shape)
|
||||
coords[:, 0].clamp_(0, img0_shape[1]) # x1
|
||||
coords[:, 1].clamp_(0, img0_shape[0]) # y1
|
||||
@@ -79,7 +80,7 @@ def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None): #
|
||||
# coords[:, 9].clamp_(0, img0_shape[0]) # y5
|
||||
return coords
|
||||
|
||||
def get_plate_rec_landmark(img, xyxy, conf, landmarks, class_num,device,plate_rec_model,plate_color_model=None):
|
||||
def get_plate_rec_landmark(img, xyxy, conf, landmarks, class_num,device,plate_rec_model,car_rec_model):
|
||||
h,w,c = img.shape
|
||||
result_dict={}
|
||||
x1 = int(xyxy[0])
|
||||
@@ -90,10 +91,15 @@ def get_plate_rec_landmark(img, xyxy, conf, landmarks, class_num,device,plate_re
|
||||
rect=[x1,y1,x2,y2]
|
||||
|
||||
if int(class_num) ==2:
|
||||
#
|
||||
car_roi_img = img[y1:y2,x1:x2]
|
||||
car_color,color_conf=get_color_and_score(car_rec_model,car_roi_img,device)
|
||||
result_dict['class_type']=class_type[int(class_num)]
|
||||
result_dict['rect']=rect #车辆roi
|
||||
result_dict['score']=conf
|
||||
result_dict['score']=conf #车牌区域检测得分
|
||||
result_dict['object_no']=int(class_num)
|
||||
result_dict['car_color']=car_color
|
||||
result_dict['color_conf']=color_conf
|
||||
return result_dict
|
||||
|
||||
for i in range(4):
|
||||
@@ -103,10 +109,9 @@ def get_plate_rec_landmark(img, xyxy, conf, landmarks, class_num,device,plate_re
|
||||
|
||||
class_label= int(class_num) #车牌的的类型0代表单牌,1代表双层车牌
|
||||
roi_img = four_point_transform(img,landmarks_np) #透视变换得到车牌小图
|
||||
color_code = plate_color_rec(roi_img,plate_color_model,device) #车牌颜色识别
|
||||
if class_label: #判断是否是双层车牌,是双牌的话进行分割后然后拼接
|
||||
roi_img=get_split_merge(roi_img)
|
||||
plate_number = get_plate_result(roi_img,device,plate_rec_model) #对车牌小图进行识别
|
||||
plate_number ,plate_color= get_plate_result(roi_img,device,plate_rec_model) #对车牌小图进行识别,得到颜色和车牌号
|
||||
for dan in danger: #只要出现‘危’或者‘险’就是危险品车牌
|
||||
if dan in plate_number:
|
||||
plate_number='危险品'
|
||||
@@ -116,14 +121,14 @@ def get_plate_rec_landmark(img, xyxy, conf, landmarks, class_num,device,plate_re
|
||||
result_dict['landmarks']=landmarks_np.tolist() #车牌角点坐标
|
||||
result_dict['plate_no']=plate_number #车牌号
|
||||
result_dict['roi_height']=roi_img.shape[0] #车牌高度
|
||||
result_dict['plate_color']=color_code #车牌颜色
|
||||
result_dict['plate_color']=plate_color #车牌颜色
|
||||
result_dict['object_no']=class_label #单双层 0单层 1双层
|
||||
result_dict['score']=conf
|
||||
result_dict['score']=conf #车牌区域检测得分
|
||||
return result_dict
|
||||
|
||||
|
||||
|
||||
def detect_Recognition_plate(model, orgimg, device,plate_rec_model,img_size,plate_color_model=None):
|
||||
def detect_Recognition_plate(model, orgimg, device,plate_rec_model,img_size,car_rec_model=None):
|
||||
# Load model
|
||||
# img_size = opt_img_size
|
||||
conf_thres = 0.3
|
||||
@@ -183,7 +188,7 @@ def detect_Recognition_plate(model, orgimg, device,plate_rec_model,img_size,plat
|
||||
conf = det[j, 4].cpu().numpy()
|
||||
landmarks = det[j, 5:13].view(-1).tolist()
|
||||
class_num = det[j, 13].cpu().numpy()
|
||||
result_dict = get_plate_rec_landmark(orgimg, xyxy, conf, landmarks, class_num,device,plate_rec_model,plate_color_model)
|
||||
result_dict = get_plate_rec_landmark(orgimg, xyxy, conf, landmarks, class_num,device,plate_rec_model,car_rec_model)
|
||||
dict_list.append(result_dict)
|
||||
return dict_list
|
||||
# cv2.imwrite('result.jpg', orgimg)
|
||||
@@ -202,7 +207,7 @@ def draw_result(orgimg,dict_list):
|
||||
rect_area[2]=min(orgimg.shape[1],int(rect_area[2]+padding_w))
|
||||
rect_area[3]=min(orgimg.shape[0],int(rect_area[3]+padding_h))
|
||||
|
||||
height_area = result['roi_height']
|
||||
height_area = int(result['roi_height']/2)
|
||||
landmarks=result['landmarks']
|
||||
result_p = result['plate_no']
|
||||
if result['object_no']==0:#单层
|
||||
@@ -218,6 +223,13 @@ def draw_result(orgimg,dict_list):
|
||||
orgimg=cv2ImgAddText(orgimg,result_p,rect_area[0],rect_area[3],(0,255,0),height_area)
|
||||
else:
|
||||
orgimg=cv2ImgAddText(orgimg,result_p,rect_area[0]-height_area,rect_area[1]-height_area-10,(0,255,0),height_area)
|
||||
else:
|
||||
height_area=int((rect_area[3]-rect_area[1])/20)
|
||||
car_color = result['car_color']
|
||||
car_color_str="车辆颜色:"
|
||||
car_color_str+=car_color
|
||||
orgimg=cv2ImgAddText(orgimg,car_color_str,rect_area[0],rect_area[1],(0,255,0),height_area)
|
||||
|
||||
cv2.rectangle(orgimg,(rect_area[0],rect_area[1]),(rect_area[2],rect_area[3]),object_color[object_no],2) #画框
|
||||
print(result_str)
|
||||
return orgimg
|
||||
@@ -233,8 +245,8 @@ def get_second(capture):
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--detect_model', nargs='+', type=str, default='weights/detect.pt', help='model.pt path(s)') #检测模型
|
||||
parser.add_argument('--rec_model', type=str, default='weights/plate_rec.pth', help='model.pt path(s)')#识别模型
|
||||
parser.add_argument('--color_model',type=str,default='weights/color_classify.pth',help='plate color')#颜色识别模型
|
||||
parser.add_argument('--rec_model', type=str, default='weights/plate_rec_color.pth', help='model.pt path(s)')#车牌识别+车牌颜色识别模型
|
||||
parser.add_argument('--car_rec_model',type=str,default='weights/car_rec_color.pth',help='car_rec_model') #车辆识别模型
|
||||
parser.add_argument('--image_path', type=str, default='imgs', help='source')
|
||||
parser.add_argument('--img_size', type=int, default=384, help='inference size (pixels)')
|
||||
parser.add_argument('--output', type=str, default='result1', help='source')
|
||||
@@ -250,12 +262,12 @@ if __name__ == '__main__':
|
||||
|
||||
detect_model = load_model(opt.detect_model, device) #初始化检测模型
|
||||
plate_rec_model=init_model(device,opt.rec_model) #初始化识别模型
|
||||
car_rec_model = init_car_rec_model(opt.car_rec_model,device) #初始化车辆识别模型
|
||||
#算参数量
|
||||
total = sum(p.numel() for p in detect_model.parameters())
|
||||
total_1 = sum(p.numel() for p in plate_rec_model.parameters())
|
||||
print("detect params: %.2fM,rec params: %.2fM" % (total/1e6,total_1/1e6))
|
||||
|
||||
plate_color_model =init_color_model(opt.color_model,device)
|
||||
time_all = 0
|
||||
time_begin=time.time()
|
||||
if not opt.video: #处理图片
|
||||
@@ -273,7 +285,8 @@ if __name__ == '__main__':
|
||||
if img.shape[-1]==4:
|
||||
img=cv2.cvtColor(img,cv2.COLOR_BGRA2BGR)
|
||||
# detect_one(model,img_path,device)
|
||||
dict_list=detect_Recognition_plate(detect_model, img, device,plate_rec_model,opt.img_size,plate_color_model)
|
||||
dict_list=detect_Recognition_plate(detect_model, img, device,plate_rec_model,opt.img_size,car_rec_model)
|
||||
# print(dict_list)
|
||||
ori_img=draw_result(img,dict_list)
|
||||
img_name = os.path.basename(img_path)
|
||||
save_img_path = os.path.join(save_path,img_name)
|
||||
|
63
car_recognition/car_rec.py
Normal file
63
car_recognition/car_rec.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from car_recognition.myNet import myNet
|
||||
import torch
|
||||
import cv2
|
||||
import torch.nn.functional as F
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
colors = ['黑色','蓝色','黄色','棕色','绿色','灰色','橙色','粉色','紫色','红色','白色']
|
||||
def init_car_rec_model(model_path,device):
|
||||
check_point = torch.load(model_path)
|
||||
cfg= check_point['cfg']
|
||||
model = myNet(num_classes=11,cfg=cfg)
|
||||
model.load_state_dict(check_point['state_dict'])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def imge_processing(img,device):
|
||||
img = cv2.resize(img,(64,64))
|
||||
img = img.transpose([2,0,1])
|
||||
img = torch.from_numpy(img).float().to(device)
|
||||
img = img-127.5
|
||||
img = img.unsqueeze(0)
|
||||
return img
|
||||
|
||||
def allFilePath(rootPath,allFIleList):
|
||||
fileList = os.listdir(rootPath)
|
||||
for temp in fileList:
|
||||
if os.path.isfile(os.path.join(rootPath,temp)):
|
||||
allFIleList.append(os.path.join(rootPath,temp))
|
||||
else:
|
||||
allFilePath(os.path.join(rootPath,temp),allFIleList)
|
||||
|
||||
def get_color_and_score(model,img,device):
|
||||
img = imge_processing(img,device)
|
||||
result = model(img)
|
||||
out =F.softmax( result)
|
||||
_, predicted = torch.max(out.data, 1)
|
||||
out=out.data.cpu().numpy().tolist()[0]
|
||||
predicted = predicted.item()
|
||||
car_color= colors[predicted]
|
||||
color_conf = out[predicted]
|
||||
# print(pic_,colors[predicted[0]])
|
||||
return car_color,color_conf
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# root_file =r"/mnt/Gpan/BaiduNetdiskDownload/VehicleColour/VehicleColour/class/7"
|
||||
root_file =r"imgs"
|
||||
file_list=[]
|
||||
allFilePath(root_file,file_list)
|
||||
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
|
||||
model_path = r"/mnt/Gpan/Mydata/pytorchPorject/Car_system/car_color/color_model/0.8682285244554049_epoth_117_model.pth"
|
||||
model = init_car_rec_model(model_path,device)
|
||||
for pic_ in file_list:
|
||||
img = cv2.imread(pic_)
|
||||
# img = imge_processing(img,device)
|
||||
color,conf = get_color_and_score(model,img,device)
|
||||
print(pic_,color,conf)
|
||||
|
||||
|
||||
|
||||
|
95
car_recognition/myNet.py
Normal file
95
car_recognition/myNet.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
from torchvision import models
|
||||
|
||||
|
||||
__all__ = ['myNet','myResNet18']
|
||||
|
||||
# defaultcfg = {
|
||||
# 11 : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
|
||||
# 13 : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
|
||||
# 16 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
|
||||
# 19 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
|
||||
# }
|
||||
# myCfg = [32,'M',64,'M',96,'M',128,'M',192,'M',256]
|
||||
myCfg = [32,'M',64,'M',96,'M',128,'M',256]
|
||||
# myCfg = [8,'M',16,'M',32,'M',64,'M',96]
|
||||
class myNet(nn.Module):
|
||||
def __init__(self,cfg=None,num_classes=3):
|
||||
super(myNet, self).__init__()
|
||||
if cfg is None:
|
||||
cfg = myCfg
|
||||
self.feature = self.make_layers(cfg, True)
|
||||
self.gap =nn.AdaptiveAvgPool2d((1,1))
|
||||
self.classifier = nn.Linear(cfg[-1], num_classes)
|
||||
# self.classifier = nn.Conv2d(cfg[-1],num_classes,kernel_size=1,stride=1)
|
||||
# self.bn_c= nn.BatchNorm2d(num_classes)
|
||||
# self.flatten = nn.Flatten()
|
||||
def make_layers(self, cfg, batch_norm=False):
|
||||
layers = []
|
||||
in_channels = 3
|
||||
for i in range(len(cfg)):
|
||||
if i == 0:
|
||||
conv2d =nn.Conv2d(in_channels, cfg[i], kernel_size=5,stride =1)
|
||||
if batch_norm:
|
||||
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
|
||||
else:
|
||||
layers += [conv2d, nn.ReLU(inplace=True)]
|
||||
in_channels = cfg[i]
|
||||
else :
|
||||
if cfg[i] == 'M':
|
||||
layers += [nn.MaxPool2d(kernel_size=3, stride=2,ceil_mode=True)]
|
||||
else:
|
||||
conv2d = nn.Conv2d(in_channels, cfg[i], kernel_size=3, padding=1,stride =1)
|
||||
if batch_norm:
|
||||
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
|
||||
else:
|
||||
layers += [conv2d, nn.ReLU(inplace=True)]
|
||||
in_channels = cfg[i]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.feature(x)
|
||||
y = nn.AvgPool2d(kernel_size=3, stride=1)(y)
|
||||
y = y.view(x.size(0), -1)
|
||||
y = self.classifier(y)
|
||||
|
||||
# y = self.flatten(y)
|
||||
return y
|
||||
|
||||
class myResNet18(nn.Module):
|
||||
def __init__(self,num_classes=1000):
|
||||
super(myResNet18,self).__init__()
|
||||
model_ft = models.resnet18(pretrained=True)
|
||||
self.model =model_ft
|
||||
self.model.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1,ceil_mode=True)
|
||||
self.model.averagePool = nn.AvgPool2d((5,5),stride=1,ceil_mode=True)
|
||||
self.cls=nn.Linear(512,num_classes)
|
||||
|
||||
def forward(self,x):
|
||||
x = self.model.conv1(x)
|
||||
x = self.model.bn1(x)
|
||||
x = self.model.relu(x)
|
||||
x = self.model.maxpool(x)
|
||||
|
||||
x = self.model.layer1(x)
|
||||
x = self.model.layer2(x)
|
||||
x = self.model.layer3(x)
|
||||
x = self.model.layer4(x)
|
||||
|
||||
x = self.model.averagePool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.cls(x)
|
||||
|
||||
return x
|
||||
if __name__ == '__main__':
|
||||
net = myNet(num_classes=2)
|
||||
# infeatures = net.cls.in_features
|
||||
# net.cls=nn.Linear(infeatures,2)
|
||||
x = torch.FloatTensor(16, 3, 64, 64)
|
||||
y = net(x)
|
||||
print(y.shape)
|
||||
# print(net)
|
@@ -169,7 +169,7 @@ def draw_result(orgimg,dict_list):
|
||||
return orgimg
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--detect_model', nargs='+', type=str, default='weights/plate_detect.pt', help='model.pt path(s)') #检测模型
|
||||
parser.add_argument('--detect_model', nargs='+', type=str, default='weights/detect.pt', help='model.pt path(s)') #检测模型
|
||||
parser.add_argument('--image_path', type=str, default='imgs', help='source')
|
||||
parser.add_argument('--img_size', type=int, default=640, help='inference size (pixels)')
|
||||
parser.add_argument('--output', type=str, default='result1', help='source')
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class myNet_ocr(nn.Module):
|
||||
@@ -121,6 +122,87 @@ class MyNet_color(nn.Module):
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class myNet_ocr_color(nn.Module):
|
||||
def __init__(self,cfg=None,num_classes=78,export=False,color_num=None):
|
||||
super(myNet_ocr_color, self).__init__()
|
||||
if cfg is None:
|
||||
cfg =[32,32,64,64,'M',128,128,'M',196,196,'M',256,256]
|
||||
# cfg =[32,32,'M',64,64,'M',128,128,'M',256,256]
|
||||
self.feature = self.make_layers(cfg, True)
|
||||
self.export = export
|
||||
self.color_num=color_num
|
||||
self.conv_out_num=12 #颜色第一个卷积层输出通道12
|
||||
if self.color_num:
|
||||
self.conv1=nn.Conv2d(cfg[-1],self.conv_out_num,kernel_size=3,stride=2)
|
||||
self.bn1=nn.BatchNorm2d(self.conv_out_num)
|
||||
self.relu1=nn.ReLU(inplace=True)
|
||||
self.gap =nn.AdaptiveAvgPool2d(output_size=1)
|
||||
self.color_classifier=nn.Conv2d(self.conv_out_num,self.color_num,kernel_size=1,stride=1)
|
||||
self.color_bn = nn.BatchNorm2d(self.color_num)
|
||||
self.flatten = nn.Flatten()
|
||||
# self.relu = nn.ReLU(inplace=True)
|
||||
# self.classifier = nn.Linear(cfg[-1], num_classes)
|
||||
# self.loc = nn.MaxPool2d((2, 2), (5, 1), (0, 1),ceil_mode=True)
|
||||
# self.loc = nn.AvgPool2d((2, 2), (5, 2), (0, 1),ceil_mode=False)
|
||||
self.loc = nn.MaxPool2d((5, 2), (1, 1),(0,1),ceil_mode=False)
|
||||
self.newCnn=nn.Conv2d(cfg[-1],num_classes,1,1)
|
||||
# self.newBn=nn.BatchNorm2d(num_classes)
|
||||
def make_layers(self, cfg, batch_norm=False):
|
||||
layers = []
|
||||
in_channels = 3
|
||||
for i in range(len(cfg)):
|
||||
if i == 0:
|
||||
conv2d =nn.Conv2d(in_channels, cfg[i], kernel_size=5,stride =1)
|
||||
if batch_norm:
|
||||
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
|
||||
else:
|
||||
layers += [conv2d, nn.ReLU(inplace=True)]
|
||||
in_channels = cfg[i]
|
||||
else :
|
||||
if cfg[i] == 'M':
|
||||
layers += [nn.MaxPool2d(kernel_size=3, stride=2,ceil_mode=True)]
|
||||
else:
|
||||
conv2d = nn.Conv2d(in_channels, cfg[i], kernel_size=3, padding=(1,1),stride =1)
|
||||
if batch_norm:
|
||||
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
|
||||
else:
|
||||
layers += [conv2d, nn.ReLU(inplace=True)]
|
||||
in_channels = cfg[i]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.feature(x)
|
||||
if self.color_num:
|
||||
x_color=self.conv1(x)
|
||||
x_color=self.bn1(x_color)
|
||||
x_color =self.relu1(x_color)
|
||||
x_color = self.color_classifier(x_color)
|
||||
x_color = self.color_bn(x_color)
|
||||
x_color =self.gap(x_color)
|
||||
x_color = self.flatten(x_color)
|
||||
x=self.loc(x)
|
||||
x=self.newCnn(x)
|
||||
|
||||
if self.export:
|
||||
conv = x.squeeze(2) # b *512 * width
|
||||
conv = conv.transpose(2,1) # [w, b, c]
|
||||
if self.color_num:
|
||||
return conv,x_color
|
||||
return conv
|
||||
else:
|
||||
b, c, h, w = x.size()
|
||||
assert h == 1, "the height of conv must be 1"
|
||||
conv = x.squeeze(2) # b *512 * width
|
||||
conv = conv.permute(2, 0, 1) # [w, b, c]
|
||||
output = F.log_softmax(conv, dim=2)
|
||||
if self.color_num:
|
||||
return output,x_color
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
x = torch.randn(1,3,48,216)
|
||||
model = myNet_ocr(num_classes=78,export=True)
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from plate_recognition.plateNet import myNet_ocr
|
||||
from plate_recognition.plateNet import myNet_ocr,myNet_ocr_color
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import cv2
|
||||
@@ -21,6 +21,7 @@ def allFilePath(rootPath,allFIleList):
|
||||
allFilePath(os.path.join(rootPath,temp),allFIleList)
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu")
|
||||
plateName=r"#京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新学警港澳挂使领民航危0123456789ABCDEFGHJKLMNPQRSTUVWXYZ险品"
|
||||
color_list=['黑色','蓝色','绿色','白色','黄色']
|
||||
mean_value,std_value=(0.588,0.193)
|
||||
def decodePlate(preds):
|
||||
pre=0
|
||||
@@ -47,17 +48,19 @@ def image_processing(img,device):
|
||||
|
||||
def get_plate_result(img,device,model):
|
||||
input = image_processing(img,device)
|
||||
preds = model(input)
|
||||
# preds =preds.argmax(dim=2) #找出概率最大的那个字符
|
||||
preds,color_preds = model(input)
|
||||
preds =preds.argmax(dim=2) #找出概率最大的那个字符
|
||||
color_preds = color_preds.argmax(dim=-1)
|
||||
# print(preds)
|
||||
preds=preds.view(-1).detach().cpu().numpy()
|
||||
color_preds=color_preds.item()
|
||||
newPreds=decodePlate(preds)
|
||||
plate=""
|
||||
for i in newPreds:
|
||||
plate+=plateName[i]
|
||||
# if not (plate[0] in plateName[1:44] ):
|
||||
# return ""
|
||||
return plate
|
||||
return plate,color_list[color_preds]
|
||||
|
||||
def init_model(device,model_path):
|
||||
# print( print(sys.path))
|
||||
@@ -66,7 +69,7 @@ def init_model(device,model_path):
|
||||
model_state=check_point['state_dict']
|
||||
cfg=check_point['cfg']
|
||||
model_path = os.sep.join([sys.path[0],model_path])
|
||||
model = myNet_ocr(num_classes=len(plateName),export=True,cfg=cfg)
|
||||
model = myNet_ocr_color(num_classes=len(plateName),export=True,cfg=cfg,color_num=len(color_list))
|
||||
|
||||
model.load_state_dict(model_state)
|
||||
model.to(device)
|
||||
|
Binary file not shown.
BIN
weights/plate_rec_color.pth
Normal file
BIN
weights/plate_rec_color.pth
Normal file
Binary file not shown.
Reference in New Issue
Block a user