mirror of
https://github.com/we0091234/Car_recognition.git
synced 2025-10-05 08:37:00 +08:00
first commit
This commit is contained in:
36
utils/infer_utils.py
Normal file
36
utils/infer_utils.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
def decode_infer(output, stride):
|
||||
# logging.info(torch.tensor(output.shape[0]))
|
||||
# logging.info(output.shape)
|
||||
# # bz is batch-size
|
||||
# bz = tuple(torch.tensor(output.shape[0]))
|
||||
# gridsize = tuple(torch.tensor(output.shape[-1]))
|
||||
# logging.info(gridsize)
|
||||
sh = torch.tensor(output.shape)
|
||||
bz = sh[0]
|
||||
gridsize = sh[-1]
|
||||
|
||||
output = output.permute(0, 2, 3, 1)
|
||||
output = output.view(bz, gridsize, gridsize, self.gt_per_grid, 5+self.numclass)
|
||||
x1y1, x2y2, conf, prob = torch.split(
|
||||
output, [2, 2, 1, self.numclass], dim=4)
|
||||
|
||||
shiftx = torch.arange(0, gridsize, dtype=torch.float32)
|
||||
shifty = torch.arange(0, gridsize, dtype=torch.float32)
|
||||
shifty, shiftx = torch.meshgrid([shiftx, shifty])
|
||||
shiftx = shiftx.unsqueeze(-1).repeat(bz, 1, 1, self.gt_per_grid)
|
||||
shifty = shifty.unsqueeze(-1).repeat(bz, 1, 1, self.gt_per_grid)
|
||||
|
||||
xy_grid = torch.stack([shiftx, shifty], dim=4).cuda()
|
||||
x1y1 = (xy_grid+0.5-torch.exp(x1y1))*stride
|
||||
x2y2 = (xy_grid+0.5+torch.exp(x2y2))*stride
|
||||
|
||||
xyxy = torch.cat((x1y1, x2y2), dim=4)
|
||||
conf = torch.sigmoid(conf)
|
||||
prob = torch.sigmoid(prob)
|
||||
output = torch.cat((xyxy, conf, prob), 4)
|
||||
output = output.view(bz, -1, 5+self.numclass)
|
||||
return output
|
Reference in New Issue
Block a user