zhengrongzhang's picture
init model
faac7d4
import random
import numpy as np
from PIL import Image, ImageOps, ImageFilter
import torch
import torch.utils.data as data
__all__ = ['BaseDataset']
class BaseDataset(data.Dataset):
def __init__(self, root, split, mode=None, transform=None,
target_transform=None, base_size=1024, crop_size=512):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.split = split
self.mode = mode if mode is not None else split
self.base_size = base_size
self.crop_size = crop_size
if self.mode == 'train':
print('BaseDataset: base_size {}, crop_size {}'. \
format(base_size, crop_size))
@property
def num_class(self):
return self.NUM_CLASS
def _val_transform(self, img, mask):
outsize = self.crop_size
short_size = outsize
w, h = img.size
if w > h:
oh = short_size
ow = int(1.0 * w * oh / h)
else:
ow = short_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# center crop
w, h = img.size
x1 = int(round((w - outsize) / 2.))
y1 = int(round((h - outsize) / 2.))
img = img.crop((x1, y1, x1+outsize, y1+outsize))
mask = mask.crop((x1, y1, x1+outsize, y1+outsize))
# final transform
return img, self._mask_transform(mask)
def _testval_transform(self, img, mask):
outsize = self.crop_size
short_size = outsize
w, h = img.size
if w > h:
oh = short_size
ow = int(1.0 * w * oh / h)
else:
ow = short_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
return img, self._mask_transform(mask)
def _train_transform(self, img, mask):
# random mirror
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
crop_size = self.crop_size
w, h = img.size
long_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.0))
if h > w:
oh = long_size
ow = int(1.0 * w * long_size / h + 0.5)
short_size = ow
else:
ow = long_size
oh = int(1.0 * h * long_size / w + 0.5)
short_size = oh
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# pad crop
if short_size < crop_size:
padh = crop_size - oh if oh < crop_size else 0
padw = crop_size - ow if ow < crop_size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
# random crop crop_size
w, h = img.size
x1 = random.randint(0, w - crop_size)
y1 = random.randint(0, h - crop_size)
img = img.crop((x1, y1, x1+crop_size, y1+crop_size))
mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size))
# final transform
return img, self._mask_transform(mask)
def _mask_transform(self, mask):
return torch.from_numpy(np.array(mask)).long()