darklord25 commited on
Commit
c10198c
1 Parent(s): 96ed3dd

Upload model files

Browse files
models/R2D2_embedding.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ # Embedding network used in Meta-learning with differentiable closed-form solvers
5
+ # (Bertinetto et al., in submission to NIPS 2018).
6
+ # They call the ridge rigressor version as "Ridge Regression Differentiable Discriminator (R2D2)."
7
+
8
+ # Note that they use a peculiar ordering of functions, namely conv-BN-pooling-lrelu,
9
+ # as opposed to the conventional one (conv-BN-lrelu-pooling).
10
+
11
+ def R2D2_conv_block(in_channels, out_channels, retain_activation=True, keep_prob=1.0):
12
+ block = nn.Sequential(
13
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
14
+ nn.BatchNorm2d(out_channels),
15
+ nn.MaxPool2d(2)
16
+ )
17
+ if retain_activation:
18
+ block.add_module("LeakyReLU", nn.LeakyReLU(0.1))
19
+
20
+ if keep_prob < 1.0:
21
+ block.add_module("Dropout", nn.Dropout(p=1 - keep_prob, inplace=False))
22
+
23
+ return block
24
+
25
+ class R2D2Embedding(nn.Module):
26
+ def __init__(self, x_dim=3, h1_dim=96, h2_dim=192, h3_dim=384, z_dim=512, \
27
+ retain_last_activation=False):
28
+ super(R2D2Embedding, self).__init__()
29
+
30
+ self.block1 = R2D2_conv_block(x_dim, h1_dim)
31
+ self.block2 = R2D2_conv_block(h1_dim, h2_dim)
32
+ self.block3 = R2D2_conv_block(h2_dim, h3_dim, keep_prob=0.9)
33
+ # In the last conv block, we disable activation function to boost the classification accuracy.
34
+ # This trick was proposed by Gidaris et al. (CVPR 2018).
35
+ # With this trick, the accuracy goes up from 50% to 51%.
36
+ # Although the authors of R2D2 did not mention this trick in the paper,
37
+ # we were unable to reproduce the result of Bertinetto et al. without resorting to this trick.
38
+ self.block4 = R2D2_conv_block(h3_dim, z_dim, retain_activation=retain_last_activation, keep_prob=0.7)
39
+
40
+ def forward(self, x):
41
+ b1 = self.block1(x)
42
+ b2 = self.block2(b1)
43
+ b3 = self.block3(b2)
44
+ b4 = self.block4(b3)
45
+ # Flatten and concatenate the output of the 3rd and 4th conv blocks as proposed in R2D2 paper.
46
+ return torch.cat((b3.view(b3.size(0), -1), b4.view(b4.size(0), -1)), 1)
models/ResNet12_embedding.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from models.dropblock import DropBlock
7
+ from models.vit import ViT
8
+ from torchvision import models
9
+ import random
10
+
11
+ # This ResNet network was designed following the practice of the following papers:
12
+ # TADAM: Task dependent adaptive metric for improved few-shot learning (Oreshkin et al., in NIPS 2018) and
13
+ # A Simple Neural Attentive Meta-Learner (Mishra et al., in ICLR 2018).
14
+
15
+
16
+ def conv3x3(in_planes, out_planes, stride=1):
17
+ """3x3 convolution with padding"""
18
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
19
+ padding=1, bias=False)
20
+
21
+
22
+ class BasicBlock(nn.Module):
23
+ expansion = 1
24
+
25
+ def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, block_size=1):
26
+ super(BasicBlock, self).__init__()
27
+ self.conv1 = conv3x3(inplanes, planes)
28
+ self.bn1 = nn.BatchNorm2d(planes)
29
+ self.relu = nn.LeakyReLU(0.1)
30
+ self.conv2 = conv3x3(planes, planes)
31
+ self.bn2 = nn.BatchNorm2d(planes)
32
+ self.conv3 = conv3x3(planes, planes)
33
+ self.bn3 = nn.BatchNorm2d(planes)
34
+ self.maxpool = nn.MaxPool2d(stride)
35
+ self.downsample = downsample
36
+ self.stride = stride
37
+ self.drop_rate = drop_rate
38
+ self.num_batches_tracked = 0
39
+ self.drop_block = drop_block
40
+ self.block_size = block_size
41
+ self.DropBlock = DropBlock(block_size=self.block_size)
42
+
43
+ def forward(self, x):
44
+ self.num_batches_tracked += 1
45
+
46
+ residual = x
47
+
48
+ out = self.conv1(x)
49
+ out = self.bn1(out)
50
+ out = self.relu(out)
51
+
52
+ out = self.conv2(out)
53
+ out = self.bn2(out)
54
+ out = self.relu(out)
55
+
56
+ out = self.conv3(out)
57
+ out = self.bn3(out)
58
+
59
+ if self.downsample is not None:
60
+ residual = self.downsample(x)
61
+ out += residual
62
+ out = self.relu(out)
63
+ out = self.maxpool(out)
64
+
65
+ if self.drop_rate > 0:
66
+ if self.drop_block == True:
67
+ feat_size = out.size()[2]
68
+ keep_rate = max(1.0 - self.drop_rate / (20*2000) *
69
+ (self.num_batches_tracked), 1.0 - self.drop_rate)
70
+ gamma = (1 - keep_rate) / self.block_size**2 * \
71
+ feat_size**2 / (feat_size - self.block_size + 1)**2
72
+ out = self.DropBlock(out, gamma=gamma)
73
+ else:
74
+ out = F.dropout(out, p=self.drop_rate,
75
+ training=self.training, inplace=True)
76
+
77
+ return out
78
+
79
+
80
+ class ResNet(nn.Module):
81
+
82
+ def __init__(self, block, keep_prob=1.0, avg_pool=False, drop_rate=0.0, dropblock_size=5):
83
+ self.inplanes = 3
84
+ super(ResNet, self).__init__()
85
+
86
+ self.layer1 = self._make_layer(
87
+ block, 64, stride=2, drop_rate=drop_rate)
88
+ self.layer2 = self._make_layer(
89
+ block, 160, stride=2, drop_rate=drop_rate)
90
+ self.layer3 = self._make_layer(
91
+ block, 320, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
92
+ self.layer4 = self._make_layer(
93
+ block, 640, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
94
+ if avg_pool:
95
+ self.avgpool = nn.AvgPool2d(5, stride=1)
96
+ self.keep_prob = keep_prob
97
+ self.keep_avg_pool = avg_pool
98
+ self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False)
99
+ self.drop_rate = drop_rate
100
+ self.linear1_1 = nn.Linear(2560, 512)
101
+ self.linear1_2 = nn.Linear(512, 64)
102
+ self.linear2_1 = nn.Linear(2560, 512)
103
+ self.linear2_2 = nn.Linear(512, 64)
104
+ self.linear3_1 = nn.Linear(2560, 512)
105
+ self.linear3_2 = nn.Linear(512, 64)
106
+
107
+ for m in self.modules():
108
+ if isinstance(m, nn.Conv2d):
109
+ nn.init.kaiming_normal_(
110
+ m.weight, mode='fan_out', nonlinearity='leaky_relu')
111
+ elif isinstance(m, nn.BatchNorm2d):
112
+ nn.init.constant_(m.weight, 1)
113
+ nn.init.constant_(m.bias, 0)
114
+
115
+ def _make_layer(self, block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1):
116
+ downsample = None
117
+ if stride != 1 or self.inplanes != planes * block.expansion:
118
+ downsample = nn.Sequential(
119
+ nn.Conv2d(self.inplanes, planes * block.expansion,
120
+ kernel_size=1, stride=1, bias=False),
121
+ nn.BatchNorm2d(planes * block.expansion),
122
+ )
123
+
124
+ layers = []
125
+ layers.append(block(self.inplanes, planes, stride,
126
+ downsample, drop_rate, drop_block, block_size))
127
+ self.inplanes = planes * block.expansion
128
+
129
+ return nn.Sequential(*layers)
130
+
131
+ def forward(self, x):
132
+ x = self.layer1(x)
133
+ x = self.layer2(x)
134
+ x = self.layer3(x)
135
+ x = self.layer4(x)
136
+ if self.keep_avg_pool:
137
+ x = self.avgpool(x)
138
+ x = x.view(x.size(0), -1)
139
+
140
+ # x1 = F.relu(self.linear1_1(x))
141
+ # x1 = self.linear1_2(x1)
142
+ # x2 = F.relu(self.linear2_1(x))
143
+ # x2 = self.linear2_2(x2)
144
+ # x3 = F.relu(self.linear3_1(x))
145
+ # x3 = self.linear3_2(x3)
146
+ # return [x1, x2, x3]
147
+ return x
148
+ # return x
149
+
150
+
151
+
152
+ class custom_model(nn.Module):
153
+
154
+ def __init__(self,num_layer=5):
155
+ super(custom_model, self).__init__()
156
+ # self.classifier = timm.create_model('densenet121', pretrained=True)
157
+ self.classifier = timm.create_model('tf_efficientnet_b7_ns', pretrained=True)
158
+ # self.classifier = nn.Sequential(*list(classifier.children())[:-1])
159
+ self.num_layer = num_layer
160
+ # self.classifier = models.resnet34(pretrained=True, progress=True)
161
+ # self.classifier = ViT(
162
+ # image_size = 32,
163
+ # patch_size = 3,
164
+ # dim = 512,
165
+ # depth = 6,
166
+ # heads = 8,
167
+ # mlp_dim = 1000,
168
+ # dropout = 0.1,
169
+ # emb_dropout = 0.1
170
+ # )
171
+
172
+ # self.bn1 = nn.BatchNorm1d(num_features=1000)
173
+ self.bn2 = nn.BatchNorm1d(num_features=128)
174
+ self.dropout = nn.Dropout(0.4)
175
+
176
+ for i in range(num_layer):
177
+ setattr(self, "linear%d_1" % i, nn.Linear(1000,512))
178
+ setattr(self, "batch_norm%d_1" % i, nn.BatchNorm1d(num_features=512))
179
+ setattr(self, "linear%d_2" % i, nn.Linear(512,64))
180
+ # setattr(self, "linear%d_3" % i, nn.Linear(128,64))
181
+
182
+ for m in self.modules():
183
+
184
+ # if isinstance(m, nn.Conv2d):
185
+ # nn.init.kaiming_normal_(
186
+ # m.weight, mode='fan_out', nonlinearity='leaky_relu')
187
+ # elif isinstance(m, nn.BatchNorm2d):
188
+ # nn.init.constant_(m.weight, 1)
189
+ # nn.init.constant_(m.bias, 0)
190
+
191
+ if isinstance(m, nn.Linear):
192
+
193
+ # m.weight = nn.parameter.Parameter(torch.randn(m.out_features,m.in_features) * torch.sqrt(torch.tensor(2/m.in_features,requires_grad = True)))
194
+ y = m.in_features
195
+ y = random.randint(m.in_features/2,m.in_features)
196
+ m.weight.data.normal_(0.0, 1/np.sqrt(y))
197
+ # m.bias.data should be 0
198
+ # m.bias.data.fill_(0)
199
+
200
+ self.act = nn.ReLU()
201
+
202
+ def forward(self, x):
203
+ x = self.classifier(x)
204
+ feat = []
205
+
206
+ # print(self.classifier.classifier.weight)
207
+
208
+ # feat = torch.zeros((self.num_layer,256))
209
+ for i in range(self.num_layer):
210
+
211
+ x1 = self.act(getattr(self,"batch_norm%d_1" % i)(getattr(self, "linear%d_1" % i)(x)))
212
+ x2 = getattr(self, "linear%d_2" % i)(x1)
213
+ feat.append(x2)
214
+ # print(getattr(self, "linear%d_1" % i).weight)
215
+ # weights.append(getattr(self, "linear%d_1" % i).weight)
216
+
217
+
218
+ feat = torch.stack(feat ,dim = 0)
219
+ # weights = torch.stack(weights,dim = 0)
220
+ # print(feat.size())
221
+ return feat
222
+
223
+
224
+
225
+
226
+ def conv_block(in_channels, out_channels):
227
+ '''
228
+ returns a block conv-bn-relu-pool
229
+ '''
230
+ return nn.Sequential(
231
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
232
+ nn.BatchNorm2d(out_channels),
233
+ nn.ReLU(),
234
+ nn.MaxPool2d(2)
235
+ )
236
+
237
+
238
+ class ProtoNet(nn.Module):
239
+ '''
240
+ Model as described in the reference paper,
241
+ source: https://github.com/jakesnell/prototypical-networks/blob/f0c48808e496989d01db59f86d4449d7aee9ab0c/protonets/models/few_shot.py#L62-L84
242
+ '''
243
+ def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
244
+ super(ProtoNet, self).__init__()
245
+ self.encoder = nn.Sequential(
246
+ conv_block(x_dim, hid_dim),
247
+ conv_block(hid_dim, hid_dim),
248
+ conv_block(hid_dim, hid_dim),
249
+ conv_block(hid_dim, z_dim),
250
+ )
251
+
252
+ def forward(self, x):
253
+ x = self.encoder(x)
254
+ return [x.view(x.size(0), -1)]
255
+
256
+
257
+ def resnet12(keep_prob=1.0, avg_pool=False,num_layer=5, **kwargs):
258
+ """Constructs a ResNet-12 model.
259
+ """
260
+ # model = ResNet(BasicBlock, keep_prob=keep_prob,
261
+ # avg_pool=avg_pool, **kwargs)
262
+ model = custom_model(num_layer=num_layer)
263
+ # model = ProtoNet()
264
+
265
+ return model
266
+
models/__pycache__/R2D2_embedding.cpython-36.pyc ADDED
Binary file (1.56 kB). View file
 
models/__pycache__/R2D2_embedding.cpython-37.pyc ADDED
Binary file (1.56 kB). View file
 
models/__pycache__/R2D2_embedding.cpython-38.pyc ADDED
Binary file (1.58 kB). View file
 
models/__pycache__/ResNet12_embedding.cpython-36.pyc ADDED
Binary file (6.5 kB). View file
 
models/__pycache__/ResNet12_embedding.cpython-37.pyc ADDED
Binary file (6.55 kB). View file
 
models/__pycache__/ResNet12_embedding.cpython-38.pyc ADDED
Binary file (6.62 kB). View file
 
models/__pycache__/ResNet12_embedding_ablation.cpython-37.pyc ADDED
Binary file (5.84 kB). View file
 
models/__pycache__/Resnet_12_em.cpython-36.pyc ADDED
Binary file (6.61 kB). View file
 
models/__pycache__/classification_heads.cpython-36.pyc ADDED
Binary file (8.92 kB). View file
 
models/__pycache__/classification_heads.cpython-37.pyc ADDED
Binary file (8.89 kB). View file
 
models/__pycache__/classification_heads.cpython-38.pyc ADDED
Binary file (8.93 kB). View file
 
models/__pycache__/classification_heads.cpython-39.pyc ADDED
Binary file (8.9 kB). View file
 
models/__pycache__/dropblock.cpython-36.pyc ADDED
Binary file (1.97 kB). View file
 
models/__pycache__/dropblock.cpython-37.pyc ADDED
Binary file (1.97 kB). View file
 
models/__pycache__/dropblock.cpython-38.pyc ADDED
Binary file (2 kB). View file
 
models/__pycache__/protonet_embedding.cpython-36.pyc ADDED
Binary file (1.93 kB). View file
 
models/__pycache__/protonet_embedding.cpython-37.pyc ADDED
Binary file (1.92 kB). View file
 
models/__pycache__/protonet_embedding.cpython-38.pyc ADDED
Binary file (1.92 kB). View file
 
models/__pycache__/vit.cpython-36.pyc ADDED
Binary file (5.38 kB). View file
 
models/__pycache__/vit.cpython-37.pyc ADDED
Binary file (5.35 kB). View file
 
models/__pycache__/vit.cpython-38.pyc ADDED
Binary file (5.26 kB). View file
 
models/classification_heads.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+ from torch.autograd import Variable
6
+ import torch.nn as nn
7
+ #from qpth.qp import QPFunction
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def sqrt_newton_schulz(A, numIters):
12
+ dim = A.shape[0]
13
+ normA = A.mul(A).sum(dim=0).sum(dim=0).sqrt()
14
+ Y = A.div(normA.expand_as(A))
15
+ I = torch.eye(dim, dim).float().cuda()
16
+ Z = torch.eye(dim, dim).float().cuda()
17
+ for i in range(numIters):
18
+ T = 0.5 * (3.0 * I - Z.mm(Y))
19
+ Y = Y.mm(T)
20
+ Z = T.mm(Z)
21
+
22
+ # sA = Y * torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A)
23
+
24
+ # sA = Y * torch.sqrt(normA).expand_as(A)
25
+
26
+ sZ = Z * 1. / torch.sqrt(normA).expand_as(A)
27
+ return sZ
28
+
29
+
30
+ def polar_decompose(input):
31
+ # square_mat = input.mm(input.transpose(0, 1))
32
+ # square_mat = square_mat/torch.norm(torch.diag(square_mat), p=1)
33
+ # ortho_mat = self.sqrt_newton_schulz(square_mat, numIters=1)
34
+
35
+ square_mat = input.transpose(0, 1).mm(input)
36
+ sA_minushalf = sqrt_newton_schulz(square_mat, 1)
37
+ ortho_mat = input.mm(sA_minushalf)
38
+
39
+ # return ortho_mat
40
+
41
+ return ortho_mat.mm(ortho_mat.transpose(0, 1))
42
+
43
+
44
+ def computeGramMatrix(A, B):
45
+ """
46
+ Constructs a linear kernel matrix between A and B.
47
+ We assume that each row in A and B represents a d-dimensional feature vector.
48
+
49
+ Parameters:
50
+ A: a (n_batch, n, d) Tensor.
51
+ B: a (n_batch, m, d) Tensor.
52
+ Returns: a (n_batch, n, m) Tensor.
53
+ """
54
+
55
+ assert(A.dim() == 3)
56
+ assert(B.dim() == 3)
57
+ assert(A.size(0) == B.size(0) and A.size(2) == B.size(2))
58
+
59
+ return torch.bmm(A, B.transpose(1,2))
60
+
61
+
62
+ def binv(b_mat):
63
+ """
64
+ Computes an inverse of each matrix in the batch.
65
+ Pytorch 0.4.1 does not support batched matrix inverse.
66
+ Hence, we are solving AX=I.
67
+
68
+ Parameters:
69
+ b_mat: a (n_batch, n, n) Tensor.
70
+ Returns: a (n_batch, n, n) Tensor.
71
+ """
72
+
73
+ id_matrix = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat).cuda()
74
+ b_inv, _ = torch.gesv(id_matrix, b_mat)
75
+
76
+ return b_inv
77
+
78
+
79
+ def one_hot(indices, depth):
80
+ """
81
+ Returns a one-hot tensor.
82
+ This is a PyTorch equivalent of Tensorflow's tf.one_hot.
83
+
84
+ Parameters:
85
+ indices: a (n_batch, m) Tensor or (m) Tensor.
86
+ depth: a scalar. Represents the depth of the one hot dimension.
87
+ Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor.
88
+ """
89
+
90
+ encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda()
91
+ index = indices.view(indices.size()+torch.Size([1]))
92
+ encoded_indicies = encoded_indicies.scatter_(1,index,1)
93
+
94
+ return encoded_indicies
95
+
96
+ def batched_kronecker(matrix1, matrix2):
97
+ matrix1_flatten = matrix1.reshape(matrix1.size()[0], -1)
98
+ matrix2_flatten = matrix2.reshape(matrix2.size()[0], -1)
99
+ return torch.bmm(matrix1_flatten.unsqueeze(2), matrix2_flatten.unsqueeze(1)).reshape([matrix1.size()[0]] + list(matrix1.size()[1:]) + list(matrix2.size()[1:])).permute([0, 1, 3, 2, 4]).reshape(matrix1.size(0), matrix1.size(1) * matrix2.size(1), matrix1.size(2) * matrix2.size(2))
100
+
101
+
102
+
103
+
104
+ ################# uncomment this if you have installed QPFunction and run Ridge
105
+ # def MetaOptNetHead_Ridge(query, support, support_labels, n_way, n_shot, lambda_reg=50.0, double_precision=True):
106
+ # """
107
+ # Fits the support set with ridge regression and
108
+ # returns the classification score on the query set.
109
+ #
110
+ # Parameters:
111
+ # query: a (tasks_per_batch, n_query, d) Tensor.
112
+ # support: a (tasks_per_batch, n_support, d) Tensor.
113
+ # support_labels: a (tasks_per_batch, n_support) Tensor.
114
+ # n_way: a scalar. Represents the number of classes in a few-shot classification task.
115
+ # n_shot: a scalar. Represents the number of support examples given per class.
116
+ # lambda_reg: a scalar. Represents the strength of L2 regularization.
117
+ # Returns: a (tasks_per_batch, n_query, n_way) Tensor.
118
+ # """
119
+ #
120
+ # tasks_per_batch = query.size(0)
121
+ # n_support = support.size(1)
122
+ # n_query = query.size(1)
123
+ #
124
+ # assert(query.dim() == 3)
125
+ # assert(support.dim() == 3)
126
+ # assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
127
+ # assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
128
+ #
129
+ # #Here we solve the dual problem:
130
+ # #Note that the classes are indexed by m & samples are indexed by i.
131
+ # #min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i
132
+ #
133
+ # #where w_m(\alpha) = \sum_i \alpha^m_i x_i,
134
+ #
135
+ # #\alpha is an (n_support, n_way) matrix
136
+ # kernel_matrix = computeGramMatrix(support, support)
137
+ # kernel_matrix += lambda_reg * torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda()
138
+ #
139
+ # block_kernel_matrix = kernel_matrix.repeat(n_way, 1, 1) #(n_way * tasks_per_batch, n_support, n_support)
140
+ #
141
+ # support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_way)
142
+ # support_labels_one_hot = support_labels_one_hot.transpose(0, 1) # (n_way, tasks_per_batch * n_support)
143
+ # support_labels_one_hot = support_labels_one_hot.reshape(n_way * tasks_per_batch, n_support) # (n_way*tasks_per_batch, n_support)
144
+ #
145
+ # G = block_kernel_matrix
146
+ # e = -2.0 * support_labels_one_hot
147
+ #
148
+ # #This is a fake inequlity constraint as qpth does not support QP without an inequality constraint.
149
+ # id_matrix_1 = torch.zeros(tasks_per_batch*n_way, n_support, n_support)
150
+ # C = Variable(id_matrix_1)
151
+ # h = Variable(torch.zeros((tasks_per_batch*n_way, n_support)))
152
+ # dummy = Variable(torch.Tensor()).cuda() # We want to ignore the equality constraint.
153
+ #
154
+ # #if double_precision:
155
+ # G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]
156
+ #
157
+ #
158
+ # qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach())
159
+ # qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support)
160
+ # qp_sol = qp_sol.permute(1, 2, 0)
161
+ #
162
+ #
163
+ # # Compute the classification score.
164
+ # compatibility = computeGramMatrix(support, query)
165
+ # compatibility = compatibility.float()
166
+ # compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way)
167
+ # qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way)
168
+ # logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way)
169
+ # logits = logits * compatibility
170
+ # logits = torch.sum(logits, 1)
171
+ #
172
+ # return logits
173
+
174
+ def R2D2Head(query, support, support_labels, n_way, n_shot, l2_regularizer_lambda=50.0):
175
+ """
176
+ Fits the support set with ridge regression and
177
+ returns the classification score on the query set.
178
+
179
+ This model is the classification head described in:
180
+ Meta-learning with differentiable closed-form solvers
181
+ (Bertinetto et al., in submission to NIPS 2018).
182
+
183
+ Parameters:
184
+ query: a (tasks_per_batch, n_query, d) Tensor.
185
+ support: a (tasks_per_batch, n_support, d) Tensor.
186
+ support_labels: a (tasks_per_batch, n_support) Tensor.
187
+ n_way: a scalar. Represents the number of classes in a few-shot classification task.
188
+ n_shot: a scalar. Represents the number of support examples given per class.
189
+ l2_regularizer_lambda: a scalar. Represents the strength of L2 regularization.
190
+ Returns: a (tasks_per_batch, n_query, n_way) Tensor.
191
+ """
192
+
193
+ tasks_per_batch = query.size(0)
194
+ n_support = support.size(1)
195
+
196
+ assert(query.dim() == 3)
197
+ assert(support.dim() == 3)
198
+ assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
199
+ assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
200
+
201
+ support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way)
202
+ support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
203
+
204
+ id_matrix = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda()
205
+
206
+ # Compute the dual form solution of the ridge regression.
207
+ # W = X^T(X X^T - lambda * I)^(-1) Y
208
+ ridge_sol = computeGramMatrix(support, support) + l2_regularizer_lambda * id_matrix
209
+ ridge_sol = binv(ridge_sol)
210
+ ridge_sol = torch.bmm(support.transpose(1,2), ridge_sol)
211
+ ridge_sol = torch.bmm(ridge_sol, support_labels_one_hot)
212
+
213
+ # Compute the classification score.
214
+ # score = W X
215
+ logits = torch.bmm(query, ridge_sol)
216
+
217
+ return logits
218
+
219
+
220
+
221
+ def ProtoNetHead(query, support, support_labels, n_way, n_shot, normalize=True):
222
+ """
223
+ Constructs the prototype representation of each class(=mean of support vectors of each class) and
224
+ returns the classification score (=L2 distance to each class prototype) on the query set.
225
+
226
+ This model is the classification head described in:
227
+ Prototypical Networks for Few-shot Learning
228
+ (Snell et al., NIPS 2017).
229
+
230
+ Parameters:
231
+ query: a (tasks_per_batch, n_query, d) Tensor.
232
+ support: a (tasks_per_batch, n_support, d) Tensor.
233
+ support_labels: a (tasks_per_batch, n_support) Tensor.
234
+ n_way: a scalar. Represents the number of classes in a few-shot classification task.
235
+ n_shot: a scalar. Represents the number of support examples given per class.
236
+ normalize: a boolean. Represents whether if we want to normalize the distances by the embedding dimension.
237
+ Returns: a (tasks_per_batch, n_query, n_way) Tensor.
238
+ """
239
+
240
+ tasks_per_batch = query.size(0)
241
+ n_support = support.size(1)
242
+ n_query = query.size(1)
243
+ d = query.size(2)
244
+
245
+ assert(query.dim() == 3)
246
+ assert(support.dim() == 3)
247
+ assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
248
+ assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
249
+
250
+ support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way)
251
+ support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
252
+
253
+ # From:
254
+ # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/architectures/PrototypicalNetworksHead.py
255
+ #************************* Compute Prototypes **************************
256
+ labels_train_transposed = support_labels_one_hot.transpose(1,2)
257
+
258
+ prototypes = torch.bmm(labels_train_transposed, support)
259
+ # Divide with the number of examples per novel category.
260
+ prototypes = prototypes.div(
261
+ labels_train_transposed.sum(dim=2, keepdim=True).expand_as(prototypes)
262
+ )
263
+
264
+ # Distance Matrix Vectorization Trick
265
+ AB = computeGramMatrix(query, prototypes)
266
+ AA = (query * query).sum(dim=2, keepdim=True)
267
+ BB = (prototypes * prototypes).sum(dim=2, keepdim=True).reshape(tasks_per_batch, 1, n_way)
268
+ logits = AA.expand_as(AB) - 2 * AB + BB.expand_as(AB)
269
+ logits = -logits
270
+
271
+ if normalize:
272
+ logits = logits / d
273
+
274
+ return logits
275
+
276
+
277
+
278
+ def SubspaceNetHead(query, support, support_labels, n_way, n_shot, normalize=True):
279
+ """
280
+ Constructs the subspace representation of each class(=mean of support vectors of each class) and
281
+ returns the classification score (=L2 distance to each class prototype) on the query set.
282
+
283
+ Our algorithm using subspaces here
284
+
285
+ Parameters:
286
+ query: a (tasks_per_batch, n_query, d) Tensor.
287
+ support: a (tasks_per_batch, n_support, d) Tensor.
288
+ support_labels: a (tasks_per_batch, n_support) Tensor.
289
+ n_way: a scalar. Represents the number of classes in a few-shot classification task.
290
+ n_shot: a scalar. Represents the number of support examples given per class.
291
+ normalize: a boolean. Represents whether if we want to normalize the distances by the embedding dimension.
292
+ Returns: a (tasks_per_batch, n_query, n_way) Tensor.
293
+ """
294
+
295
+ tasks_per_batch = query.size(0)
296
+ n_support = support.size(1)
297
+ n_query = query.size(1)
298
+ d = query.size(2)
299
+
300
+ assert(query.dim() == 3)
301
+ assert(support.dim() == 3)
302
+ assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
303
+ assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
304
+
305
+ support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way)
306
+ #support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
307
+
308
+
309
+ support_reshape = support.view(tasks_per_batch * n_support, -1)
310
+
311
+ support_labels_reshaped = support_labels.contiguous().view(-1)
312
+ class_representatives = []
313
+ for nn in range(n_way):
314
+ idxss = torch.nonzero((support_labels_reshaped == nn),as_tuple=False)
315
+ all_support_perclass = support_reshape[idxss, :]
316
+ class_representatives.append(all_support_perclass.view(tasks_per_batch, n_shot, -1))
317
+
318
+ class_representatives = torch.stack(class_representatives)
319
+ class_representatives = class_representatives.transpose(0, 1) #tasks_per_batch, n_way, n_support, -1
320
+ class_representatives = class_representatives.transpose(2, 3).contiguous().view(tasks_per_batch*n_way, -1, n_shot)
321
+
322
+ dist = []
323
+ for cc in range(tasks_per_batch*n_way):
324
+ batch_idx = cc//n_way
325
+ qq = query[batch_idx]
326
+ uu, _, _ = torch.svd(class_representatives[cc].double())
327
+ uu = uu.float()
328
+ subspace = uu[:, :n_shot-1].transpose(0, 1)
329
+ projection = subspace.transpose(0, 1).mm(subspace.mm(qq.transpose(0, 1))).transpose(0, 1)
330
+ dist_perclass = torch.sum((qq - projection)**2, dim=-1)
331
+ dist.append(dist_perclass)
332
+
333
+ dist = torch.stack(dist).view(tasks_per_batch, n_way, -1).transpose(1, 2)
334
+ logits = -dist
335
+
336
+ if normalize:
337
+ logits = logits / d
338
+
339
+ return logits
340
+
341
+
342
+
343
+
344
+ class ClassificationHead(nn.Module):
345
+ def __init__(self, base_learner='MetaOptNet', enable_scale=True):
346
+ super(ClassificationHead, self).__init__()
347
+ if ('Subspace' in base_learner):
348
+ self.head = SubspaceNetHead
349
+ elif ('Ridge' in base_learner):
350
+ self.head = MetaOptNetHead_Ridge
351
+ elif ('R2D2' in base_learner):
352
+ self.head = R2D2Head
353
+ elif ('Proto' in base_learner):
354
+ self.head = ProtoNetHead
355
+ else:
356
+ print ("Cannot recognize the base learner type")
357
+ assert(False)
358
+
359
+ # Add a learnable scale
360
+ self.enable_scale = enable_scale
361
+ self.scale = nn.Parameter(torch.FloatTensor([1.0]))
362
+
363
+ def forward(self, query, support, support_labels, n_way, n_shot, **kwargs):
364
+ if self.enable_scale:
365
+ return self.scale * self.head(query, support, support_labels, n_way, n_shot, **kwargs)
366
+ else:
367
+ return self.head(query, support, support_labels, n_way, n_shot, **kwargs)
models/closerlook_classifier.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from torch.nn.utils.weight_norm import WeightNorm
4
+
5
+
6
+ class distLinear(nn.Module):
7
+ def __init__(self, indim, outdim):
8
+ super(distLinear, self).__init__()
9
+ self.L = nn.Linear( indim, outdim, bias = False)
10
+ self.class_wise_learnable_norm = True #See the issue#4&8 in the github
11
+ if self.class_wise_learnable_norm:
12
+ WeightNorm.apply(self.L, 'weight', dim=0) #split the weight update component to direction and norm
13
+
14
+ if outdim <=200:
15
+ self.scale_factor = 2; #a fixed scale factor to scale the output of cos value into a reasonably large input for softmax, for to reproduce the result of CUB with ResNet10, use 4. see the issue#31 in the github
16
+ else:
17
+ self.scale_factor = 10; #in omniglot, a larger scale factor is required to handle >1000 output classes.
18
+
19
+ def forward(self, x):
20
+ x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x)
21
+ x_normalized = x.div(x_norm+ 0.00001)
22
+ if not self.class_wise_learnable_norm:
23
+ L_norm = torch.norm(self.L.weight.data, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data)
24
+ self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001)
25
+ cos_dist = self.L(x_normalized) #matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github
26
+ scores = self.scale_factor* (cos_dist)
27
+
28
+ return scores
models/dropblock.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from torch.distributions import Bernoulli
5
+
6
+
7
+ class DropBlock(nn.Module):
8
+ def __init__(self, block_size):
9
+ super(DropBlock, self).__init__()
10
+
11
+ self.block_size = block_size
12
+ #self.gamma = gamma
13
+ #self.bernouli = Bernoulli(gamma)
14
+
15
+ def forward(self, x, gamma):
16
+ # shape: (bsize, channels, height, width)
17
+
18
+ if self.training:
19
+ batch_size, channels, height, width = x.shape
20
+
21
+ bernoulli = Bernoulli(gamma)
22
+ mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda()
23
+ #print((x.sample[-2], x.sample[-1]))
24
+ block_mask = self._compute_block_mask(mask)
25
+ #print (block_mask.size())
26
+ #print (x.size())
27
+ countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
28
+ count_ones = block_mask.sum()
29
+
30
+ return block_mask * x * (countM / count_ones)
31
+ else:
32
+ return x
33
+
34
+ def _compute_block_mask(self, mask):
35
+ left_padding = int((self.block_size-1) / 2)
36
+ right_padding = int(self.block_size / 2)
37
+
38
+ batch_size, channels, height, width = mask.shape
39
+ #print ("mask", mask[0][0])
40
+ non_zero_idxs = torch.nonzero(mask,as_tuple=False)
41
+ # print(type(non_zero_idxs))
42
+ # print(type(non_zero_idxs))
43
+ nr_blocks = non_zero_idxs.size(0)
44
+
45
+ offsets = torch.stack(
46
+ [
47
+ torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding,
48
+ torch.arange(self.block_size).repeat(self.block_size), #- left_padding
49
+ ]
50
+ ).t().cuda()
51
+ offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1)
52
+
53
+ if nr_blocks > 0:
54
+ non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
55
+ offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
56
+ offsets = offsets.long()
57
+
58
+ block_idxs = non_zero_idxs + offsets
59
+ #block_idxs += left_padding
60
+ padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
61
+ padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
62
+ else:
63
+ padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
64
+
65
+ block_mask = 1 - padded_mask#[:height, :width]
66
+ return block_mask
models/protonet_embedding.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import math
3
+
4
+ class ConvBlock(nn.Module):
5
+ def __init__(self, in_channels, out_channels, retain_activation=True):
6
+ super(ConvBlock, self).__init__()
7
+
8
+ self.block = nn.Sequential(
9
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
10
+ nn.BatchNorm2d(out_channels)
11
+ )
12
+
13
+ if retain_activation:
14
+ self.block.add_module("ReLU", nn.ReLU(inplace=True))
15
+ self.block.add_module("MaxPool2d", nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
16
+
17
+ def forward(self, x):
18
+ out = self.block(x)
19
+ return out
20
+
21
+ # Embedding network used in Matching Networks (Vinyals et al., NIPS 2016), Meta-LSTM (Ravi & Larochelle, ICLR 2017),
22
+ # MAML (w/ h_dim=z_dim=32) (Finn et al., ICML 2017), Prototypical Networks (Snell et al. NIPS 2017).
23
+
24
+ class ProtoNetEmbedding(nn.Module):
25
+ def __init__(self, x_dim=3, h_dim=64, z_dim=64, retain_last_activation=True):
26
+ super(ProtoNetEmbedding, self).__init__()
27
+ self.encoder = nn.Sequential(
28
+ ConvBlock(x_dim, h_dim),
29
+ ConvBlock(h_dim, h_dim),
30
+ ConvBlock(h_dim, h_dim),
31
+ ConvBlock(h_dim, z_dim, retain_activation=retain_last_activation),
32
+ )
33
+ for m in self.modules():
34
+ if isinstance(m, nn.Conv2d):
35
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
36
+ m.weight.data.normal_(0, math.sqrt(2. / n))
37
+ elif isinstance(m, nn.BatchNorm2d):
38
+ m.weight.data.fill_(1)
39
+ m.bias.data.zero_()
40
+
41
+ def forward(self, x):
42
+ x = self.encoder(x)
43
+ return x.view(x.size(0), -1)
models/vit.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_pytorch.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch import nn
7
+
8
+ MIN_NUM_PATCHES = 16
9
+
10
+ class Residual(nn.Module):
11
+ def __init__(self, fn):
12
+ super().__init__()
13
+ self.fn = fn
14
+ def forward(self, x, **kwargs):
15
+ return self.fn(x, **kwargs) + x
16
+
17
+ class PreNorm(nn.Module):
18
+ def __init__(self, dim, fn):
19
+ super().__init__()
20
+ self.norm = nn.LayerNorm(dim)
21
+ self.fn = fn
22
+ def forward(self, x, **kwargs):
23
+ return self.fn(self.norm(x), **kwargs)
24
+
25
+ class FeedForward(nn.Module):
26
+ def __init__(self, dim, hidden_dim, dropout = 0.):
27
+ super().__init__()
28
+ self.net = nn.Sequential(
29
+ nn.Linear(dim, hidden_dim),
30
+ nn.GELU(),
31
+ nn.Dropout(dropout),
32
+ nn.Linear(hidden_dim, dim),
33
+ nn.Dropout(dropout)
34
+ )
35
+ def forward(self, x):
36
+ return self.net(x)
37
+
38
+ class Attention(nn.Module):
39
+ def __init__(self, dim, heads = 8, dropout = 0.):
40
+ super().__init__()
41
+ self.heads = heads
42
+ self.scale = dim ** -0.5
43
+
44
+ self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
45
+ self.to_out = nn.Sequential(
46
+ nn.Linear(dim, dim),
47
+ nn.Dropout(dropout)
48
+ )
49
+
50
+ def forward(self, x, mask = None):
51
+ b, n, _, h = *x.shape, self.heads
52
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
53
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
54
+
55
+ dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
56
+
57
+ if mask is not None:
58
+ mask = F.pad(mask.flatten(1), (1, 0), value = True)
59
+ assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
60
+ mask = mask[:, None, :] * mask[:, :, None]
61
+ dots.masked_fill_(~mask, float('-inf'))
62
+ del mask
63
+
64
+ attn = dots.softmax(dim=-1)
65
+
66
+ out = torch.einsum('bhij,bhjd->bhid', attn, v)
67
+ out = rearrange(out, 'b h n d -> b n (h d)')
68
+ out = self.to_out(out)
69
+ return out
70
+
71
+ class Transformer(nn.Module):
72
+ def __init__(self, dim, depth, heads, mlp_dim, dropout):
73
+ super().__init__()
74
+ self.layers = nn.ModuleList([])
75
+ for _ in range(depth):
76
+ self.layers.append(nn.ModuleList([
77
+ Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
78
+ Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
79
+ ]))
80
+ def forward(self, x, mask = None):
81
+ for attn, ff in self.layers:
82
+ x = attn(x, mask = mask)
83
+ x = ff(x)
84
+ return x
85
+
86
+ class ViT(nn.Module):
87
+ def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
88
+ super().__init__()
89
+ assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
90
+ num_patches = (image_size // patch_size) ** 2
91
+ patch_dim = channels * patch_size ** 2
92
+ assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'
93
+
94
+ self.patch_size = patch_size
95
+
96
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
97
+ self.patch_to_embedding = nn.Linear(patch_dim, dim)
98
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
99
+ self.dropout = nn.Dropout(emb_dropout)
100
+
101
+ self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
102
+
103
+ self.to_cls_token = nn.Identity()
104
+
105
+ self.mlp_head = nn.Sequential(
106
+ nn.LayerNorm(dim),
107
+ nn.Linear(dim, mlp_dim),
108
+ nn.GELU(),
109
+ nn.Dropout(dropout)
110
+ )
111
+
112
+ def forward(self, img, mask = None):
113
+ p = self.patch_size
114
+
115
+ x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
116
+ x = self.patch_to_embedding(x)
117
+ b, n, _ = x.shape
118
+
119
+ cls_tokens = self.cls_token.expand(b, -1, -1)
120
+ x = torch.cat((cls_tokens, x), dim=1)
121
+ x += self.pos_embedding[:, :(n + 1)]
122
+ x = self.dropout(x)
123
+
124
+ x = self.transformer(x, mask)
125
+
126
+ x = self.to_cls_token(x[:, 0])
127
+ return self.mlp_head(x)