# Copyright (c) OpenMMLab. All rights reserved. import collections import os.path as osp import random from typing import Dict, List import mmengine from mmengine.dataset import BaseDataset # from mmdet.registry import DATASETS # @DATASETS.register_module() class RefCocoDataset(BaseDataset): """RefCOCO dataset. The `Refcoco` and `Refcoco+` dataset is based on `ReferItGame: Referring to Objects in Photographs of Natural Scenes `_. The `Refcocog` dataset is based on `Generation and Comprehension of Unambiguous Object Descriptions `_. Args: ann_file (str): Annotation file path. data_root (str): The root directory for ``data_prefix`` and ``ann_file``. Defaults to ''. data_prefix (str): Prefix for training data. split_file (str): Split file path. split (str): Split name. Defaults to 'train'. text_mode (str): Text mode. Defaults to 'random'. **kwargs: Other keyword arguments in :class:`BaseDataset`. """ def __init__(self, data_root: str, ann_file: str, split_file: str, data_prefix: Dict, split: str = 'train', text_mode: str = 'random', **kwargs): self.split_file = split_file self.split = split assert text_mode in ['original', 'random', 'concat', 'select_first'] self.text_mode = text_mode super().__init__( data_root=data_root, data_prefix=data_prefix, ann_file=ann_file, **kwargs, ) def _join_prefix(self): if not mmengine.is_abs(self.split_file) and self.split_file: self.split_file = osp.join(self.data_root, self.split_file) return super()._join_prefix() def _init_refs(self): """Initialize the refs for RefCOCO.""" anns, imgs = {}, {} for ann in self.instances['annotations']: anns[ann['id']] = ann for img in self.instances['images']: imgs[img['id']] = img refs, ref_to_ann = {}, {} for ref in self.splits: # ids ref_id = ref['ref_id'] ann_id = ref['ann_id'] # add mapping related to ref refs[ref_id] = ref ref_to_ann[ref_id] = anns[ann_id] self.refs = refs self.ref_to_ann = ref_to_ann def load_data_list(self) -> List[dict]: """Load data list.""" self.splits = mmengine.load(self.split_file, file_format='pkl') self.instances = mmengine.load(self.ann_file, file_format='json') self._init_refs() img_prefix = self.data_prefix['img_path'] ref_ids = [ ref['ref_id'] for ref in self.splits if ref['split'] == self.split ] full_anno = [] for ref_id in ref_ids: ref = self.refs[ref_id] ann = self.ref_to_ann[ref_id] ann.update(ref) full_anno.append(ann) image_id_list = [] final_anno = {} for anno in full_anno: image_id_list.append(anno['image_id']) final_anno[anno['ann_id']] = anno annotations = [value for key, value in final_anno.items()] coco_train_id = [] image_annot = {} for i in range(len(self.instances['images'])): coco_train_id.append(self.instances['images'][i]['id']) image_annot[self.instances['images'][i] ['id']] = self.instances['images'][i] images = [] for image_id in list(set(image_id_list)): images += [image_annot[image_id]] data_list = [] grounding_dict = collections.defaultdict(list) for anno in annotations: image_id = int(anno['image_id']) grounding_dict[image_id].append(anno) join_path = mmengine.fileio.get_file_backend(img_prefix).join_path for image in images: img_id = image['id'] instances = [] sentences = [] for grounding_anno in grounding_dict[img_id]: texts = [x['raw'].lower() for x in grounding_anno['sentences']] # random select one text if self.text_mode == 'random': idx = random.randint(0, len(texts) - 1) text = [texts[idx]] # concat all texts elif self.text_mode == 'concat': text = [''.join(texts)] # select the first text elif self.text_mode == 'select_first': text = [texts[0]] # use all texts elif self.text_mode == 'original': text = texts else: raise ValueError(f'Invalid text mode "{self.text_mode}".') ins = [{ 'mask': grounding_anno['segmentation'], 'ignore_flag': 0 }] * len(text) instances.extend(ins) sentences.extend(text) data_info = { 'img_path': join_path(img_prefix, image['file_name']), 'img_id': img_id, 'instances': instances, 'text': sentences } data_list.append(data_info) if len(data_list) == 0: raise ValueError(f'No sample in split "{self.split}".') return data_list