mirror of
https://github.com/we0091234/crnn_plate_recognition.git
synced 2025-09-27 07:52:06 +08:00
onnx-trt support
This commit is contained in:
@@ -3,13 +3,14 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class myNet_ocr(nn.Module):
|
||||
def __init__(self,cfg=None,num_classes=78,export=False):
|
||||
def __init__(self,cfg=None,num_classes=78,export=False,trt=False):
|
||||
super(myNet_ocr, self).__init__()
|
||||
if cfg is None:
|
||||
cfg =[32,32,64,64,'M',128,128,'M',196,196,'M',256,256]
|
||||
# cfg =[32,32,'M',64,64,'M',128,128,'M',256,256]
|
||||
self.feature = self.make_layers(cfg, True)
|
||||
self.export = export
|
||||
self.trt= trt
|
||||
# self.classifier = nn.Linear(cfg[-1], num_classes)
|
||||
# self.loc = nn.MaxPool2d((2, 2), (5, 1), (0, 1),ceil_mode=True)
|
||||
# self.loc = nn.AvgPool2d((2, 2), (5, 2), (0, 1),ceil_mode=False)
|
||||
@@ -47,8 +48,9 @@ class myNet_ocr(nn.Module):
|
||||
if self.export:
|
||||
conv = x.squeeze(2) # b *512 * width
|
||||
conv = conv.transpose(2,1) # [w, b, c]
|
||||
# conv =conv.argmax(dim=2)
|
||||
# out = conv.float()
|
||||
if self.trt:
|
||||
conv =conv.argmax(dim=2)
|
||||
out = conv.float()
|
||||
return conv
|
||||
else:
|
||||
b, c, h, w = x.size()
|
||||
|
Reference in New Issue
Block a user