This commit is contained in:
sha-xiaobao
2022-11-29 14:34:40 +08:00
parent 8e10aee4c0
commit 5e9dad7dc6
2 changed files with 14 additions and 19 deletions

View File

@@ -20,12 +20,10 @@ from re import A
from collections import defaultdict
import numpy as np
sys.path.append(".")
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from sklearn.metrics import auc
sys.path.append(".")
def load_detection_file(kitti):

View File

@@ -19,16 +19,14 @@ import glob
from logging import raiseExceptions
from re import A
sys.path.append(".")
import copy
import argparse
from collections import defaultdict
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from sklearn.metrics import auc
import argparse
import numpy as np
from collections import defaultdict
sys.path.append(".")
# laser to base
@@ -124,31 +122,30 @@ def get_precision_recall(
recalls = np.full_like(dets_cls, np.nan)
threshs = np.full_like(dets_cls, np.nan)
indices = np.argsort(dets_cls, kind="mergesort") # mergesort for determinism.
indices = np.argsort(dets_cls, kind="mergesort")
for i, idx in enumerate(reversed(indices)):
frame = dets_inds[idx]
iframe = np.where(frames == frame)[0][0] # Can only be a single one.
frame_tp = dets_inds[idx]
iframe_tp = np.where(frames == frame_tp)[0][0]
# Accept this detection
dets_idxs = det_accepted_idxs[frame]
dets_idxs = det_accepted_idxs[frame_tp]
dets_idxs.append(idx)
threshs[i] = dets_cls[idx]
dets = dets_xy[dets_idxs]
print(dets)
gts_mask = gts_inds == frame
gts_mask = gts_inds == frame_tp
gts = gts_xy[gts_mask]
radii = a_rad[gts_mask]
if len(gts) == 0: # No GT, but there is a detection.
fps[iframe] += 1
if len(gts) == 0:
fps[iframe_tp] += 1
else:
not_in_radius = radii[:, None] < cdist(gts, dets)
igt, idet = linear_sum_assignment(not_in_radius)
tps[iframe] = np.sum(np.logical_not(not_in_radius[igt, idet]))
fps[iframe] = (len(dets) - tps[iframe])
fns[iframe] = len(gts) - tps[iframe]
tps[iframe_tp] = np.sum(np.logical_not(not_in_radius[igt, idet]))
fps[iframe_tp] = (len(dets) - tps[iframe_tp])
fns[iframe_tp] = len(gts) - tps[iframe_tp]
tp, fp, fn = np.sum(tps), np.sum(fps), np.sum(fns)
precisions[i] = tp / (fp + tp) if fp + tp > 0 else np.nan