mirror of
https://github.com/we0091234/crnn_plate_recognition.git
synced 2025-09-26 15:41:10 +08:00
99 lines
3.8 KiB
Python
99 lines
3.8 KiB
Python
import torch.nn as nn
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
class small_basic_block(nn.Module):
|
|
def __init__(self, ch_in, ch_out):
|
|
super(small_basic_block, self).__init__()
|
|
self.block = nn.Sequential(
|
|
nn.Conv2d(ch_in, ch_out // 4, kernel_size=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(3, 1), padding=(1, 0)),
|
|
nn.ReLU(),
|
|
nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(1, 3), padding=(0, 1)),
|
|
nn.ReLU(),
|
|
nn.Conv2d(ch_out // 4, ch_out, kernel_size=1),
|
|
)
|
|
def forward(self, x):
|
|
return self.block(x)
|
|
|
|
class LPRNet(nn.Module):
|
|
def __init__(self, lpr_max_len, num_classes, dropout_rate,export=False):
|
|
super(LPRNet, self).__init__()
|
|
self.lpr_max_len = lpr_max_len
|
|
self.num_classes = num_classes
|
|
self.export=export
|
|
self.backbone = nn.Sequential(
|
|
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1), # 0
|
|
nn.BatchNorm2d(num_features=64),
|
|
nn.ReLU(), # 2
|
|
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 1, 1)),
|
|
small_basic_block(ch_in=64, ch_out=128), # *** 4 ***
|
|
nn.BatchNorm2d(num_features=128),
|
|
nn.ReLU(), # 6
|
|
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(2, 1, 2)),
|
|
small_basic_block(ch_in=64, ch_out=256), # 8
|
|
nn.BatchNorm2d(num_features=256),
|
|
nn.ReLU(), # 10
|
|
small_basic_block(ch_in=256, ch_out=256), # *** 11 ***
|
|
nn.BatchNorm2d(num_features=256), # 12
|
|
nn.ReLU(),
|
|
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 1, 2)), # 14
|
|
nn.Dropout(dropout_rate),
|
|
nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1), # 16
|
|
nn.BatchNorm2d(num_features=256),
|
|
nn.ReLU(), # 18
|
|
nn.Dropout(dropout_rate),
|
|
nn.Conv2d(in_channels=256, out_channels=num_classes, kernel_size=(13, 1), stride=1), # 20
|
|
nn.BatchNorm2d(num_features=num_classes),
|
|
nn.ReLU(), # *** 22 ***
|
|
)
|
|
self.container = nn.Sequential(
|
|
nn.Conv2d(in_channels=448+self.num_classes, out_channels=self.num_classes, kernel_size=(1, 1), stride=(1, 1)),
|
|
# nn.BatchNorm2d(num_features=self.class_num),
|
|
# nn.ReLU(),
|
|
# nn.Conv2d(in_channels=self.class_num, out_channels=self.lpr_max_len+1, kernel_size=3, stride=2),
|
|
# nn.ReLU(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
keep_features = list()
|
|
for i, layer in enumerate(self.backbone.children()):
|
|
x = layer(x)
|
|
if i in [2, 6, 13, 22]: # [2, 4, 8, 11, 22]
|
|
keep_features.append(x)
|
|
|
|
global_context = list()
|
|
for i, f in enumerate(keep_features):
|
|
if i in [0, 1]:
|
|
f = nn.AvgPool2d(kernel_size=5, stride=5)(f)
|
|
if i in [2]:
|
|
f = nn.AvgPool2d(kernel_size=(4, 10), stride=(4, 2))(f)
|
|
f_pow = torch.pow(f, 2)
|
|
f_mean = torch.mean(f_pow)
|
|
f = torch.div(f, f_mean)
|
|
global_context.append(f)
|
|
|
|
x = torch.cat(global_context, 1)
|
|
x = self.container(x)
|
|
logits = torch.mean(x, dim=2)
|
|
if self.export:
|
|
logits=logits.transpose(2,1)
|
|
logits = logits.argmax(dim=2)
|
|
else:
|
|
logits = logits.permute(2, 0, 1) # [w, b, c]
|
|
# output = F.log_softmax(self.rnn(conv), dim=2)
|
|
logits = F.log_softmax(logits, dim=2)
|
|
return logits
|
|
|
|
def build_lprnet(lpr_max_len=8, num_classes=78, dropout_rate=0.5,export=False):
|
|
|
|
Net = LPRNet(lpr_max_len, num_classes, dropout_rate,export)
|
|
|
|
return Net
|
|
|
|
if __name__ == "__main__":
|
|
model =build_lprnet(export=True)
|
|
x=torch.randn(1,3,24,94)
|
|
out = model(x)
|
|
print(out.shape) |