sanket03 commited on
Commit
0efbdc7
·
1 Parent(s): 37e373f

Added Custom Resnet file

Browse files
Files changed (1) hide show
  1. custom_resnet.py +69 -0
custom_resnet.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ def normalization(norm_type, embedding):
5
+ if norm_type=='batch':
6
+ return nn.BatchNorm2d(embedding)
7
+ elif norm_type=='layer':
8
+ return nn.GroupNorm(1, embedding)
9
+ else:
10
+ return nn.GroupNorm(4, embedding)
11
+
12
+ def custom_conv_layer(in_channels,
13
+ out_channels,
14
+ pool,
15
+ norm_type,
16
+ ):
17
+ conv_layer = [
18
+ nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1, stride=1, bias=False)
19
+ ]
20
+ if pool :
21
+ conv_layer.append(
22
+ nn.MaxPool2d(2, 2),
23
+ )
24
+ conv_layer.append(
25
+ normalization(norm_type, out_channels),
26
+ )
27
+ conv_layer.append(
28
+ nn.ReLU()
29
+ )
30
+ block = nn.Sequential(*conv_layer)
31
+ return block
32
+
33
+ class Net(nn.Module):
34
+ def __init__(self, normtype):
35
+ super(Net, self).__init__()
36
+ # prep layer
37
+ self.prep_layer = custom_conv_layer(3, 64, False, 'batch')
38
+ # layer 1
39
+ self.layer1_x = custom_conv_layer(64, 128, True, 'batch')
40
+ self.layer1_r1 = nn.Sequential(
41
+ custom_conv_layer(128, 128, False, 'batch'),
42
+ custom_conv_layer(128, 128, False, 'batch')
43
+ )
44
+ # layer 2
45
+ self.layer2 = custom_conv_layer(128, 256, True, 'batch')
46
+ # Layer 3
47
+ self.layer3_x = custom_conv_layer(256, 512, True, 'batch')
48
+ self.layer3_r3 = nn.Sequential(
49
+ custom_conv_layer(512, 512, False, 'batch'),
50
+ custom_conv_layer(512, 512, False, 'batch')
51
+ )
52
+ # MaxPooling with Kernel Size 4
53
+ self.pool = nn.MaxPool2d(4, 4)
54
+ # FC Layer
55
+ self.fc = nn.Linear(512, 10)
56
+
57
+ def forward(self, x):
58
+ x = self.prep_layer(x)
59
+ x1 = self.layer1_x(x)
60
+ r1 = self.layer1_r1(x1)
61
+ x = x1 + r1
62
+ x = self.layer2(x)
63
+ x3 = self.layer3_x(x)
64
+ r3 = self.layer3_r3(x3)
65
+ x = x3 + r3
66
+ x = self.pool(x)
67
+ x = x.view(-1, 512)
68
+ x = self.fc(x)
69
+ return F.softmax(x, dim=-1)