File size: 3,774 Bytes
a84a65c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import logging
import os
from pathlib import Path

import albumentations
import numpy as np
import torch
from tqdm import tqdm

logger = logging.getLogger(f'main.{__name__}')


class StandardNormalizeAudio(object):
    '''
        Frequency-wise normalization
    '''
    def __init__(self, specs_dir, train_ids_path='./data/vggsound_train.txt', cache_path='./data/'):
        self.specs_dir = specs_dir
        self.train_ids_path = train_ids_path
        # making the stats filename to match the specs dir name
        self.cache_path = os.path.join(cache_path, f'train_means_stds_{Path(specs_dir).stem}.txt')
        logger.info('Assuming that the input stats are calculated using preprocessed spectrograms (log)')
        self.train_stats = self.calculate_or_load_stats()

    def __call__(self, item):
        # just to generalizat the input handling. Useful for FID, IS eval and training other staff
        if isinstance(item, dict):
            if 'input' in item:
                input_key = 'input'
            elif 'image' in item:
                input_key = 'image'
            else:
                raise NotImplementedError
            item[input_key] = (item[input_key] - self.train_stats['means']) / self.train_stats['stds']
        elif isinstance(item, torch.Tensor):
            # broadcasts np.ndarray (80, 1) to (1, 80, 1) because item is torch.Tensor (B, 80, T)
            item = (item - self.train_stats['means']) / self.train_stats['stds']
        else:
            raise NotImplementedError
        return item

    def calculate_or_load_stats(self):
        try:
            # (F, 2)
            train_stats = np.loadtxt(self.cache_path)
            means, stds = train_stats.T
            logger.info('Trying to load train stats for Standard Normalization of inputs')
        except OSError:
            logger.info('Could not find the precalculated stats for Standard Normalization. Calculating...')
            train_vid_ids = open(self.train_ids_path)
            specs_paths = [os.path.join(self.specs_dir, f'{i.rstrip()}_mel.npy') for i in train_vid_ids]
            means = [None] * len(specs_paths)
            stds = [None] * len(specs_paths)
            for i, path in enumerate(tqdm(specs_paths)):
                spec = np.load(path)
                means[i] = spec.mean(axis=1)
                stds[i] = spec.std(axis=1)
            # (F) <- (num_files, F)
            means = np.array(means).mean(axis=0)
            stds = np.array(stds).mean(axis=0)
            # saving in two columns
            np.savetxt(self.cache_path, np.vstack([means, stds]).T, fmt='%0.8f')
        means = means.reshape(-1, 1)
        stds = stds.reshape(-1, 1)
        return {'means': means, 'stds': stds}

class ToTensor(object):

    def __call__(self, item):
        item['input'] = torch.from_numpy(item['input']).float()
        # if 'target' in item:
        item['target'] = torch.tensor(item['target'])
        return item

class Crop(object):

    def __init__(self, cropped_shape=None, random_crop=False):
        self.cropped_shape = cropped_shape
        if cropped_shape is not None:
            mel_num, spec_len = cropped_shape
            if random_crop:
                self.cropper = albumentations.RandomCrop
            else:
                self.cropper = albumentations.CenterCrop
            self.preprocessor = albumentations.Compose([self.cropper(mel_num, spec_len)])
        else:
            self.preprocessor = lambda **kwargs: kwargs

    def __call__(self, item):
        item['input'] = self.preprocessor(image=item['input'])['image']
        return item


if __name__ == '__main__':
    cropper = Crop([80, 848])
    item = {'input': torch.rand([80, 860])}
    outputs = cropper(item)
    print(outputs['input'].shape)