onnx-trt support

This commit is contained in:
we0091234
2022-12-13 13:52:53 +08:00
parent 79c609f2b5
commit 3688ccc532
3 changed files with 9 additions and 4 deletions

View File

@@ -3,13 +3,14 @@ import torch
import torch.nn.functional as F
class myNet_ocr(nn.Module):
def __init__(self,cfg=None,num_classes=78,export=False):
def __init__(self,cfg=None,num_classes=78,export=False,trt=False):
super(myNet_ocr, self).__init__()
if cfg is None:
cfg =[32,32,64,64,'M',128,128,'M',196,196,'M',256,256]
# cfg =[32,32,'M',64,64,'M',128,128,'M',256,256]
self.feature = self.make_layers(cfg, True)
self.export = export
self.trt= trt
# self.classifier = nn.Linear(cfg[-1], num_classes)
# self.loc = nn.MaxPool2d((2, 2), (5, 1), (0, 1),ceil_mode=True)
# self.loc = nn.AvgPool2d((2, 2), (5, 2), (0, 1),ceil_mode=False)
@@ -47,8 +48,9 @@ class myNet_ocr(nn.Module):
if self.export:
conv = x.squeeze(2) # b *512 * width
conv = conv.transpose(2,1) # [w, b, c]
# conv =conv.argmax(dim=2)
# out = conv.float()
if self.trt:
conv =conv.argmax(dim=2)
out = conv.float()
return conv
else:
b, c, h, w = x.size()