mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
179 lines
7.1 KiB
Python
179 lines
7.1 KiB
Python
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
from __future__ import absolute_import
|
||
import copy
|
||
import os.path as osp
|
||
import sys
|
||
import numpy as np
|
||
from . import fd_logging as logging
|
||
from .util import is_pic, get_num_workers
|
||
|
||
|
||
class CocoDetection(object):
|
||
"""读取MSCOCO格式的检测数据集,并对样本进行相应的处理,该格式的数据集同样可以应用到实例分割模型的训练中。
|
||
|
||
Args:
|
||
data_dir (str): 数据集所在的目录路径。
|
||
ann_file (str): 数据集的标注文件,为一个独立的json格式文件。
|
||
num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
|
||
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。
|
||
shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
|
||
allow_empty (bool): 是否加载负样本。默认为False。
|
||
empty_ratio (float): 用于指定负样本占总样本数的比例。如果小于0或大于等于1,则保留全部的负样本。默认为1。
|
||
"""
|
||
|
||
def __init__(self,
|
||
data_dir,
|
||
ann_file,
|
||
num_workers='auto',
|
||
shuffle=False,
|
||
allow_empty=False,
|
||
empty_ratio=1.):
|
||
|
||
from pycocotools.coco import COCO
|
||
self.data_dir = data_dir
|
||
self.data_fields = None
|
||
self.num_max_boxes = 1000
|
||
self.num_workers = get_num_workers(num_workers)
|
||
self.shuffle = shuffle
|
||
self.allow_empty = allow_empty
|
||
self.empty_ratio = empty_ratio
|
||
self.file_list = list()
|
||
neg_file_list = list()
|
||
self.labels = list()
|
||
|
||
coco = COCO(ann_file)
|
||
self.coco_gt = coco
|
||
img_ids = sorted(coco.getImgIds())
|
||
cat_ids = coco.getCatIds()
|
||
catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
|
||
cname2clsid = dict({
|
||
coco.loadCats(catid)[0]['name']: clsid
|
||
for catid, clsid in catid2clsid.items()
|
||
})
|
||
for label, cid in sorted(cname2clsid.items(), key=lambda d: d[1]):
|
||
self.labels.append(label)
|
||
logging.info("Starting to read file list from dataset...")
|
||
|
||
ct = 0
|
||
for img_id in img_ids:
|
||
is_empty = False
|
||
img_anno = coco.loadImgs(img_id)[0]
|
||
im_fname = osp.join(data_dir, img_anno['file_name'])
|
||
if not is_pic(im_fname):
|
||
continue
|
||
im_w = float(img_anno['width'])
|
||
im_h = float(img_anno['height'])
|
||
ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
|
||
instances = coco.loadAnns(ins_anno_ids)
|
||
|
||
bboxes = []
|
||
for inst in instances:
|
||
x, y, box_w, box_h = inst['bbox']
|
||
x1 = max(0, x)
|
||
y1 = max(0, y)
|
||
x2 = min(im_w - 1, x1 + max(0, box_w))
|
||
y2 = min(im_h - 1, y1 + max(0, box_h))
|
||
if inst['area'] > 0 and x2 >= x1 and y2 >= y1:
|
||
inst['clean_bbox'] = [x1, y1, x2, y2]
|
||
bboxes.append(inst)
|
||
else:
|
||
logging.warning(
|
||
"Found an invalid bbox in annotations: "
|
||
"im_id: {}, area: {} x1: {}, y1: {}, x2: {}, y2: {}."
|
||
.format(img_id, float(inst['area']), x1, y1, x2, y2))
|
||
num_bbox = len(bboxes)
|
||
if num_bbox == 0 and not self.allow_empty:
|
||
continue
|
||
elif num_bbox == 0:
|
||
is_empty = True
|
||
|
||
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
|
||
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
|
||
gt_score = np.ones((num_bbox, 1), dtype=np.float32)
|
||
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
|
||
difficult = np.zeros((num_bbox, 1), dtype=np.int32)
|
||
gt_poly = [None] * num_bbox
|
||
|
||
has_segmentation = False
|
||
for i, box in reversed(list(enumerate(bboxes))):
|
||
catid = box['category_id']
|
||
gt_class[i][0] = catid2clsid[catid]
|
||
gt_bbox[i, :] = box['clean_bbox']
|
||
is_crowd[i][0] = box['iscrowd']
|
||
if 'segmentation' in box and box['iscrowd'] == 1:
|
||
gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
|
||
elif 'segmentation' in box and box['segmentation']:
|
||
if not np.array(
|
||
box['segmentation'],
|
||
dtype=object).size > 0 and not self.allow_empty:
|
||
gt_poly.pop(i)
|
||
is_crowd = np.delete(is_crowd, i)
|
||
gt_class = np.delete(gt_class, i)
|
||
gt_bbox = np.delete(gt_bbox, i)
|
||
else:
|
||
gt_poly[i] = box['segmentation']
|
||
has_segmentation = True
|
||
if has_segmentation and not any(gt_poly) and not self.allow_empty:
|
||
continue
|
||
|
||
im_info = {
|
||
'im_id': np.array([img_id]).astype('int32'),
|
||
'image_shape': np.array([im_h, im_w]).astype('int32'),
|
||
}
|
||
label_info = {
|
||
'is_crowd': is_crowd,
|
||
'gt_class': gt_class,
|
||
'gt_bbox': gt_bbox,
|
||
'gt_score': gt_score,
|
||
'gt_poly': gt_poly,
|
||
'difficult': difficult
|
||
}
|
||
|
||
if is_empty:
|
||
neg_file_list.append({
|
||
'image': im_fname,
|
||
**
|
||
im_info,
|
||
**
|
||
label_info
|
||
})
|
||
else:
|
||
self.file_list.append({
|
||
'image': im_fname,
|
||
**
|
||
im_info,
|
||
**
|
||
label_info
|
||
})
|
||
ct += 1
|
||
|
||
self.num_max_boxes = max(self.num_max_boxes, len(instances))
|
||
|
||
if not ct:
|
||
logging.error(
|
||
"No coco record found in %s' % (ann_file)", exit=True)
|
||
self.pos_num = len(self.file_list)
|
||
if self.allow_empty and neg_file_list:
|
||
self.file_list += self._sample_empty(neg_file_list)
|
||
logging.info(
|
||
"{} samples in file {}, including {} positive samples and {} negative samples.".
|
||
format(
|
||
len(self.file_list), ann_file, self.pos_num,
|
||
len(self.file_list) - self.pos_num))
|
||
self.num_samples = len(self.file_list)
|
||
|
||
self._epoch = 0
|