Sijuade commited on
Commit
a1539a5
1 Parent(s): fae2821

Create model/network.py

Browse files
Files changed (1) hide show
  1. model/network.py +138 -0
model/network.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ import pytorch_lightning as pl
5
+ import torchmetrics
6
+ from torch.optim.lr_scheduler import OneCycleLR
7
+ from torchmetrics.functional import accuracy
8
+
9
+
10
+ class ResBlock(nn.Module):
11
+
12
+ def __init__(self, in_channel, out_channel, stride=1):
13
+ super(ResBlock, self).__init__()
14
+ self.conv = nn.Sequential(
15
+ nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
16
+ nn.BatchNorm2d(in_channel),
17
+ nn.ReLU(),
18
+
19
+ nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
20
+ nn.BatchNorm2d(out_channel),
21
+ nn.ReLU(),
22
+ )
23
+
24
+ def forward(self, x):
25
+ return(self.conv(x))
26
+
27
+
28
+
29
+ class ResNet18(pl.LightningModule):
30
+ def __init__(self, train_loader_len, criterion, num_classes=10, lr=0.001, max_lr=1.45E-03):
31
+ super().__init__()
32
+ self.save_hyperparameters(ignore=['criterion'])
33
+
34
+ self.criterion = criterion
35
+ self.train_loader_len = train_loader_len
36
+ self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=self.hparams.num_classes)
37
+
38
+ self.prep_layer = nn.Sequential(
39
+ nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
40
+ nn.BatchNorm2d(64),
41
+ nn.ReLU()
42
+ )
43
+
44
+ self.layer_one = nn.Sequential(
45
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
46
+ nn.MaxPool2d(2,2),
47
+ nn.BatchNorm2d(128),
48
+ nn.ReLU()
49
+ )
50
+
51
+ self.res_block1 = ResBlock(128, 128)
52
+
53
+ self.layer_two = nn.Sequential(
54
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False),
55
+ nn.MaxPool2d(2,2),
56
+ nn.BatchNorm2d(256),
57
+ nn.ReLU()
58
+ )
59
+
60
+ self.layer_three = nn.Sequential(
61
+ nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False),
62
+ nn.MaxPool2d(2,2),
63
+ nn.BatchNorm2d(512),
64
+ nn.ReLU()
65
+ )
66
+
67
+ self.res_block2 = ResBlock(512, 512)
68
+
69
+ self.max_pool = nn.MaxPool2d(4,4)
70
+ self.fc = nn.Linear(512, num_classes, bias=False)
71
+
72
+ def forward(self, x):
73
+ x = self.prep_layer(x)
74
+ x = self.layer_one(x)
75
+ R1 = self.res_block1(x)
76
+ x = x + R1
77
+
78
+ x = self.layer_two(x)
79
+
80
+ x = self.layer_three(x)
81
+ R2 = self.res_block2(x)
82
+ x = x + R2
83
+
84
+ x = self.max_pool(x)
85
+
86
+ x = x.view(x.size(0), -1)
87
+ x = self.fc(x)
88
+
89
+ return(x)
90
+
91
+ def configure_optimizers(self):
92
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=1e-4)
93
+ scheduler = OneCycleLR(
94
+ optimizer,
95
+ max_lr=self.hparams.max_lr,
96
+ epochs=self.trainer.max_epochs,
97
+ steps_per_epoch=self.train_loader_len,
98
+ pct_start=5/self.trainer.max_epochs,
99
+ div_factor=100,
100
+ three_phase=False,
101
+ )
102
+ if self.hparams.max_lr==1.45E-03:
103
+ return(optimizer)
104
+ else:
105
+ return([optimizer], [scheduler])
106
+
107
+ def training_step(self, train_batch, batch_idx):
108
+ data, target = train_batch
109
+ y_pred = self(data)
110
+ loss = self.criterion(y_pred, target)
111
+
112
+ pred = torch.argmax(y_pred.squeeze(), dim=1)
113
+ acc = accuracy(pred, target, task="multiclass", num_classes=self.hparams.num_classes)
114
+
115
+ self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
116
+ self.log('train_acc', acc, prog_bar=True, on_step=False, on_epoch=True)
117
+
118
+ return(loss)
119
+
120
+ def validation_step(self, batch, batch_idx):
121
+ return(self.evaluate(batch, 'val'))
122
+
123
+ def test_step(self, batch, batch_idx):
124
+ return(self.evaluate(batch, 'test'))
125
+
126
+ def evaluate(self, batch, stage=None):
127
+ data, target = batch
128
+ y_pred = self(data)
129
+
130
+ loss = self.criterion(y_pred, target).item()
131
+ pred = torch.argmax(y_pred.squeeze(), dim=1)
132
+ acc = accuracy(pred, target, task="multiclass", num_classes=self.hparams.num_classes)
133
+
134
+ if stage:
135
+ self.log(f"{stage}_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
136
+ self.log(f"{stage}_acc", acc, prog_bar=True, on_step=False, on_epoch=True)
137
+
138
+ return pred, target