diff --git a/fastdeploy/vision/evaluation/detection.py b/fastdeploy/vision/evaluation/detection.py index 4aaaaaaa5..cd09046f7 100644 --- a/fastdeploy/vision/evaluation/detection.py +++ b/fastdeploy/vision/evaluation/detection.py @@ -22,18 +22,21 @@ import collections def eval_detection(model, - conf_threshold, - nms_iou_threshold, data_dir, ann_file, + conf_threshold=None, + nms_iou_threshold=None, plot=False): - assert isinstance(conf_threshold, ( - float, int - )), "The conf_threshold:{} need to be int or float".format(conf_threshold) - assert isinstance(nms_iou_threshold, ( - float, - int)), "The nms_iou_threshold:{} need to be int or float".format( - nms_iou_threshold) + if conf_threshold is not None or nms_iou_threshold is not None: + assert conf_threshold is not None and nms_iou_threshold is not None, "The conf_threshold and nms_iou_threshold should be setted at the same time" + assert isinstance(conf_threshold, ( + float, + int)), "The conf_threshold:{} need to be int or float".format( + conf_threshold) + assert isinstance(nms_iou_threshold, ( + float, + int)), "The nms_iou_threshold:{} need to be int or float".format( + nms_iou_threshold) eval_dataset = CocoDetection( data_dir=data_dir, ann_file=ann_file, shuffle=False) all_image_info = eval_dataset.file_list @@ -49,7 +52,10 @@ def eval_detection(model, image_num, desc="Inference Progress")): im = cv2.imread(image_info["image"]) im_id = image_info["im_id"] - result = model.predict(im, conf_threshold, nms_iou_threshold) + if conf_threshold is None and nms_iou_threshold is None: + result = model.predict(im) + else: + result = model.predict(im, conf_threshold, nms_iou_threshold) pred = { 'bbox': [[c] + [s] + b