Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
import collections | |
import os.path as osp | |
import random | |
from typing import Dict, List | |
import mmengine | |
from mmengine.dataset import BaseDataset | |
# from mmdet.registry import DATASETS | |
# @DATASETS.register_module() | |
class RefCocoDataset(BaseDataset): | |
"""RefCOCO dataset. | |
The `Refcoco` and `Refcoco+` dataset is based on | |
`ReferItGame: Referring to Objects in Photographs of Natural Scenes | |
<http://tamaraberg.com/papers/referit.pdf>`_. | |
The `Refcocog` dataset is based on | |
`Generation and Comprehension of Unambiguous Object Descriptions | |
<https://arxiv.org/abs/1511.02283>`_. | |
Args: | |
ann_file (str): Annotation file path. | |
data_root (str): The root directory for ``data_prefix`` and | |
``ann_file``. Defaults to ''. | |
data_prefix (str): Prefix for training data. | |
split_file (str): Split file path. | |
split (str): Split name. Defaults to 'train'. | |
text_mode (str): Text mode. Defaults to 'random'. | |
**kwargs: Other keyword arguments in :class:`BaseDataset`. | |
""" | |
def __init__(self, | |
data_root: str, | |
ann_file: str, | |
split_file: str, | |
data_prefix: Dict, | |
split: str = 'train', | |
text_mode: str = 'random', | |
**kwargs): | |
self.split_file = split_file | |
self.split = split | |
assert text_mode in ['original', 'random', 'concat', 'select_first'] | |
self.text_mode = text_mode | |
super().__init__( | |
data_root=data_root, | |
data_prefix=data_prefix, | |
ann_file=ann_file, | |
**kwargs, | |
) | |
def _join_prefix(self): | |
if not mmengine.is_abs(self.split_file) and self.split_file: | |
self.split_file = osp.join(self.data_root, self.split_file) | |
return super()._join_prefix() | |
def _init_refs(self): | |
"""Initialize the refs for RefCOCO.""" | |
anns, imgs = {}, {} | |
for ann in self.instances['annotations']: | |
anns[ann['id']] = ann | |
for img in self.instances['images']: | |
imgs[img['id']] = img | |
refs, ref_to_ann = {}, {} | |
for ref in self.splits: | |
# ids | |
ref_id = ref['ref_id'] | |
ann_id = ref['ann_id'] | |
# add mapping related to ref | |
refs[ref_id] = ref | |
ref_to_ann[ref_id] = anns[ann_id] | |
self.refs = refs | |
self.ref_to_ann = ref_to_ann | |
def load_data_list(self) -> List[dict]: | |
"""Load data list.""" | |
self.splits = mmengine.load(self.split_file, file_format='pkl') | |
self.instances = mmengine.load(self.ann_file, file_format='json') | |
self._init_refs() | |
img_prefix = self.data_prefix['img_path'] | |
ref_ids = [ | |
ref['ref_id'] for ref in self.splits if ref['split'] == self.split | |
] | |
full_anno = [] | |
for ref_id in ref_ids: | |
ref = self.refs[ref_id] | |
ann = self.ref_to_ann[ref_id] | |
ann.update(ref) | |
full_anno.append(ann) | |
image_id_list = [] | |
final_anno = {} | |
for anno in full_anno: | |
image_id_list.append(anno['image_id']) | |
final_anno[anno['ann_id']] = anno | |
annotations = [value for key, value in final_anno.items()] | |
coco_train_id = [] | |
image_annot = {} | |
for i in range(len(self.instances['images'])): | |
coco_train_id.append(self.instances['images'][i]['id']) | |
image_annot[self.instances['images'][i] | |
['id']] = self.instances['images'][i] | |
images = [] | |
for image_id in list(set(image_id_list)): | |
images += [image_annot[image_id]] | |
data_list = [] | |
grounding_dict = collections.defaultdict(list) | |
for anno in annotations: | |
image_id = int(anno['image_id']) | |
grounding_dict[image_id].append(anno) | |
join_path = mmengine.fileio.get_file_backend(img_prefix).join_path | |
for image in images: | |
img_id = image['id'] | |
instances = [] | |
sentences = [] | |
for grounding_anno in grounding_dict[img_id]: | |
texts = [x['raw'].lower() for x in grounding_anno['sentences']] | |
# random select one text | |
if self.text_mode == 'random': | |
idx = random.randint(0, len(texts) - 1) | |
text = [texts[idx]] | |
# concat all texts | |
elif self.text_mode == 'concat': | |
text = [''.join(texts)] | |
# select the first text | |
elif self.text_mode == 'select_first': | |
text = [texts[0]] | |
# use all texts | |
elif self.text_mode == 'original': | |
text = texts | |
else: | |
raise ValueError(f'Invalid text mode "{self.text_mode}".') | |
ins = [{ | |
'mask': grounding_anno['segmentation'], | |
'ignore_flag': 0 | |
}] * len(text) | |
instances.extend(ins) | |
sentences.extend(text) | |
data_info = { | |
'img_path': join_path(img_prefix, image['file_name']), | |
'img_id': img_id, | |
'instances': instances, | |
'text': sentences | |
} | |
data_list.append(data_info) | |
if len(data_list) == 0: | |
raise ValueError(f'No sample in split "{self.split}".') | |
return data_list | |