mirror of
https://github.com/we0091234/crnn_plate_recognition.git
synced 2025-09-26 23:45:51 +08:00
130 lines
4.5 KiB
Python
130 lines
4.5 KiB
Python
from plateNet import myNet_ocr
|
|
import torch
|
|
import torch.nn as nn
|
|
import cv2
|
|
import numpy as np
|
|
import os
|
|
import time
|
|
import argparse
|
|
from alphabets import plate_chr
|
|
from LPRNet import build_lprnet
|
|
def cv_imread(path): #读取中文路径的图片
|
|
img=cv2.imdecode(np.fromfile(path,dtype=np.uint8),-1)
|
|
return img
|
|
|
|
def allFilePath(rootPath,allFIleList):
|
|
fileList = os.listdir(rootPath)
|
|
for temp in fileList:
|
|
if os.path.isfile(os.path.join(rootPath,temp)):
|
|
allFIleList.append(os.path.join(rootPath,temp))
|
|
else:
|
|
allFilePath(os.path.join(rootPath,temp),allFIleList)
|
|
|
|
mean_value,std_value=(0.588,0.193)
|
|
def decodePlate(preds):
|
|
pre=0
|
|
newPreds=[]
|
|
for i in range(len(preds)):
|
|
if preds[i]!=0 and preds[i]!=pre:
|
|
newPreds.append(preds[i])
|
|
pre=preds[i]
|
|
return newPreds
|
|
|
|
def image_processing(img,device,img_size):
|
|
img_h,img_w= img_size
|
|
img = cv2.resize(img, (img_w,img_h))
|
|
# img = np.reshape(img, (48, 168, 3))
|
|
|
|
# normalize
|
|
img = img.astype(np.float32)
|
|
img = (img / 255. - mean_value) / std_value
|
|
img = img.transpose([2, 0, 1])
|
|
img = torch.from_numpy(img)
|
|
|
|
img = img.to(device)
|
|
img = img.view(1, *img.size())
|
|
return img
|
|
|
|
def get_plate_result(img,device,model,img_size):
|
|
# img = cv2.imread(image_path)
|
|
input = image_processing(img,device,img_size)
|
|
preds = model(input)
|
|
# print(preds)
|
|
preds=preds.view(-1).detach().cpu().numpy()
|
|
newPreds=decodePlate(preds)
|
|
plate=""
|
|
for i in newPreds:
|
|
plate+=plate_chr[int(i)]
|
|
return plate
|
|
|
|
def init_model(device,model_path):
|
|
check_point = torch.load(model_path,map_location=device)
|
|
model_state=check_point['state_dict']
|
|
cfg = check_point['cfg']
|
|
model = myNet_ocr(num_classes=78,export=True,cfg=cfg) #export True 用来推理
|
|
# model =build_lprnet(num_classes=len(plate_chr),export=True)
|
|
model.load_state_dict(model_state)
|
|
model.to(device)
|
|
model.eval()
|
|
return model
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model_path', type=str, default='saved_model/best.pth', help='model.pt path(s)')
|
|
parser.add_argument('--image_path', type=str, default='images/test.jpg', help='source')
|
|
parser.add_argument('--img_h', type=int, default=48, help='height')
|
|
parser.add_argument('--img_w',type=int,default=168,help='width')
|
|
parser.add_argument('--LPRNet',action='store_true',help='use LPRNet') #True代表使用LPRNet ,False代表用plateNet
|
|
parser.add_argument('--acc',type=bool,default='false',help=' get accuracy') #标记好的图片,计算准确率
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
# device =torch.device("cpu")
|
|
opt = parser.parse_args()
|
|
img_size = (opt.img_h,opt.img_w)
|
|
model = init_model(device,opt.model_path)
|
|
if os.path.isfile(opt.image_path): #判断是单张图片还是目录
|
|
right=0
|
|
begin = time.time()
|
|
img = cv_imread(opt.image_path)
|
|
if img.shape[-1]!=3:
|
|
img = cv2.cvtColor(img,cv2.COLOR_BGRA2BGR)
|
|
plate=get_plate_result(img, device,model,img_size)
|
|
print(plate)
|
|
elif opt.acc:
|
|
file_list=[]
|
|
right=0
|
|
allFilePath(opt.image_path,file_list)
|
|
for pic_ in file_list:
|
|
try:
|
|
pic_name = os.path.basename(pic_)
|
|
img = cv_imread(pic_)
|
|
if img.shape[-1]!=3:
|
|
img = cv2.cvtColor(img,cv2.COLOR_BGRA2BGR)
|
|
plate=get_plate_result(img,device,model,img_size)
|
|
plate_ori = pic_.split('/')[-1].split('_')[0]
|
|
# print(plate,"---",plate_ori)
|
|
if(plate==plate_ori):
|
|
|
|
right+=1
|
|
else:
|
|
print(plate_ori,"rec as ---> ",plate,pic_)
|
|
# print(plate,pic_name)
|
|
except:
|
|
print("error")
|
|
print("sum:%d ,right:%d , accuracy: %f"%(len(file_list),right,right/len(file_list)))
|
|
else:
|
|
file_list=[]
|
|
allFilePath(opt.image_path,file_list)
|
|
for pic_ in file_list:
|
|
try:
|
|
pic_name = os.path.basename(pic_)
|
|
img = cv_imread(pic_)
|
|
if img.shape[-1]!=3:
|
|
img = cv2.cvtColor(img,cv2.COLOR_BGRA2BGR)
|
|
plate=get_plate_result(img,device,model)
|
|
print(plate,pic_name)
|
|
except:
|
|
print("error")
|
|
|
|
|
|
|