mirror of
https://github.com/we0091234/crnn_plate_recognition.git
synced 2025-09-26 23:45:51 +08:00
175 lines
5.8 KiB
Python
175 lines
5.8 KiB
Python
import argparse
|
|
from easydict import EasyDict as edict
|
|
import yaml
|
|
import os
|
|
import torch
|
|
import torch.backends.cudnn as cudnn
|
|
from torch.utils.data import DataLoader
|
|
import lib.models.crnn as crnn
|
|
import lib.utils.utils as utils
|
|
from lib.dataset import get_dataset
|
|
from lib.core import function
|
|
from lib.utils.utils import model_info
|
|
from plateNet import myNet_ocr
|
|
from alphabets import plateName,plate_chr
|
|
from LPRNet import build_lprnet
|
|
|
|
from tensorboardX import SummaryWriter
|
|
|
|
def parse_arg():
|
|
parser = argparse.ArgumentParser(description="train crnn")
|
|
|
|
parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)
|
|
parser.add_argument('--img_h', type=int, default=48, help='height')
|
|
parser.add_argument('--img_w',type=int,default=168,help='width')
|
|
args = parser.parse_args()
|
|
|
|
with open(args.cfg, 'r') as f:
|
|
# config = yaml.load(f, Loader=yaml.FullLoader)
|
|
config = yaml.load(f)
|
|
config = edict(config)
|
|
|
|
config.DATASET.ALPHABETS = plateName
|
|
config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
|
|
config.HEIGHT=args.img_h
|
|
config.WIDTH = args.img_w
|
|
return config
|
|
|
|
|
|
def main():
|
|
|
|
# load config
|
|
config = parse_arg()
|
|
|
|
# create output folder
|
|
output_dict = utils.create_log_folder(config, phase='train')
|
|
|
|
# cudnn
|
|
cudnn.benchmark = config.CUDNN.BENCHMARK
|
|
cudnn.deterministic = config.CUDNN.DETERMINISTIC
|
|
cudnn.enabled = config.CUDNN.ENABLED
|
|
|
|
# writer dict
|
|
writer_dict = {
|
|
'writer': SummaryWriter(log_dir=output_dict['tb_dir']),
|
|
'train_global_steps': 0,
|
|
'valid_global_steps': 0,
|
|
}
|
|
|
|
# construct face related neural networks
|
|
# cfg =[32,'M',64,'M',128,'M',256] #small model
|
|
cfg =[32,32,64,64,'M',128,128,'M',196,196,'M',256,256] #big model
|
|
# model = crnn.get_crnn(config,cfg=cfg)
|
|
model = myNet_ocr(num_classes=len(plate_chr),cfg=cfg)
|
|
# model = build_lprnet(num_classes=len(plate_chr))
|
|
|
|
# get device
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda:{}".format(config.GPUID))
|
|
else:
|
|
device = torch.device("cpu:0")
|
|
|
|
model = model.to(device)
|
|
|
|
# define loss function
|
|
criterion = torch.nn.CTCLoss()
|
|
|
|
last_epoch = config.TRAIN.BEGIN_EPOCH
|
|
optimizer = utils.get_optimizer(config, model)
|
|
if isinstance(config.TRAIN.LR_STEP, list):
|
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
optimizer, config.TRAIN.LR_STEP,
|
|
config.TRAIN.LR_FACTOR, last_epoch-1
|
|
)
|
|
else:
|
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
|
optimizer, config.TRAIN.LR_STEP,
|
|
config.TRAIN.LR_FACTOR, last_epoch - 1
|
|
)
|
|
|
|
if config.TRAIN.FINETUNE.IS_FINETUNE:
|
|
model_state_file = config.TRAIN.FINETUNE.FINETUNE_CHECKPOINIT
|
|
if model_state_file == '':
|
|
print(" => no checkpoint found")
|
|
checkpoint = torch.load(model_state_file, map_location='cpu')
|
|
if 'state_dict' in checkpoint.keys():
|
|
checkpoint = checkpoint['state_dict']
|
|
|
|
# from collections import OrderedDict
|
|
# model_dict = OrderedDict()
|
|
# for k, v in checkpoint.items():
|
|
# if 'cnn' in k:
|
|
# model_dict[k[4:]] = v
|
|
# model.cnn.load_state_dict(model_dict)
|
|
model.load_state_dict(checkpoint)
|
|
# if config.TRAIN.FINETUNE.FREEZE:
|
|
# for p in model.cnn.parameters():
|
|
# p.requires_grad = False
|
|
|
|
elif config.TRAIN.RESUME.IS_RESUME:
|
|
model_state_file = config.TRAIN.RESUME.FILE
|
|
if model_state_file == '':
|
|
print(" => no checkpoint found")
|
|
checkpoint = torch.load(model_state_file, map_location='cpu')
|
|
if 'state_dict' in checkpoint.keys():
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
last_epoch = checkpoint['epoch']
|
|
# optimizer.load_state_dict(checkpoint['optimizer'])
|
|
# lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
|
else:
|
|
model.load_state_dict(checkpoint)
|
|
|
|
model_info(model)
|
|
train_dataset = get_dataset(config)(config, input_w=config.WIDTH,input_h=config.HEIGHT,is_train=True)
|
|
train_loader = DataLoader(
|
|
dataset=train_dataset,
|
|
batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
|
|
shuffle=config.TRAIN.SHUFFLE,
|
|
num_workers=config.WORKERS,
|
|
pin_memory=config.PIN_MEMORY,
|
|
)
|
|
|
|
val_dataset = get_dataset(config)(config,input_w=config.WIDTH,input_h=config.HEIGHT, is_train=False)
|
|
val_loader = DataLoader(
|
|
dataset=val_dataset,
|
|
batch_size=config.TEST.BATCH_SIZE_PER_GPU,
|
|
shuffle=config.TEST.SHUFFLE,
|
|
num_workers=config.WORKERS,
|
|
pin_memory=config.PIN_MEMORY,
|
|
)
|
|
|
|
best_acc = 0.5
|
|
converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
|
|
for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
|
|
|
|
function.train(config, train_loader, train_dataset, converter, model,
|
|
criterion, optimizer, device, epoch, writer_dict, output_dict)
|
|
lr_scheduler.step()
|
|
|
|
acc = function.validate(config, val_loader, val_dataset, converter,
|
|
model, criterion, device, epoch, writer_dict, output_dict)
|
|
|
|
is_best = acc > best_acc
|
|
best_acc = max(acc, best_acc)
|
|
|
|
print("is best:", is_best)
|
|
print("best acc is:", best_acc)
|
|
# save checkpoint
|
|
torch.save(
|
|
{
|
|
"cfg":cfg,
|
|
"state_dict": model.state_dict(),
|
|
"epoch": epoch + 1,
|
|
# "optimizer": optimizer.state_dict(),
|
|
# "lr_scheduler": lr_scheduler.state_dict(),
|
|
"best_acc": best_acc,
|
|
}, os.path.join(output_dict['chs_dir'], "checkpoint_{}_acc_{:.4f}.pth".format(epoch, acc))
|
|
)
|
|
|
|
writer_dict['writer'].close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|