diff --git a/python/fastdeploy/vision/evaluation/detection.py b/python/fastdeploy/vision/evaluation/detection.py index 98c6794fe..a13e0429e 100644 --- a/python/fastdeploy/vision/evaluation/detection.py +++ b/python/fastdeploy/vision/evaluation/detection.py @@ -23,7 +23,8 @@ def eval_detection(model, ann_file, conf_threshold=None, nms_iou_threshold=None, - plot=False): + plot=False, + batch_size=1): from .utils import CocoDetection from .utils import COCOMetric import cv2 @@ -54,6 +55,8 @@ def eval_detection(model, start_time = 0 end_time = 0 average_inference_time = 0 + im_list = list() + im_id_list = list() for image_info, i in zip(all_image_info, trange( image_num, desc="Inference Progress")): @@ -61,19 +64,43 @@ def eval_detection(model, start_time = time.time() im = cv2.imread(image_info["image"]) im_id = image_info["im_id"] - if conf_threshold is None and nms_iou_threshold is None: - result = model.predict(im.copy()) + if batch_size == 1: + if conf_threshold is None and nms_iou_threshold is None: + result = model.predict(im.copy()) + else: + result = model.predict(im, conf_threshold, nms_iou_threshold) + pred = { + 'bbox': [[c] + [s] + b + for b, s, c in zip(result.boxes, result.scores, + result.label_ids)], + 'bbox_num': len(result.boxes), + 'im_id': im_id + } + eval_metric.update(im_id, pred) else: - result = model.predict(im, conf_threshold, nms_iou_threshold) - pred = { - 'bbox': - [[c] + [s] + b - for b, s, c in zip(result.boxes, result.scores, result.label_ids) - ], - 'bbox_num': len(result.boxes), - 'im_id': im_id - } - eval_metric.update(im_id, pred) + im_list.append(im) + im_id_list.append(im_id) + # If the batch_size is not satisfied, the remaining pictures are formed into a batch + if (i + 1) % batch_size != 0 and i != image_num - 1: + continue + if conf_threshold is None and nms_iou_threshold is None: + results = model.batch_predict(im_list) + else: + model.postprocessor.conf_threshold = conf_threshold + model.postprocessor.nms_threshold = nms_iou_threshold + results = model.batch_predict(im_list) + for k in range(len(im_list)): + pred = { + 'bbox': [[c] + [s] + b + for b, s, c in zip(results[k].boxes, results[ + k].scores, results[k].label_ids)], + 'bbox_num': len(results[k].boxes), + 'im_id': im_id_list[k] + } + eval_metric.update(im_id_list[k], pred) + im_list.clear() + im_id_list.clear() + if i == image_num - 1: end_time = time.time() average_inference_time = round(