import random import numpy as np import skimage.color as sc import torch def set_channel(*args, n_channels=3): def _set_channel(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) c = img.shape[2] if n_channels == 1 and c == 3: img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) elif n_channels == 3 and c == 1: img = np.concatenate([img] * n_channels, 2) return img return [_set_channel(a) for a in args] def np2Tensor(*args, rgb_range=255, format='NCHW'): def _np2Tensor(img, channel_format): assert channel_format in ('NCHW', 'NHWC') img = np.ascontiguousarray(img.transpose((2, 0, 1))) if channel_format == ('NCHW') else img tensor = torch.from_numpy(img).float() tensor.mul_(rgb_range / 255) return tensor return [_np2Tensor(a, format) for a in args]