yolox-s / coco.py
zhengrongzhang's picture
init model
1cff332
raw
history blame
8.48 kB
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import os
import cv2
import numpy as np
from loguru import logger
from functools import wraps
from pycocotools.coco import COCO
from torch.utils.data.dataset import Dataset as torchDataset
COCO_CLASSES = (
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush')
def remove_useless_info(coco):
"""
Remove useless info in coco dataset. COCO object is modified inplace.
This function is mainly used for saving memory (save about 30% mem).
"""
if isinstance(coco, COCO):
dataset = coco.dataset
dataset.pop("info", None)
dataset.pop("licenses", None)
for img in dataset["images"]:
img.pop("license", None)
img.pop("coco_url", None)
img.pop("date_captured", None)
img.pop("flickr_url", None)
if "annotations" in coco.dataset:
for anno in coco.dataset["annotations"]:
anno.pop("segmentation", None)
class Dataset(torchDataset):
""" This class is a subclass of the base :class:`torch.utils.data.Dataset`,
that enables on the fly resizing of the ``input_dim``.
Args:
input_dimension (tuple): (width,height) tuple with default dimensions of the network
"""
def __init__(self, input_dimension, mosaic=True):
super().__init__()
self.__input_dim = input_dimension[:2]
self.enable_mosaic = mosaic
@property
def input_dim(self):
"""
Dimension that can be used by transforms to set the correct image size, etc.
This allows transforms to have a single source of truth
for the input dimension of the network.
Return:
list: Tuple containing the current width,height
"""
if hasattr(self, "_input_dim"):
return self._input_dim
return self.__input_dim
@staticmethod
def mosaic_getitem(getitem_fn):
"""
Decorator method that needs to be used around the ``__getitem__`` method. |br|
This decorator enables the closing mosaic
Example:
>>> class CustomSet(ln.data.Dataset):
... def __len__(self):
... return 10
... @ln.data.Dataset.mosaic_getitem
... def __getitem__(self, index):
... return self.enable_mosaic
"""
@wraps(getitem_fn)
def wrapper(self, index):
if not isinstance(index, int):
self.enable_mosaic = index[0]
index = index[1]
ret_val = getitem_fn(self, index)
return ret_val
return wrapper
class COCODataset(Dataset):
"""
COCO dataset class.
"""
def __init__(
self,
data_dir='data/COCO',
json_file="instances_train2017.json",
name="train2017",
img_size=(416, 416),
preproc=None
):
"""
COCO dataset initialization. Annotation data are read into memory by COCO API.
Args:
data_dir (str): dataset root directory
json_file (str): COCO json file name
name (str): COCO data name (e.g. 'train2017' or 'val2017')
img_size (tuple(int)): target image size after pre-processing
preproc: data augmentation strategy
"""
super().__init__(img_size)
self.data_dir = data_dir
self.json_file = json_file
self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file))
remove_useless_info(self.coco)
self.ids = self.coco.getImgIds()
self.class_ids = sorted(self.coco.getCatIds())
self.cats = self.coco.loadCats(self.coco.getCatIds())
self._classes = tuple([c["name"] for c in self.cats])
self.imgs = None
self.name = name
self.img_size = img_size
self.preproc = preproc
self.annotations = self._load_coco_annotations()
def __len__(self):
return len(self.ids)
def __del__(self):
del self.imgs
def _load_coco_annotations(self):
return [self.load_anno_from_ids(_ids) for _ids in self.ids]
def load_anno_from_ids(self, id_):
im_ann = self.coco.loadImgs(id_)[0]
width = im_ann["width"]
height = im_ann["height"]
anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=False)
annotations = self.coco.loadAnns(anno_ids)
objs = []
for obj in annotations:
x1 = np.max((0, obj["bbox"][0]))
y1 = np.max((0, obj["bbox"][1]))
x2 = np.min((width, x1 + np.max((0, obj["bbox"][2]))))
y2 = np.min((height, y1 + np.max((0, obj["bbox"][3]))))
if obj["area"] > 0 and x2 >= x1 and y2 >= y1:
obj["clean_bbox"] = [x1, y1, x2, y2]
objs.append(obj)
num_objs = len(objs)
res = np.zeros((num_objs, 5))
for ix, obj in enumerate(objs):
cls = self.class_ids.index(obj["category_id"])
res[ix, 0:4] = obj["clean_bbox"]
res[ix, 4] = cls
r = min(self.img_size[0] / height, self.img_size[1] / width)
res[:, :4] *= r
img_info = (height, width)
resized_info = (int(height * r), int(width * r))
file_name = (
im_ann["file_name"]
if "file_name" in im_ann
else "{:012}".format(id_) + ".jpg"
)
return res, img_info, resized_info, file_name
def load_anno(self, index):
return self.annotations[index][0]
def load_resized_img(self, index):
img = self.load_image(index)
r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR,
).astype(np.uint8)
return resized_img
def load_image(self, index):
file_name = self.annotations[index][3]
img_file = os.path.join(self.data_dir, self.name, file_name)
img = cv2.imread(img_file)
assert img is not None, f"file named {img_file} not found"
return img
def pull_item(self, index):
id_ = self.ids[index]
res, img_info, resized_info, _ = self.annotations[index]
if self.imgs is not None:
pad_img = self.imgs[index]
img = pad_img[: resized_info[0], : resized_info[1], :].copy()
else:
img = self.load_resized_img(index)
return img, res.copy(), img_info, np.array([id_])
@Dataset.mosaic_getitem
def __getitem__(self, index):
"""
One image / label pair for the given index is picked up and pre-processed.
Args:
index (int): data index
Returns:
img (numpy.ndarray): pre-processed image
target (torch.Tensor): pre-processed label data.
The shape is :math:`[max_labels, 5]`.
each label consists of [class, xc, yc, w, h]:
class (float): class index.
xc, yc (float) : center of bbox whose values range from 0 to 1.
w, h (float) : size of bbox whose values range from 0 to 1.
img_info : tuple of h, w.
h, w (int): original shape of the image
img_id (int): same as the input index. Used for evaluation.
"""
img, target, img_info, img_id = self.pull_item(index)
if self.preproc is not None:
img, target = self.preproc(img, target, self.input_dim)
return img, target, img_info, img_id