mirror of
https://github.com/we0091234/crnn_plate_recognition.git
synced 2025-12-24 12:12:23 +08:00
163 lines
5.6 KiB
Python
163 lines
5.6 KiB
Python
import torch.optim as optim
|
|
import time
|
|
from pathlib import Path
|
|
import os
|
|
import torch
|
|
|
|
def get_optimizer(config, model):
|
|
|
|
optimizer = None
|
|
|
|
if config.TRAIN.OPTIMIZER == "sgd":
|
|
optimizer = optim.SGD(
|
|
filter(lambda p: p.requires_grad, model.parameters()),
|
|
lr=config.TRAIN.LR,
|
|
momentum=config.TRAIN.MOMENTUM,
|
|
weight_decay=config.TRAIN.WD,
|
|
nesterov=config.TRAIN.NESTEROV
|
|
)
|
|
elif config.TRAIN.OPTIMIZER == "adam":
|
|
optimizer = optim.Adam(
|
|
filter(lambda p: p.requires_grad, model.parameters()),
|
|
lr=config.TRAIN.LR,
|
|
)
|
|
elif config.TRAIN.OPTIMIZER == "rmsprop":
|
|
optimizer = optim.RMSprop(
|
|
filter(lambda p: p.requires_grad, model.parameters()),
|
|
lr=config.TRAIN.LR,
|
|
momentum=config.TRAIN.MOMENTUM,
|
|
weight_decay=config.TRAIN.WD,
|
|
# alpha=config.TRAIN.RMSPROP_ALPHA,
|
|
# centered=config.TRAIN.RMSPROP_CENTERED
|
|
)
|
|
|
|
return optimizer
|
|
|
|
def create_log_folder(cfg, phase='train'):
|
|
root_output_dir = Path(cfg.OUTPUT_DIR)
|
|
# set up logger
|
|
if not root_output_dir.exists():
|
|
print('=> creating {}'.format(root_output_dir))
|
|
root_output_dir.mkdir()
|
|
|
|
dataset = cfg.DATASET.DATASET
|
|
model = cfg.MODEL.NAME
|
|
|
|
time_str = time.strftime('%Y-%m-%d-%H-%M')
|
|
checkpoints_output_dir = root_output_dir / dataset / model / time_str / 'checkpoints'
|
|
|
|
print('=> creating {}'.format(checkpoints_output_dir))
|
|
checkpoints_output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
tensorboard_log_dir = root_output_dir / dataset / model / time_str / 'log'
|
|
print('=> creating {}'.format(tensorboard_log_dir))
|
|
tensorboard_log_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
return {'chs_dir': str(checkpoints_output_dir), 'tb_dir': str(tensorboard_log_dir)}
|
|
|
|
|
|
def get_batch_label(d, i):
|
|
label = []
|
|
for idx in i:
|
|
label.append(list(d.labels[idx].values())[0])
|
|
return label
|
|
|
|
class strLabelConverter(object):
|
|
"""Convert between str and label.
|
|
|
|
NOTE:
|
|
Insert `blank` to the alphabet for CTC.
|
|
|
|
Args:
|
|
alphabet (str): set of the possible characters.
|
|
ignore_case (bool, default=True): whether or not to ignore all of the case.
|
|
"""
|
|
|
|
def __init__(self, alphabet, ignore_case=False):
|
|
self._ignore_case = ignore_case
|
|
if self._ignore_case:
|
|
alphabet = alphabet.lower()
|
|
self.alphabet = alphabet + '-' # for `-1` index
|
|
|
|
self.dict = {}
|
|
for i, char in enumerate(alphabet):
|
|
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
|
|
self.dict[char] = i + 1
|
|
|
|
def encode(self, text):
|
|
"""Support batch or single str.
|
|
|
|
Args:
|
|
text (str or list of str): texts to convert.
|
|
|
|
Returns:
|
|
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
|
|
torch.IntTensor [n]: length of each text.
|
|
"""
|
|
|
|
length = []
|
|
result = []
|
|
decode_flag = True if type(text[0])==bytes else False
|
|
|
|
for item in text:
|
|
|
|
if decode_flag:
|
|
item = item.decode('utf-8','strict')
|
|
length.append(len(item))
|
|
for char in item:
|
|
index = self.dict[char]
|
|
result.append(index)
|
|
text = result
|
|
return (torch.IntTensor(text), torch.IntTensor(length))
|
|
|
|
def decode(self, t, length, raw=False):
|
|
"""Decode encoded texts back into strs.
|
|
|
|
Args:
|
|
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
|
|
torch.IntTensor [n]: length of each text.
|
|
|
|
Raises:
|
|
AssertionError: when the texts and its length does not match.
|
|
|
|
Returns:
|
|
text (str or list of str): texts to convert.
|
|
"""
|
|
if length.numel() == 1:
|
|
length = length[0]
|
|
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
|
|
if raw:
|
|
return ''.join([self.alphabet[i - 1] for i in t])
|
|
else:
|
|
char_list = []
|
|
for i in range(length):
|
|
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
|
|
char_list.append(self.alphabet[t[i] - 1])
|
|
return ''.join(char_list)
|
|
else:
|
|
# batch mode
|
|
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
|
|
texts = []
|
|
index = 0
|
|
for i in range(length.numel()):
|
|
l = length[i]
|
|
texts.append(
|
|
self.decode(
|
|
t[index:index + l], torch.IntTensor([l]), raw=raw))
|
|
index += l
|
|
return texts
|
|
|
|
def get_char_dict(path):
|
|
with open(path, 'rb') as file:
|
|
char_dict = {num: char.strip().decode('gbk', 'ignore') for num, char in enumerate(file.readlines())}
|
|
|
|
def model_info(model): # Plots a line-by-line description of a PyTorch model
|
|
n_p = sum(x.numel() for x in model.parameters()) # number parameters
|
|
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
|
|
print('\n%5s %50s %9s %12s %20s %12s %12s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
|
|
for i, (name, p) in enumerate(model.named_parameters()):
|
|
name = name.replace('module_list.', '')
|
|
print('%5g %50s %9s %12g %20s %12.3g %12.3g' % (
|
|
i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
|
print('Model Summary: %g layers, %g parameters, %g gradients\n' % (i + 1, n_p, n_g)) |