|
|
|
|
|
|
|
import os |
|
import cv2 |
|
import numpy as np |
|
from loguru import logger |
|
from functools import wraps |
|
from pycocotools.coco import COCO |
|
from torch.utils.data.dataset import Dataset as torchDataset |
|
|
|
COCO_CLASSES = ( |
|
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', |
|
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', |
|
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', |
|
'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', |
|
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', |
|
'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', |
|
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', |
|
'teddy bear', 'hair drier', 'toothbrush') |
|
|
|
|
|
def remove_useless_info(coco): |
|
""" |
|
Remove useless info in coco dataset. COCO object is modified inplace. |
|
This function is mainly used for saving memory (save about 30% mem). |
|
""" |
|
if isinstance(coco, COCO): |
|
dataset = coco.dataset |
|
dataset.pop("info", None) |
|
dataset.pop("licenses", None) |
|
for img in dataset["images"]: |
|
img.pop("license", None) |
|
img.pop("coco_url", None) |
|
img.pop("date_captured", None) |
|
img.pop("flickr_url", None) |
|
if "annotations" in coco.dataset: |
|
for anno in coco.dataset["annotations"]: |
|
anno.pop("segmentation", None) |
|
|
|
|
|
class Dataset(torchDataset): |
|
""" This class is a subclass of the base :class:`torch.utils.data.Dataset`, |
|
that enables on the fly resizing of the ``input_dim``. |
|
|
|
Args: |
|
input_dimension (tuple): (width,height) tuple with default dimensions of the network |
|
""" |
|
|
|
def __init__(self, input_dimension, mosaic=True): |
|
super().__init__() |
|
self.__input_dim = input_dimension[:2] |
|
self.enable_mosaic = mosaic |
|
|
|
@property |
|
def input_dim(self): |
|
""" |
|
Dimension that can be used by transforms to set the correct image size, etc. |
|
This allows transforms to have a single source of truth |
|
for the input dimension of the network. |
|
|
|
Return: |
|
list: Tuple containing the current width,height |
|
""" |
|
if hasattr(self, "_input_dim"): |
|
return self._input_dim |
|
return self.__input_dim |
|
|
|
@staticmethod |
|
def mosaic_getitem(getitem_fn): |
|
""" |
|
Decorator method that needs to be used around the ``__getitem__`` method. |br| |
|
This decorator enables the closing mosaic |
|
|
|
Example: |
|
>>> class CustomSet(ln.data.Dataset): |
|
... def __len__(self): |
|
... return 10 |
|
... @ln.data.Dataset.mosaic_getitem |
|
... def __getitem__(self, index): |
|
... return self.enable_mosaic |
|
""" |
|
|
|
@wraps(getitem_fn) |
|
def wrapper(self, index): |
|
if not isinstance(index, int): |
|
self.enable_mosaic = index[0] |
|
index = index[1] |
|
ret_val = getitem_fn(self, index) |
|
return ret_val |
|
|
|
return wrapper |
|
|
|
|
|
class COCODataset(Dataset): |
|
""" |
|
COCO dataset class. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
data_dir='data/COCO', |
|
json_file="instances_train2017.json", |
|
name="train2017", |
|
img_size=(416, 416), |
|
preproc=None |
|
): |
|
""" |
|
COCO dataset initialization. Annotation data are read into memory by COCO API. |
|
Args: |
|
data_dir (str): dataset root directory |
|
json_file (str): COCO json file name |
|
name (str): COCO data name (e.g. 'train2017' or 'val2017') |
|
img_size (tuple(int)): target image size after pre-processing |
|
preproc: data augmentation strategy |
|
""" |
|
super().__init__(img_size) |
|
self.data_dir = data_dir |
|
self.json_file = json_file |
|
self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file)) |
|
remove_useless_info(self.coco) |
|
self.ids = self.coco.getImgIds() |
|
self.class_ids = sorted(self.coco.getCatIds()) |
|
self.cats = self.coco.loadCats(self.coco.getCatIds()) |
|
self._classes = tuple([c["name"] for c in self.cats]) |
|
self.imgs = None |
|
self.name = name |
|
self.img_size = img_size |
|
self.preproc = preproc |
|
self.annotations = self._load_coco_annotations() |
|
|
|
def __len__(self): |
|
return len(self.ids) |
|
|
|
def __del__(self): |
|
del self.imgs |
|
|
|
def _load_coco_annotations(self): |
|
return [self.load_anno_from_ids(_ids) for _ids in self.ids] |
|
|
|
def load_anno_from_ids(self, id_): |
|
im_ann = self.coco.loadImgs(id_)[0] |
|
width = im_ann["width"] |
|
height = im_ann["height"] |
|
anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=False) |
|
annotations = self.coco.loadAnns(anno_ids) |
|
objs = [] |
|
for obj in annotations: |
|
x1 = np.max((0, obj["bbox"][0])) |
|
y1 = np.max((0, obj["bbox"][1])) |
|
x2 = np.min((width, x1 + np.max((0, obj["bbox"][2])))) |
|
y2 = np.min((height, y1 + np.max((0, obj["bbox"][3])))) |
|
if obj["area"] > 0 and x2 >= x1 and y2 >= y1: |
|
obj["clean_bbox"] = [x1, y1, x2, y2] |
|
objs.append(obj) |
|
num_objs = len(objs) |
|
res = np.zeros((num_objs, 5)) |
|
for ix, obj in enumerate(objs): |
|
cls = self.class_ids.index(obj["category_id"]) |
|
res[ix, 0:4] = obj["clean_bbox"] |
|
res[ix, 4] = cls |
|
r = min(self.img_size[0] / height, self.img_size[1] / width) |
|
res[:, :4] *= r |
|
img_info = (height, width) |
|
resized_info = (int(height * r), int(width * r)) |
|
file_name = ( |
|
im_ann["file_name"] |
|
if "file_name" in im_ann |
|
else "{:012}".format(id_) + ".jpg" |
|
) |
|
return res, img_info, resized_info, file_name |
|
|
|
def load_anno(self, index): |
|
return self.annotations[index][0] |
|
|
|
def load_resized_img(self, index): |
|
img = self.load_image(index) |
|
r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1]) |
|
resized_img = cv2.resize( |
|
img, |
|
(int(img.shape[1] * r), int(img.shape[0] * r)), |
|
interpolation=cv2.INTER_LINEAR, |
|
).astype(np.uint8) |
|
return resized_img |
|
|
|
def load_image(self, index): |
|
file_name = self.annotations[index][3] |
|
img_file = os.path.join(self.data_dir, self.name, file_name) |
|
img = cv2.imread(img_file) |
|
assert img is not None, f"file named {img_file} not found" |
|
return img |
|
|
|
def pull_item(self, index): |
|
id_ = self.ids[index] |
|
res, img_info, resized_info, _ = self.annotations[index] |
|
if self.imgs is not None: |
|
pad_img = self.imgs[index] |
|
img = pad_img[: resized_info[0], : resized_info[1], :].copy() |
|
else: |
|
img = self.load_resized_img(index) |
|
return img, res.copy(), img_info, np.array([id_]) |
|
|
|
@Dataset.mosaic_getitem |
|
def __getitem__(self, index): |
|
""" |
|
One image / label pair for the given index is picked up and pre-processed. |
|
|
|
Args: |
|
index (int): data index |
|
|
|
Returns: |
|
img (numpy.ndarray): pre-processed image |
|
target (torch.Tensor): pre-processed label data. |
|
The shape is :math:`[max_labels, 5]`. |
|
each label consists of [class, xc, yc, w, h]: |
|
class (float): class index. |
|
xc, yc (float) : center of bbox whose values range from 0 to 1. |
|
w, h (float) : size of bbox whose values range from 0 to 1. |
|
img_info : tuple of h, w. |
|
h, w (int): original shape of the image |
|
img_id (int): same as the input index. Used for evaluation. |
|
""" |
|
img, target, img_info, img_id = self.pull_item(index) |
|
if self.preproc is not None: |
|
img, target = self.preproc(img, target, self.input_dim) |
|
return img, target, img_info, img_id |
|
|