Evaluation support model containing nms (#51)

* Detection evaluation function

* Add license

* Fix python import problem

* Modify requirement.txt

* Add requirements.txt

* Evaluation support model containing nms

* Delete useless code

Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
huangjianhui
2022-07-29 11:49:20 +08:00
committed by GitHub
parent 5c128c4b30
commit fc83115320

View File

@@ -22,18 +22,21 @@ import collections
def eval_detection(model, def eval_detection(model,
conf_threshold,
nms_iou_threshold,
data_dir, data_dir,
ann_file, ann_file,
conf_threshold=None,
nms_iou_threshold=None,
plot=False): plot=False):
assert isinstance(conf_threshold, ( if conf_threshold is not None or nms_iou_threshold is not None:
float, int assert conf_threshold is not None and nms_iou_threshold is not None, "The conf_threshold and nms_iou_threshold should be setted at the same time"
)), "The conf_threshold:{} need to be int or float".format(conf_threshold) assert isinstance(conf_threshold, (
assert isinstance(nms_iou_threshold, ( float,
float, int)), "The conf_threshold:{} need to be int or float".format(
int)), "The nms_iou_threshold:{} need to be int or float".format( conf_threshold)
nms_iou_threshold) assert isinstance(nms_iou_threshold, (
float,
int)), "The nms_iou_threshold:{} need to be int or float".format(
nms_iou_threshold)
eval_dataset = CocoDetection( eval_dataset = CocoDetection(
data_dir=data_dir, ann_file=ann_file, shuffle=False) data_dir=data_dir, ann_file=ann_file, shuffle=False)
all_image_info = eval_dataset.file_list all_image_info = eval_dataset.file_list
@@ -49,7 +52,10 @@ def eval_detection(model,
image_num, desc="Inference Progress")): image_num, desc="Inference Progress")):
im = cv2.imread(image_info["image"]) im = cv2.imread(image_info["image"])
im_id = image_info["im_id"] im_id = image_info["im_id"]
result = model.predict(im, conf_threshold, nms_iou_threshold) if conf_threshold is None and nms_iou_threshold is None:
result = model.predict(im)
else:
result = model.predict(im, conf_threshold, nms_iou_threshold)
pred = { pred = {
'bbox': 'bbox':
[[c] + [s] + b [[c] + [s] + b