import random import numpy as np from PIL import Image, ImageOps, ImageFilter import torch import torch.utils.data as data __all__ = ['BaseDataset'] class BaseDataset(data.Dataset): def __init__(self, root, split, mode=None, transform=None, target_transform=None, base_size=1024, crop_size=512): self.root = root self.transform = transform self.target_transform = target_transform self.split = split self.mode = mode if mode is not None else split self.base_size = base_size self.crop_size = crop_size if self.mode == 'train': print('BaseDataset: base_size {}, crop_size {}'. \ format(base_size, crop_size)) @property def num_class(self): return self.NUM_CLASS def _val_transform(self, img, mask): outsize = self.crop_size short_size = outsize w, h = img.size if w > h: oh = short_size ow = int(1.0 * w * oh / h) else: ow = short_size oh = int(1.0 * h * ow / w) img = img.resize((ow, oh), Image.BILINEAR) mask = mask.resize((ow, oh), Image.NEAREST) # center crop w, h = img.size x1 = int(round((w - outsize) / 2.)) y1 = int(round((h - outsize) / 2.)) img = img.crop((x1, y1, x1+outsize, y1+outsize)) mask = mask.crop((x1, y1, x1+outsize, y1+outsize)) # final transform return img, self._mask_transform(mask) def _testval_transform(self, img, mask): outsize = self.crop_size short_size = outsize w, h = img.size if w > h: oh = short_size ow = int(1.0 * w * oh / h) else: ow = short_size oh = int(1.0 * h * ow / w) img = img.resize((ow, oh), Image.BILINEAR) return img, self._mask_transform(mask) def _train_transform(self, img, mask): # random mirror if random.random() < 0.5: img = img.transpose(Image.FLIP_LEFT_RIGHT) mask = mask.transpose(Image.FLIP_LEFT_RIGHT) crop_size = self.crop_size w, h = img.size long_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.0)) if h > w: oh = long_size ow = int(1.0 * w * long_size / h + 0.5) short_size = ow else: ow = long_size oh = int(1.0 * h * long_size / w + 0.5) short_size = oh img = img.resize((ow, oh), Image.BILINEAR) mask = mask.resize((ow, oh), Image.NEAREST) # pad crop if short_size < crop_size: padh = crop_size - oh if oh < crop_size else 0 padw = crop_size - ow if ow < crop_size else 0 img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) # random crop crop_size w, h = img.size x1 = random.randint(0, w - crop_size) y1 = random.randint(0, h - crop_size) img = img.crop((x1, y1, x1+crop_size, y1+crop_size)) mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size)) # final transform return img, self._mask_transform(mask) def _mask_transform(self, mask): return torch.from_numpy(np.array(mask)).long()