import os import glob import random import pickle from data import common import imageio import torch.utils.data as data class SRData(data.Dataset): def __init__(self, args, name='', benchmark=True, input_data_format='NCHW'): self.args = args self.name = name self.benchmark = benchmark self.input_large = False self.scale = args.scale self.idx_scale = 0 assert input_data_format in ('NCHW', 'NHWC') self.input_data_format = input_data_format self._set_filesystem(args.dir_data) if args.ext.find('img') < 0: path_bin = os.path.join(self.apath, 'bin') os.makedirs(path_bin, exist_ok=True) list_hr, list_lr = self._scan() if args.ext.find('img') >= 0 or benchmark: self.images_hr, self.images_lr = list_hr, list_lr elif args.ext.find('sep') >= 0: os.makedirs( self.dir_hr.replace(self.apath, path_bin), exist_ok=True ) for s in self.scale: os.makedirs( os.path.join( self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s) ), exist_ok=True ) self.images_hr, self.images_lr = [], [[] for _ in self.scale] for h in list_hr: b = h.replace(self.apath, path_bin) b = b.replace(self.ext[0], '.pt') self.images_hr.append(b) self._check_and_load(args.ext, h, b, verbose=True) for i, ll in enumerate(list_lr): for l in ll: b = l.replace(self.apath, path_bin) b = b.replace(self.ext[1], '.pt') self.images_lr[i].append(b) self._check_and_load(args.ext, l, b, verbose=True) # Below functions as used to prepare images def _scan(self): names_hr = sorted( glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) ) names_lr = [[] for _ in self.scale] for f in names_hr: filename, _ = os.path.splitext(os.path.basename(f)) for si, s in enumerate(self.scale): names_lr[si].append(os.path.join( self.dir_lr, 'X{}/{}x{}{}'.format( s, filename, s, self.ext[1] ) )) return names_hr, names_lr def _set_filesystem(self, dir_data): self.apath = os.path.join(dir_data, self.name) self.dir_hr = os.path.join(self.apath, 'HR') self.dir_lr = os.path.join(self.apath, 'LR_bicubic') if self.input_large: self.dir_lr += 'L' self.ext = ('.png', '.png') def _check_and_load(self, ext, img, f, verbose=True): if not os.path.isfile(f) or ext.find('reset') >= 0: if verbose: print('Making a binary: {}'.format(f)) with open(f, 'wb') as _f: pickle.dump(imageio.imread(img), _f) def __getitem__(self, idx): lr, hr, filename = self._load_file(idx) pair = self.get_patch(lr, hr) pair = common.set_channel(*pair, n_channels=self.args.n_colors) pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range, format=self.input_data_format) return pair_t[0], pair_t[1], filename def __len__(self): return len(self.images_hr) def _get_index(self, idx): return idx def _load_file(self, idx): idx = self._get_index(idx) f_hr = self.images_hr[idx] f_lr = self.images_lr[self.idx_scale][idx] filename, _ = os.path.splitext(os.path.basename(f_hr)) if self.args.ext == 'img' or self.benchmark: hr = imageio.imread(f_hr) lr = imageio.imread(f_lr) elif self.args.ext.find('sep') >= 0: with open(f_hr, 'rb') as _f: hr = pickle.load(_f) with open(f_lr, 'rb') as _f: lr = pickle.load(_f) return lr, hr, filename def get_patch(self, lr, hr): scale = self.scale[self.idx_scale] ih, iw = lr.shape[:2] hr = hr[0:ih * scale, 0:iw * scale] return lr, hr def set_scale(self, idx_scale): if not self.input_large: self.idx_scale = idx_scale else: self.idx_scale = random.randint(0, len(self.scale) - 1)