Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import os | |
from pathlib import Path | |
import albumentations | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
logger = logging.getLogger(f'main.{__name__}') | |
class StandardNormalizeAudio(object): | |
''' | |
Frequency-wise normalization | |
''' | |
def __init__(self, specs_dir, train_ids_path='./data/vggsound_train.txt', cache_path='./data/'): | |
self.specs_dir = specs_dir | |
self.train_ids_path = train_ids_path | |
# making the stats filename to match the specs dir name | |
self.cache_path = os.path.join(cache_path, f'train_means_stds_{Path(specs_dir).stem}.txt') | |
logger.info('Assuming that the input stats are calculated using preprocessed spectrograms (log)') | |
self.train_stats = self.calculate_or_load_stats() | |
def __call__(self, item): | |
# just to generalizat the input handling. Useful for FID, IS eval and training other staff | |
if isinstance(item, dict): | |
if 'input' in item: | |
input_key = 'input' | |
elif 'image' in item: | |
input_key = 'image' | |
else: | |
raise NotImplementedError | |
item[input_key] = (item[input_key] - self.train_stats['means']) / self.train_stats['stds'] | |
elif isinstance(item, torch.Tensor): | |
# broadcasts np.ndarray (80, 1) to (1, 80, 1) because item is torch.Tensor (B, 80, T) | |
item = (item - self.train_stats['means']) / self.train_stats['stds'] | |
else: | |
raise NotImplementedError | |
return item | |
def calculate_or_load_stats(self): | |
try: | |
# (F, 2) | |
train_stats = np.loadtxt(self.cache_path) | |
means, stds = train_stats.T | |
logger.info('Trying to load train stats for Standard Normalization of inputs') | |
except OSError: | |
logger.info('Could not find the precalculated stats for Standard Normalization. Calculating...') | |
train_vid_ids = open(self.train_ids_path) | |
specs_paths = [os.path.join(self.specs_dir, f'{i.rstrip()}_mel.npy') for i in train_vid_ids] | |
means = [None] * len(specs_paths) | |
stds = [None] * len(specs_paths) | |
for i, path in enumerate(tqdm(specs_paths)): | |
spec = np.load(path) | |
means[i] = spec.mean(axis=1) | |
stds[i] = spec.std(axis=1) | |
# (F) <- (num_files, F) | |
means = np.array(means).mean(axis=0) | |
stds = np.array(stds).mean(axis=0) | |
# saving in two columns | |
np.savetxt(self.cache_path, np.vstack([means, stds]).T, fmt='%0.8f') | |
means = means.reshape(-1, 1) | |
stds = stds.reshape(-1, 1) | |
return {'means': means, 'stds': stds} | |
class ToTensor(object): | |
def __call__(self, item): | |
item['input'] = torch.from_numpy(item['input']).float() | |
# if 'target' in item: | |
item['target'] = torch.tensor(item['target']) | |
return item | |
class Crop(object): | |
def __init__(self, cropped_shape=None, random_crop=False): | |
self.cropped_shape = cropped_shape | |
if cropped_shape is not None: | |
mel_num, spec_len = cropped_shape | |
if random_crop: | |
self.cropper = albumentations.RandomCrop | |
else: | |
self.cropper = albumentations.CenterCrop | |
self.preprocessor = albumentations.Compose([self.cropper(mel_num, spec_len)]) | |
else: | |
self.preprocessor = lambda **kwargs: kwargs | |
def __call__(self, item): | |
item['input'] = self.preprocessor(image=item['input'])['image'] | |
return item | |
if __name__ == '__main__': | |
cropper = Crop([80, 848]) | |
item = {'input': torch.rand([80, 860])} | |
outputs = cropper(item) | |
print(outputs['input'].shape) | |