Spaces:
Running
on
Zero
Running
on
Zero
import collections | |
import csv | |
import logging | |
import os | |
import random | |
from glob import glob | |
from pathlib import Path | |
import numpy as np | |
import torch | |
import torchvision | |
logger = logging.getLogger(f'main.{__name__}') | |
class VGGSound(torch.utils.data.Dataset): | |
def __init__(self, split, specs_dir, transforms=None, splits_path='./data', meta_path='./data/vggsound.csv'): | |
super().__init__() | |
self.split = split | |
self.specs_dir = specs_dir | |
self.transforms = transforms | |
self.splits_path = splits_path | |
self.meta_path = meta_path | |
vggsound_meta = list(csv.reader(open(meta_path), quotechar='"')) | |
unique_classes = sorted(list(set(row[2] for row in vggsound_meta))) | |
self.label2target = {label: target for target, label in enumerate(unique_classes)} | |
self.target2label = {target: label for label, target in self.label2target.items()} | |
self.video2target = {row[0]: self.label2target[row[2]] for row in vggsound_meta} | |
split_clip_ids_path = os.path.join(splits_path, f'vggsound_{split}.txt') | |
if not os.path.exists(split_clip_ids_path): | |
self.make_split_files() | |
clip_ids_with_timestamp = open(split_clip_ids_path).read().splitlines() | |
clip_paths = [os.path.join(specs_dir, v + '_mel.npy') for v in clip_ids_with_timestamp] | |
self.dataset = clip_paths | |
# self.dataset = clip_paths[:10000] # overfit one batch | |
# 'zyTX_1BXKDE_16000_26000'[:11] -> 'zyTX_1BXKDE' | |
vid_classes = [self.video2target[Path(path).stem[:11]] for path in self.dataset] | |
class2count = collections.Counter(vid_classes) | |
self.class_counts = torch.tensor([class2count[cls] for cls in range(len(class2count))]) | |
# self.sample_weights = [len(self.dataset) / class2count[self.video2target[Path(path).stem[:11]]] for path in self.dataset] | |
def __getitem__(self, idx): | |
item = {} | |
spec_path = self.dataset[idx] | |
# 'zyTX_1BXKDE_16000_26000' -> 'zyTX_1BXKDE' | |
video_name = Path(spec_path).stem[:11] | |
item['input'] = np.load(spec_path) | |
item['input_path'] = spec_path | |
# if self.split in ['train', 'valid']: | |
item['target'] = self.video2target[video_name] | |
item['label'] = self.target2label[item['target']] | |
if self.transforms is not None: | |
item = self.transforms(item) | |
return item | |
def __len__(self): | |
return len(self.dataset) | |
def make_split_files(self): | |
random.seed(1337) | |
logger.info(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.') | |
# The downloaded videos (some went missing on YouTube and no longer available) | |
available_vid_paths = sorted(glob(os.path.join(self.specs_dir, '*_mel.npy'))) | |
logger.info(f'The number of clips available after download: {len(available_vid_paths)}') | |
# original (full) train and test sets | |
vggsound_meta = list(csv.reader(open(self.meta_path), quotechar='"')) | |
train_vids = {row[0] for row in vggsound_meta if row[3] == 'train'} | |
test_vids = {row[0] for row in vggsound_meta if row[3] == 'test'} | |
logger.info(f'The number of videos in vggsound train set: {len(train_vids)}') | |
logger.info(f'The number of videos in vggsound test set: {len(test_vids)}') | |
# class counts in test set. We would like to have the same distribution in valid | |
unique_classes = sorted(list(set(row[2] for row in vggsound_meta))) | |
label2target = {label: target for target, label in enumerate(unique_classes)} | |
video2target = {row[0]: label2target[row[2]] for row in vggsound_meta} | |
test_vid_classes = [video2target[vid] for vid in test_vids] | |
test_target2count = collections.Counter(test_vid_classes) | |
# now given the counts from test set, sample the same count for validation and the rest leave in train | |
train_vids_wo_valid, valid_vids = set(), set() | |
for target, label in enumerate(label2target.keys()): | |
class_train_vids = [vid for vid in train_vids if video2target[vid] == target] | |
random.shuffle(class_train_vids) | |
count = test_target2count[target] | |
valid_vids.update(class_train_vids[:count]) | |
train_vids_wo_valid.update(class_train_vids[count:]) | |
# make file with a list of available test videos (each video should contain timestamps as well) | |
train_i = valid_i = test_i = 0 | |
with open(os.path.join(self.splits_path, 'vggsound_train.txt'), 'w') as train_file, \ | |
open(os.path.join(self.splits_path, 'vggsound_valid.txt'), 'w') as valid_file, \ | |
open(os.path.join(self.splits_path, 'vggsound_test.txt'), 'w') as test_file: | |
for path in available_vid_paths: | |
path = path.replace('_mel.npy', '') | |
vid_name = Path(path).name | |
# 'zyTX_1BXKDE_16000_26000'[:11] -> 'zyTX_1BXKDE' | |
if vid_name[:11] in train_vids_wo_valid: | |
train_file.write(vid_name + '\n') | |
train_i += 1 | |
elif vid_name[:11] in valid_vids: | |
valid_file.write(vid_name + '\n') | |
valid_i += 1 | |
elif vid_name[:11] in test_vids: | |
test_file.write(vid_name + '\n') | |
test_i += 1 | |
else: | |
raise Exception(f'Clip {vid_name} is neither in train, valid nor test. Strange.') | |
logger.info(f'Put {train_i} clips to the train set and saved it to ./data/vggsound_train.txt') | |
logger.info(f'Put {valid_i} clips to the valid set and saved it to ./data/vggsound_valid.txt') | |
logger.info(f'Put {test_i} clips to the test set and saved it to ./data/vggsound_test.txt') | |
if __name__ == '__main__': | |
from transforms import Crop, StandardNormalizeAudio, ToTensor | |
specs_path = '/home/nvme/data/vggsound/features/melspec_10s_22050hz/' | |
transforms = torchvision.transforms.transforms.Compose([ | |
StandardNormalizeAudio(specs_path), | |
ToTensor(), | |
Crop([80, 848]), | |
]) | |
datasets = { | |
'train': VGGSound('train', specs_path, transforms), | |
'valid': VGGSound('valid', specs_path, transforms), | |
'test': VGGSound('test', specs_path, transforms), | |
} | |
print(datasets['train'][0]) | |
print(datasets['valid'][0]) | |
print(datasets['test'][0]) | |
print(datasets['train'].class_counts) | |
print(datasets['valid'].class_counts) | |
print(datasets['test'].class_counts) | |