|
from importlib import import_module |
|
|
|
from torch.utils.data import dataloader |
|
from torch.utils.data import ConcatDataset |
|
import torch |
|
import random |
|
|
|
class MyConcatDataset(ConcatDataset): |
|
def __init__(self, datasets): |
|
super(MyConcatDataset, self).__init__(datasets) |
|
|
|
|
|
def set_scale(self, idx_scale): |
|
for d in self.datasets: |
|
if hasattr(d, 'set_scale'): d.set_scale(idx_scale) |
|
|
|
class Data: |
|
def __init__(self, args): |
|
self.loader_train = None |
|
self.loader_test = [] |
|
for d in args.data_test: |
|
if d in ['Set5', 'Set14', 'B100', 'Urban100']: |
|
m = import_module('data.benchmark') |
|
testset = getattr(m, 'Benchmark')(args, name=d) |
|
else: |
|
assert NotImplementedError |
|
|
|
self.loader_test.append( |
|
dataloader.DataLoader( |
|
testset, |
|
batch_size=1, |
|
shuffle=False, |
|
pin_memory=False, |
|
num_workers=args.n_threads, |
|
) |
|
) |
|
|