Files
crnn_plate_recognition/lib/utils/utils.py
2022-10-03 09:14:38 +08:00

163 lines
5.6 KiB
Python

import torch.optim as optim
import time
from pathlib import Path
import os
import torch
def get_optimizer(config, model):
optimizer = None
if config.TRAIN.OPTIMIZER == "sgd":
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, model.parameters()),
lr=config.TRAIN.LR,
momentum=config.TRAIN.MOMENTUM,
weight_decay=config.TRAIN.WD,
nesterov=config.TRAIN.NESTEROV
)
elif config.TRAIN.OPTIMIZER == "adam":
optimizer = optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=config.TRAIN.LR,
)
elif config.TRAIN.OPTIMIZER == "rmsprop":
optimizer = optim.RMSprop(
filter(lambda p: p.requires_grad, model.parameters()),
lr=config.TRAIN.LR,
momentum=config.TRAIN.MOMENTUM,
weight_decay=config.TRAIN.WD,
# alpha=config.TRAIN.RMSPROP_ALPHA,
# centered=config.TRAIN.RMSPROP_CENTERED
)
return optimizer
def create_log_folder(cfg, phase='train'):
root_output_dir = Path(cfg.OUTPUT_DIR)
# set up logger
if not root_output_dir.exists():
print('=> creating {}'.format(root_output_dir))
root_output_dir.mkdir()
dataset = cfg.DATASET.DATASET
model = cfg.MODEL.NAME
time_str = time.strftime('%Y-%m-%d-%H-%M')
checkpoints_output_dir = root_output_dir / dataset / model / time_str / 'checkpoints'
print('=> creating {}'.format(checkpoints_output_dir))
checkpoints_output_dir.mkdir(parents=True, exist_ok=True)
tensorboard_log_dir = root_output_dir / dataset / model / time_str / 'log'
print('=> creating {}'.format(tensorboard_log_dir))
tensorboard_log_dir.mkdir(parents=True, exist_ok=True)
return {'chs_dir': str(checkpoints_output_dir), 'tb_dir': str(tensorboard_log_dir)}
def get_batch_label(d, i):
label = []
for idx in i:
label.append(list(d.labels[idx].values())[0])
return label
class strLabelConverter(object):
"""Convert between str and label.
NOTE:
Insert `blank` to the alphabet for CTC.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def __init__(self, alphabet, ignore_case=False):
self._ignore_case = ignore_case
if self._ignore_case:
alphabet = alphabet.lower()
self.alphabet = alphabet + '-' # for `-1` index
self.dict = {}
for i, char in enumerate(alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[char] = i + 1
def encode(self, text):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
length = []
result = []
decode_flag = True if type(text[0])==bytes else False
for item in text:
if decode_flag:
item = item.decode('utf-8','strict')
length.append(len(item))
for char in item:
index = self.dict[char]
result.append(index)
text = result
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, t, length, raw=False):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1])
return ''.join(char_list)
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
texts.append(
self.decode(
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts
def get_char_dict(path):
with open(path, 'rb') as file:
char_dict = {num: char.strip().decode('gbk', 'ignore') for num, char in enumerate(file.readlines())}
def model_info(model): # Plots a line-by-line description of a PyTorch model
n_p = sum(x.numel() for x in model.parameters()) # number parameters
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
print('\n%5s %50s %9s %12s %20s %12s %12s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
for i, (name, p) in enumerate(model.named_parameters()):
name = name.replace('module_list.', '')
print('%5g %50s %9s %12g %20s %12.3g %12.3g' % (
i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
print('Model Summary: %g layers, %g parameters, %g gradients\n' % (i + 1, n_p, n_g))