mirror of
https://github.com/we0091234/crnn_plate_recognition.git
synced 2025-10-05 10:56:50 +08:00
onnx-trt support
This commit is contained in:
@@ -100,6 +100,8 @@ python export.py --weights saved_model/best.pth --save_path saved_model/best.onn
|
|||||||
|
|
||||||
导出onnx文件为 saved_model/best.onnx
|
导出onnx文件为 saved_model/best.onnx
|
||||||
|
|
||||||
|
如果需要onnx支持trt的话,支持[这里推理](https://github.com/we0091234/chinese_plate_tensorrt),则加上--trt
|
||||||
|
|
||||||
#### onnx 推理
|
#### onnx 推理
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@@ -14,6 +14,7 @@ if __name__=="__main__":
|
|||||||
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
|
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
|
||||||
parser.add_argument('--dynamic', action='store_true', default=False, help='enable dynamic axis in onnx model')
|
parser.add_argument('--dynamic', action='store_true', default=False, help='enable dynamic axis in onnx model')
|
||||||
parser.add_argument('--simplify', action='store_true', default=False, help='simplified onnx')
|
parser.add_argument('--simplify', action='store_true', default=False, help='simplified onnx')
|
||||||
|
parser.add_argument('--trt', action='store_true', default=False, help='support trt')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ if __name__=="__main__":
|
|||||||
print(opt)
|
print(opt)
|
||||||
checkpoint = torch.load(opt.weights)
|
checkpoint = torch.load(opt.weights)
|
||||||
cfg = checkpoint['cfg']
|
cfg = checkpoint['cfg']
|
||||||
model = myNet_ocr(num_classes=len(plate_chr),cfg=cfg,export=True)
|
model = myNet_ocr(num_classes=len(plate_chr),cfg=cfg,export=True,trt=opt.trt)
|
||||||
model.load_state_dict(checkpoint['state_dict'])
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
@@ -3,13 +3,14 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
class myNet_ocr(nn.Module):
|
class myNet_ocr(nn.Module):
|
||||||
def __init__(self,cfg=None,num_classes=78,export=False):
|
def __init__(self,cfg=None,num_classes=78,export=False,trt=False):
|
||||||
super(myNet_ocr, self).__init__()
|
super(myNet_ocr, self).__init__()
|
||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg =[32,32,64,64,'M',128,128,'M',196,196,'M',256,256]
|
cfg =[32,32,64,64,'M',128,128,'M',196,196,'M',256,256]
|
||||||
# cfg =[32,32,'M',64,64,'M',128,128,'M',256,256]
|
# cfg =[32,32,'M',64,64,'M',128,128,'M',256,256]
|
||||||
self.feature = self.make_layers(cfg, True)
|
self.feature = self.make_layers(cfg, True)
|
||||||
self.export = export
|
self.export = export
|
||||||
|
self.trt= trt
|
||||||
# self.classifier = nn.Linear(cfg[-1], num_classes)
|
# self.classifier = nn.Linear(cfg[-1], num_classes)
|
||||||
# self.loc = nn.MaxPool2d((2, 2), (5, 1), (0, 1),ceil_mode=True)
|
# self.loc = nn.MaxPool2d((2, 2), (5, 1), (0, 1),ceil_mode=True)
|
||||||
# self.loc = nn.AvgPool2d((2, 2), (5, 2), (0, 1),ceil_mode=False)
|
# self.loc = nn.AvgPool2d((2, 2), (5, 2), (0, 1),ceil_mode=False)
|
||||||
@@ -47,8 +48,9 @@ class myNet_ocr(nn.Module):
|
|||||||
if self.export:
|
if self.export:
|
||||||
conv = x.squeeze(2) # b *512 * width
|
conv = x.squeeze(2) # b *512 * width
|
||||||
conv = conv.transpose(2,1) # [w, b, c]
|
conv = conv.transpose(2,1) # [w, b, c]
|
||||||
# conv =conv.argmax(dim=2)
|
if self.trt:
|
||||||
# out = conv.float()
|
conv =conv.argmax(dim=2)
|
||||||
|
out = conv.float()
|
||||||
return conv
|
return conv
|
||||||
else:
|
else:
|
||||||
b, c, h, w = x.size()
|
b, c, h, w = x.size()
|
||||||
|
Reference in New Issue
Block a user