Sijuade commited on
Commit
fae2821
1 Parent(s): 08eb57c

Create dataset/dataset.py

Browse files
Files changed (1) hide show
  1. dataset/dataset.py +14 -0
dataset/dataset.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+
3
+ class CIFAR10Dataset(torchvision.datasets.CIFAR10):
4
+ def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None):
5
+ super().__init__(root=root, train=train, download=download, transform=transform)
6
+
7
+ def __getitem__(self, index):
8
+ image, label = self.data[index], self.targets[index]
9
+
10
+ if self.transform is not None:
11
+ transformed = self.transform(image=image)
12
+ image = transformed["image"]
13
+
14
+ return image, label