import logging import os from pathlib import Path import albumentations import numpy as np import torch from tqdm import tqdm logger = logging.getLogger(f'main.{__name__}') class StandardNormalizeAudio(object): ''' Frequency-wise normalization ''' def __init__(self, specs_dir, train_ids_path='./data/vggsound_train.txt', cache_path='./data/'): self.specs_dir = specs_dir self.train_ids_path = train_ids_path # making the stats filename to match the specs dir name self.cache_path = os.path.join(cache_path, f'train_means_stds_{Path(specs_dir).stem}.txt') logger.info('Assuming that the input stats are calculated using preprocessed spectrograms (log)') self.train_stats = self.calculate_or_load_stats() def __call__(self, item): # just to generalizat the input handling. Useful for FID, IS eval and training other staff if isinstance(item, dict): if 'input' in item: input_key = 'input' elif 'image' in item: input_key = 'image' else: raise NotImplementedError item[input_key] = (item[input_key] - self.train_stats['means']) / self.train_stats['stds'] elif isinstance(item, torch.Tensor): # broadcasts np.ndarray (80, 1) to (1, 80, 1) because item is torch.Tensor (B, 80, T) item = (item - self.train_stats['means']) / self.train_stats['stds'] else: raise NotImplementedError return item def calculate_or_load_stats(self): try: # (F, 2) train_stats = np.loadtxt(self.cache_path) means, stds = train_stats.T logger.info('Trying to load train stats for Standard Normalization of inputs') except OSError: logger.info('Could not find the precalculated stats for Standard Normalization. Calculating...') train_vid_ids = open(self.train_ids_path) specs_paths = [os.path.join(self.specs_dir, f'{i.rstrip()}_mel.npy') for i in train_vid_ids] means = [None] * len(specs_paths) stds = [None] * len(specs_paths) for i, path in enumerate(tqdm(specs_paths)): spec = np.load(path) means[i] = spec.mean(axis=1) stds[i] = spec.std(axis=1) # (F) <- (num_files, F) means = np.array(means).mean(axis=0) stds = np.array(stds).mean(axis=0) # saving in two columns np.savetxt(self.cache_path, np.vstack([means, stds]).T, fmt='%0.8f') means = means.reshape(-1, 1) stds = stds.reshape(-1, 1) return {'means': means, 'stds': stds} class ToTensor(object): def __call__(self, item): item['input'] = torch.from_numpy(item['input']).float() # if 'target' in item: item['target'] = torch.tensor(item['target']) return item class Crop(object): def __init__(self, cropped_shape=None, random_crop=False): self.cropped_shape = cropped_shape if cropped_shape is not None: mel_num, spec_len = cropped_shape if random_crop: self.cropper = albumentations.RandomCrop else: self.cropper = albumentations.CenterCrop self.preprocessor = albumentations.Compose([self.cropper(mel_num, spec_len)]) else: self.preprocessor = lambda **kwargs: kwargs def __call__(self, item): item['input'] = self.preprocessor(image=item['input'])['image'] return item if __name__ == '__main__': cropper = Crop([80, 848]) item = {'input': torch.rand([80, 860])} outputs = cropper(item) print(outputs['input'].shape)