ibrim commited on
Commit
6bea3f9
1 Parent(s): 8fbb358

Upload 7 files

Browse files
Files changed (7) hide show
  1. CLIP.py +66 -0
  2. app.py +75 -0
  3. best.pt +3 -0
  4. config.py +32 -0
  5. implement.py +332 -0
  6. main.py +115 -0
  7. requirements.txt +14 -0
CLIP.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ import config as CFG
6
+ from modules import ImageEncoder, TextEncoder, ProjectionHead
7
+
8
+
9
+ class CLIPModel(nn.Module):
10
+ def __init__(
11
+ self,
12
+ temperature=CFG.temperature,
13
+ image_embedding=CFG.image_embedding,
14
+ text_embedding=CFG.text_embedding,
15
+ ):
16
+ super().__init__()
17
+ self.image_encoder = ImageEncoder()
18
+ self.text_encoder = TextEncoder()
19
+ self.image_projection = ProjectionHead(embedding_dim=image_embedding)
20
+ self.text_projection = ProjectionHead(embedding_dim=text_embedding)
21
+ self.temperature = temperature
22
+
23
+ def forward(self, batch):
24
+ # Getting Image and Text Features
25
+ image_features = self.image_encoder(batch["image"])
26
+ text_features = self.text_encoder(
27
+ input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
28
+ )
29
+ # Getting Image and Text Embeddings (with same dimension)
30
+ image_embeddings = self.image_projection(image_features)
31
+ text_embeddings = self.text_projection(text_features)
32
+
33
+ # Calculating the Loss
34
+ logits = (text_embeddings @ image_embeddings.T) / self.temperature
35
+ images_similarity = image_embeddings @ image_embeddings.T
36
+ texts_similarity = text_embeddings @ text_embeddings.T
37
+ targets = F.softmax(
38
+ (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
39
+ )
40
+ texts_loss = cross_entropy(logits, targets, reduction='none')
41
+ images_loss = cross_entropy(logits.T, targets.T, reduction='none')
42
+ loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
43
+ return loss.mean()
44
+
45
+
46
+ def cross_entropy(preds, targets, reduction='none'):
47
+ log_softmax = nn.LogSoftmax(dim=-1)
48
+ loss = (-targets * log_softmax(preds)).sum(1)
49
+ if reduction == "none":
50
+ return loss
51
+ elif reduction == "mean":
52
+ return loss.mean()
53
+
54
+ if __name__ == '__main__':
55
+ images = torch.randn(8, 3, 224, 224)
56
+ input_ids = torch.randint(5, 300, size=(8, 25))
57
+ attention_mask = torch.ones(8, 25)
58
+ batch = {
59
+ 'image': images,
60
+ 'input_ids': input_ids,
61
+ 'attention_mask': attention_mask
62
+ }
63
+
64
+ CLIP = CLIPModel()
65
+ loss = CLIP(batch)
66
+ print("")
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import gc
3
+ import cv2
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+ from transformers import DistilBertTokenizer
8
+ import matplotlib.pyplot as plt
9
+ from implement import *
10
+ import config as CFG
11
+ from main import build_loaders
12
+ from CLIP import CLIPModel
13
+ import os
14
+ os.environ['HTTPS_PROXY']="http://185.46.212.90:80/"
15
+ os.environ['HTTP_PROXY']="http://185.46.212.90:80/"
16
+ with gr.Blocks(css="style.css") as demo:
17
+ def get_image_embeddings(valid_df, model_path):
18
+ tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
19
+ valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
20
+
21
+ model = CLIPModel().to(CFG.device)
22
+ model.load_state_dict(torch.load(model_path, map_location=CFG.device))
23
+ model.eval()
24
+
25
+ valid_image_embeddings = []
26
+ with torch.no_grad():
27
+ for batch in tqdm(valid_loader):
28
+ image_features = model.image_encoder(batch["image"].to(CFG.device))
29
+ image_embeddings = model.image_projection(image_features)
30
+ valid_image_embeddings.append(image_embeddings)
31
+ return model, torch.cat(valid_image_embeddings)
32
+
33
+ _, valid_df = make_train_valid_dfs()
34
+ model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
35
+
36
+ def find_matches(query, n=9):
37
+ tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
38
+ encoded_query = tokenizer([query])
39
+ batch = {
40
+ key: torch.tensor(values).to(CFG.device)
41
+ for key, values in encoded_query.items()
42
+ }
43
+ with torch.no_grad():
44
+ text_features = model.text_encoder(
45
+ input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
46
+ )
47
+ text_embeddings = model.text_projection(text_features)
48
+
49
+ image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
50
+ text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
51
+ dot_similarity = text_embeddings_n @ image_embeddings_n.T
52
+
53
+ _, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
54
+ matches = [valid_df['image'].values[idx] for idx in indices[::5]]
55
+
56
+ images = []
57
+ for match in matches:
58
+ image = cv2.imread(f"{CFG.image_path}/{match}")
59
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
60
+ # images.append(image)
61
+
62
+ return image
63
+ with gr.Row():
64
+ textbox = gr.Textbox(label = "Enter a query to find matching images using a CLIP model.")
65
+ image = gr.Image(type="numpy")
66
+
67
+ button = gr.Button("Press")
68
+ button.click(
69
+ fn = find_matches,
70
+ inputs=textbox,
71
+ outputs=image
72
+ )
73
+
74
+ # Create Gradio interface
75
+ demo.launch(share=True)
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7643c035e44a5bee1abd2dd82ecc7232803751b9fc87c00d456f848ca1d0e385
3
+ size 363250624
config.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ debug = True
4
+ image_path = "/raid/users/mohammadibrahim-st/TSAI/OpenAI-CLIP/Flicker-8k/Images"
5
+ captions_path = "/raid/users/mohammadibrahim-st/TSAI/OpenAI-CLIP/Flicker-8k"
6
+ batch_size = 20
7
+ num_workers = 0
8
+ lr = 1e-3
9
+ weight_decay = 1e-3
10
+ patience = 2
11
+ factor = 0.5
12
+ epochs = 5
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ model_name = 'resnet50'
16
+ image_embedding = 2048
17
+ text_encoder_model = "/raid/users/mohammadibrahim-st/Models/BertDistil"
18
+ text_embedding = 768
19
+ text_tokenizer = "/raid/users/mohammadibrahim-st/Models/BertDistil"
20
+ max_length = 200
21
+
22
+ pretrained = False # for both image encoder and text encoder
23
+ trainable = False # for both image encoder and text encoder
24
+ temperature = 1.0
25
+
26
+ # image size
27
+ size = 224
28
+
29
+ # for projection head; used for both image and text encoders
30
+ num_projection_layers = 1
31
+ projection_dim = 256
32
+ dropout = 0.1
implement.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import gc
4
+ import numpy as np
5
+ import pandas as pd
6
+ import itertools
7
+ from tqdm.autonotebook import tqdm
8
+ import albumentations as A
9
+
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+ import timm
14
+ from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
15
+ import os
16
+ os.environ['HTTPS_PROXY']="http://185.46.212.90:80/"
17
+ os.environ['HTTP_PROXY']="http://185.46.212.90:80/"
18
+ class CFG:
19
+ debug = False
20
+ image_path = "/raid/users/mohammadibrahim-st/TSAI/OpenAI-CLIP/Flicker-8k/Images"
21
+ captions_path = "/raid/users/mohammadibrahim-st/TSAI/OpenAI-CLIP/Flicker-8k"
22
+ batch_size = 30
23
+ num_workers = 4
24
+ head_lr = 1e-3
25
+ image_encoder_lr = 1e-4
26
+ text_encoder_lr = 1e-5
27
+ weight_decay = 1e-3
28
+ patience = 1
29
+ factor = 0.8
30
+ epochs = 4
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+
33
+ model_name = 'resnet50'
34
+ image_embedding = 2048
35
+ text_encoder_model = "/raid/users/mohammadibrahim-st/Models/BertDistil"
36
+ text_embedding = 768
37
+ text_tokenizer = "/raid/users/mohammadibrahim-st/Models/BertDistil"
38
+ max_length = 200
39
+
40
+ pretrained = True # for both image encoder and text encoder
41
+ trainable = True # for both image encoder and text encoder
42
+ temperature = 1.0
43
+
44
+ # image size
45
+ size = 224
46
+
47
+ # for projection head; used for both image and text encoders
48
+ num_projection_layers = 1
49
+ projection_dim = 256
50
+ dropout = 0.1
51
+
52
+ class AvgMeter:
53
+ def __init__(self, name="Metric"):
54
+ self.name = name
55
+ self.reset()
56
+
57
+ def reset(self):
58
+ self.avg, self.sum, self.count = [0] * 3
59
+
60
+ def update(self, val, count=1):
61
+ self.count += count
62
+ self.sum += val * count
63
+ self.avg = self.sum / self.count
64
+
65
+ def __repr__(self):
66
+ text = f"{self.name}: {self.avg:.4f}"
67
+ return text
68
+
69
+ def get_lr(optimizer):
70
+ for param_group in optimizer.param_groups:
71
+ return param_group["lr"]
72
+
73
+ class CLIPDataset(torch.utils.data.Dataset):
74
+ def __init__(self, image_filenames, captions, tokenizer, transforms):
75
+ """
76
+ image_filenames and cpations must have the same length; so, if there are
77
+ multiple captions for each image, the image_filenames must have repetitive
78
+ file names
79
+ """
80
+
81
+ self.image_filenames = image_filenames
82
+ self.captions = list(captions)
83
+ self.encoded_captions = tokenizer(
84
+ list(captions), padding=True, truncation=True, max_length=CFG.max_length
85
+ )
86
+ self.transforms = transforms
87
+
88
+ def __getitem__(self, idx):
89
+ item = {
90
+ key: torch.tensor(values[idx])
91
+ for key, values in self.encoded_captions.items()
92
+ }
93
+
94
+ image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
95
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
96
+ image = self.transforms(image=image)['image']
97
+ item['image'] = torch.tensor(image).permute(2, 0, 1).float()
98
+ item['caption'] = self.captions[idx]
99
+
100
+ return item
101
+
102
+
103
+ def __len__(self):
104
+ return len(self.captions)
105
+
106
+
107
+
108
+ def get_transforms(mode="train"):
109
+ if mode == "train":
110
+ return A.Compose(
111
+ [
112
+ A.Resize(CFG.size, CFG.size, always_apply=True),
113
+ A.Normalize(max_pixel_value=255.0, always_apply=True),
114
+ ]
115
+ )
116
+ else:
117
+ return A.Compose(
118
+ [
119
+ A.Resize(CFG.size, CFG.size, always_apply=True),
120
+ A.Normalize(max_pixel_value=255.0, always_apply=True),
121
+ ]
122
+ )
123
+
124
+ class ImageEncoder(nn.Module):
125
+ """
126
+ Encode images to a fixed size vector
127
+ """
128
+
129
+ def __init__(
130
+ self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
131
+ ):
132
+ super().__init__()
133
+ self.model = timm.create_model(
134
+ model_name, pretrained, num_classes=0, global_pool="avg"
135
+ )
136
+ for p in self.model.parameters():
137
+ p.requires_grad = trainable
138
+
139
+ def forward(self, x):
140
+ return self.model(x)
141
+
142
+ class TextEncoder(nn.Module):
143
+ def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
144
+ super().__init__()
145
+ if pretrained:
146
+ self.model = DistilBertModel.from_pretrained(model_name, use_safetensors=True) #added use_safetensor
147
+ else:
148
+ self.model = DistilBertModel(config=DistilBertConfig())
149
+
150
+ for p in self.model.parameters():
151
+ p.requires_grad = trainable
152
+
153
+ # we are using the CLS token hidden representation as the sentence's embedding
154
+ self.target_token_idx = 0
155
+
156
+ def forward(self, input_ids, attention_mask):
157
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)
158
+ last_hidden_state = output.last_hidden_state
159
+ return last_hidden_state[:, self.target_token_idx, :]
160
+
161
+ class ProjectionHead(nn.Module):
162
+ def __init__(
163
+ self,
164
+ embedding_dim,
165
+ projection_dim=CFG.projection_dim,
166
+ dropout=CFG.dropout
167
+ ):
168
+ super().__init__()
169
+ self.projection = nn.Linear(embedding_dim, projection_dim)
170
+ self.gelu = nn.GELU()
171
+ self.fc = nn.Linear(projection_dim, projection_dim)
172
+ self.dropout = nn.Dropout(dropout)
173
+ self.layer_norm = nn.LayerNorm(projection_dim)
174
+
175
+ def forward(self, x):
176
+ projected = self.projection(x)
177
+ x = self.gelu(projected)
178
+ x = self.fc(x)
179
+ x = self.dropout(x)
180
+ x = x + projected
181
+ x = self.layer_norm(x)
182
+ return x
183
+
184
+ class CLIPModel(nn.Module):
185
+ def __init__(
186
+ self,
187
+ temperature=CFG.temperature,
188
+ image_embedding=CFG.image_embedding,
189
+ text_embedding=CFG.text_embedding,
190
+ ):
191
+ super().__init__()
192
+ self.image_encoder = ImageEncoder()
193
+ self.text_encoder = TextEncoder()
194
+ self.image_projection = ProjectionHead(embedding_dim=image_embedding)
195
+ self.text_projection = ProjectionHead(embedding_dim=text_embedding)
196
+ self.temperature = temperature
197
+
198
+ def forward(self, batch):
199
+ # Getting Image and Text Features
200
+ image_features = self.image_encoder(batch["image"])
201
+ text_features = self.text_encoder(
202
+ input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
203
+ )
204
+ # Getting Image and Text Embeddings (with same dimension)
205
+ image_embeddings = self.image_projection(image_features)
206
+ text_embeddings = self.text_projection(text_features)
207
+
208
+ # Calculating the Loss
209
+ logits = (text_embeddings @ image_embeddings.T) / self.temperature
210
+ images_similarity = image_embeddings @ image_embeddings.T
211
+ texts_similarity = text_embeddings @ text_embeddings.T
212
+ targets = F.softmax(
213
+ (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
214
+ )
215
+ texts_loss = cross_entropy(logits, targets, reduction='none')
216
+ images_loss = cross_entropy(logits.T, targets.T, reduction='none')
217
+ loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
218
+ return loss.mean()
219
+
220
+
221
+ def cross_entropy(preds, targets, reduction='none'):
222
+ log_softmax = nn.LogSoftmax(dim=-1)
223
+ loss = (-targets * log_softmax(preds)).sum(1)
224
+ if reduction == "none":
225
+ return loss
226
+ elif reduction == "mean":
227
+ return loss.mean()
228
+
229
+ def make_train_valid_dfs():
230
+ dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
231
+ dataframe['id'] = dataframe.index #new added
232
+ max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
233
+ image_ids = np.arange(0, max_id)
234
+ np.random.seed(42)
235
+ valid_ids = np.random.choice(
236
+ image_ids, size=int(0.2 * len(image_ids)), replace=False
237
+ )
238
+ train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
239
+ train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
240
+ valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
241
+ return train_dataframe, valid_dataframe
242
+
243
+
244
+ def build_loaders(dataframe, tokenizer, mode):
245
+ transforms = get_transforms(mode=mode)
246
+ dataset = CLIPDataset(
247
+ dataframe["image"].values,
248
+ dataframe["caption"].values,
249
+ tokenizer=tokenizer,
250
+ transforms=transforms,
251
+ )
252
+ dataloader = torch.utils.data.DataLoader(
253
+ dataset,
254
+ batch_size=CFG.batch_size,
255
+ num_workers=CFG.num_workers,
256
+ shuffle=True if mode == "train" else False,
257
+ )
258
+ return dataloader
259
+
260
+ def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
261
+ loss_meter = AvgMeter()
262
+ tqdm_object = tqdm(train_loader, total=len(train_loader))
263
+ for batch in tqdm_object:
264
+ batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
265
+ loss = model(batch)
266
+ optimizer.zero_grad()
267
+ loss.backward()
268
+ optimizer.step()
269
+ if step == "batch":
270
+ lr_scheduler.step()
271
+
272
+ count = batch["image"].size(0)
273
+ loss_meter.update(loss.item(), count)
274
+
275
+ tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
276
+ return loss_meter
277
+
278
+
279
+ def valid_epoch(model, valid_loader):
280
+ loss_meter = AvgMeter()
281
+
282
+ tqdm_object = tqdm(valid_loader, total=len(valid_loader))
283
+ for batch in tqdm_object:
284
+ batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
285
+ loss = model(batch)
286
+
287
+ count = batch["image"].size(0)
288
+ loss_meter.update(loss.item(), count)
289
+
290
+ tqdm_object.set_postfix(valid_loss=loss_meter.avg)
291
+ return loss_meter
292
+
293
+
294
+ def main():
295
+ train_df, valid_df = make_train_valid_dfs()
296
+ tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
297
+ train_loader = build_loaders(train_df, tokenizer, mode="train")
298
+ valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
299
+
300
+
301
+ model = CLIPModel().to(CFG.device)
302
+ params = [
303
+ {"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},
304
+ {"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
305
+ {"params": itertools.chain(
306
+ model.image_projection.parameters(), model.text_projection.parameters()
307
+ ), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
308
+ ]
309
+ optimizer = torch.optim.AdamW(params, weight_decay=0.)
310
+ lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
311
+ optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
312
+ )
313
+ step = "epoch"
314
+
315
+ best_loss = float('inf')
316
+ for epoch in range(CFG.epochs):
317
+ print(f"Epoch: {epoch + 1}")
318
+ model.train()
319
+ train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
320
+ model.eval()
321
+ with torch.no_grad():
322
+ valid_loss = valid_epoch(model, valid_loader)
323
+
324
+ if valid_loss.avg < best_loss:
325
+ best_loss = valid_loss.avg
326
+ torch.save(model.state_dict(), "best.pt")
327
+ print("Saved Best Model!")
328
+
329
+ lr_scheduler.step(valid_loss.avg)
330
+
331
+ if __name__ == "__main__":
332
+ main()
main.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import numpy as np
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+
7
+ import torch
8
+ from torch import nn
9
+ from transformers import DistilBertTokenizer
10
+
11
+ import config as CFG
12
+ from dataset import CLIPDataset, get_transforms
13
+ from CLIP import CLIPModel
14
+ from utils import AvgMeter, get_lr
15
+
16
+
17
+ def make_train_valid_dfs():
18
+ dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
19
+ dataframe['id'] = dataframe.index #new added
20
+ max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
21
+ image_ids = np.arange(0, max_id)
22
+ np.random.seed(42)
23
+ valid_ids = np.random.choice(
24
+ image_ids, size=int(0.2 * len(image_ids)), replace=False
25
+ )
26
+ train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
27
+ train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
28
+ valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
29
+ return train_dataframe, valid_dataframe
30
+
31
+
32
+ def build_loaders(dataframe, tokenizer, mode):
33
+ transforms = get_transforms(mode=mode)
34
+ dataset = CLIPDataset(
35
+ dataframe["image"].values,
36
+ dataframe["caption"].values,
37
+ tokenizer=tokenizer,
38
+ transforms=transforms,
39
+ )
40
+ dataloader = torch.utils.data.DataLoader(
41
+ dataset,
42
+ batch_size=CFG.batch_size,
43
+ num_workers=CFG.num_workers,
44
+ shuffle=True if mode == "train" else False,
45
+ )
46
+ return dataloader
47
+
48
+
49
+ def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
50
+ loss_meter = AvgMeter()
51
+ tqdm_object = tqdm(train_loader, total=len(train_loader))
52
+ for batch in tqdm_object:
53
+ batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
54
+ loss = model(batch)
55
+ optimizer.zero_grad()
56
+ loss.backward()
57
+ optimizer.step()
58
+ if step == "batch":
59
+ lr_scheduler.step()
60
+
61
+ count = batch["image"].size(0)
62
+ loss_meter.update(loss.item(), count)
63
+
64
+ tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
65
+ return loss_meter
66
+
67
+
68
+ def valid_epoch(model, valid_loader):
69
+ loss_meter = AvgMeter()
70
+
71
+ tqdm_object = tqdm(valid_loader, total=len(valid_loader))
72
+ for batch in tqdm_object:
73
+ batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
74
+ loss = model(batch)
75
+
76
+ count = batch["image"].size(0)
77
+ loss_meter.update(loss.item(), count)
78
+
79
+ tqdm_object.set_postfix(valid_loss=loss_meter.avg)
80
+ return loss_meter
81
+
82
+
83
+ def main():
84
+ train_df, valid_df = make_train_valid_dfs()
85
+ tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
86
+ train_loader = build_loaders(train_df, tokenizer, mode="train")
87
+ valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
88
+
89
+
90
+ model = CLIPModel().to(CFG.device)
91
+ optimizer = torch.optim.AdamW(
92
+ model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay
93
+ )
94
+ lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
95
+ optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
96
+ )
97
+ step = "epoch"
98
+
99
+ best_loss = float('inf')
100
+ for epoch in range(CFG.epochs):
101
+ print(f"Epoch: {epoch + 1}")
102
+ model.train()
103
+ train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
104
+ model.eval()
105
+ with torch.no_grad():
106
+ valid_loss = valid_epoch(model, valid_loader)
107
+
108
+ if valid_loss.avg < best_loss:
109
+ best_loss = valid_loss.avg
110
+ torch.save(model.state_dict(), "best2.pt")
111
+ print("Saved Best Model!")
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {\rtf1\ansi\ansicpg1252\cocoartf2709
2
+ \cocoatextscaling0\cocoaplatform0{\fonttbl\f0\fswiss\fcharset0 Helvetica;}
3
+ {\colortbl;\red255\green255\blue255;}
4
+ {\*\expandedcolortbl;;}
5
+ \paperw11900\paperh16840\margl1440\margr1440\vieww11520\viewh8400\viewkind0
6
+ \pard\tx720\tx1440\tx2160\tx2880\tx3600\tx4320\tx5040\tx5760\tx6480\tx7200\tx7920\tx8640\pardirnatural\partightenfactor0
7
+
8
+ \f0\fs24 \cf0 torch\
9
+ opencv-python\
10
+ matplotlib\
11
+ transformers\
12
+ tqdm\
13
+ \
14
+ }