Spaces:
Runtime error
Runtime error
Alberto Carmona
commited on
Commit
·
35df8d2
1
Parent(s):
2773b59
Add required folders and files
Browse files- configs/phase2/FineCapEval_clipRN50_clips_grammar.yml +64 -0
- configs/phase2/clipRN50_clips_grammar.yml +64 -0
- configs/phase2/transformer.yml +41 -0
- data/README.md +1 -0
- retrieval/README.md +5 -0
- retrieval/caption_data.py +500 -0
- retrieval/clip_model.py +350 -0
- retrieval/configs/clip_negative_text.yaml +14 -0
- retrieval/param.py +209 -0
- retrieval/pth_loader.py +334 -0
- retrieval/text_utils.py +74 -0
- retrieval/train_pl.py +661 -0
- save/README.md +1 -0
configs/phase2/FineCapEval_clipRN50_clips_grammar.yml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
noamopt_warmup: 20000
|
4 |
+
label_smoothing: 0.0
|
5 |
+
input_json: data/FineCapEval.json
|
6 |
+
input_label_h5: none
|
7 |
+
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
+
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
+
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
+
seq_per_img: 5
|
11 |
+
batch_size: 160
|
12 |
+
learning_rate: 0.0005
|
13 |
+
|
14 |
+
checkpoint_path: ./save/clipRN50_clips_grammar/clipRN50_clips_grammar
|
15 |
+
|
16 |
+
use_multi_rewards: true
|
17 |
+
use_grammar: true
|
18 |
+
use_grammar_baseline: true
|
19 |
+
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
20 |
+
|
21 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
22 |
+
# N=num_layers
|
23 |
+
# d_model=input_encoding_size
|
24 |
+
# d_ff=rnn_size
|
25 |
+
|
26 |
+
# will be ignored
|
27 |
+
num_layers: 6
|
28 |
+
input_encoding_size: 512
|
29 |
+
rnn_size: 2048
|
30 |
+
|
31 |
+
# Transformer config
|
32 |
+
N_enc: 6
|
33 |
+
N_dec: 6
|
34 |
+
d_model: 512
|
35 |
+
d_ff: 2048
|
36 |
+
num_att_heads: 8
|
37 |
+
dropout: 0.1
|
38 |
+
|
39 |
+
|
40 |
+
learning_rate_decay_start: 0
|
41 |
+
scheduled_sampling_start: -1
|
42 |
+
save_checkpoint_every: 3000
|
43 |
+
language_eval: 0
|
44 |
+
val_images_use: 5000
|
45 |
+
max_epochs: 15
|
46 |
+
train_sample_n: 5
|
47 |
+
|
48 |
+
REFORWARD: false
|
49 |
+
|
50 |
+
# _BASE_: transformer.yml
|
51 |
+
reduce_on_plateau: false
|
52 |
+
noamopt: false
|
53 |
+
learning_rate: 0.000005
|
54 |
+
learning_rate_decay_start: -1
|
55 |
+
|
56 |
+
self_critical_after: 15
|
57 |
+
max_epochs: 50
|
58 |
+
|
59 |
+
verbose: false
|
60 |
+
precision: 32
|
61 |
+
|
62 |
+
# use_clipscore: true
|
63 |
+
use_clipscore: false
|
64 |
+
clipscore_reward_weight: 2.0
|
configs/phase2/clipRN50_clips_grammar.yml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
noamopt_warmup: 20000
|
4 |
+
label_smoothing: 0.0
|
5 |
+
input_json: data/cocotalk.json
|
6 |
+
input_label_h5: data/cocotalk_label.h5
|
7 |
+
input_fc_dir: data/cocotalk_clip_RN50_fc
|
8 |
+
input_att_dir: data/cocotalk_clip_RN50_att
|
9 |
+
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
10 |
+
seq_per_img: 5
|
11 |
+
batch_size: 160
|
12 |
+
learning_rate: 0.0005
|
13 |
+
|
14 |
+
checkpoint_path: save/clipRN50_clips_grammar/clipRN50_clips_grammar
|
15 |
+
|
16 |
+
use_multi_rewards: true
|
17 |
+
use_grammar: true
|
18 |
+
use_grammar_baseline: true
|
19 |
+
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
20 |
+
clip_load_path: 'retrieval/save/clip_negative_text/clip_negative_text-epoch=12.ckpt'
|
21 |
+
|
22 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
23 |
+
# N=num_layers
|
24 |
+
# d_model=input_encoding_size
|
25 |
+
# d_ff=rnn_size
|
26 |
+
|
27 |
+
# will be ignored
|
28 |
+
num_layers: 6
|
29 |
+
input_encoding_size: 512
|
30 |
+
rnn_size: 2048
|
31 |
+
|
32 |
+
# Transformer config
|
33 |
+
N_enc: 6
|
34 |
+
N_dec: 6
|
35 |
+
d_model: 512
|
36 |
+
d_ff: 2048
|
37 |
+
num_att_heads: 8
|
38 |
+
dropout: 0.1
|
39 |
+
|
40 |
+
|
41 |
+
learning_rate_decay_start: 0
|
42 |
+
scheduled_sampling_start: -1
|
43 |
+
save_checkpoint_every: 3000
|
44 |
+
language_eval: 1
|
45 |
+
val_images_use: 5000
|
46 |
+
max_epochs: 15
|
47 |
+
train_sample_n: 5
|
48 |
+
|
49 |
+
REFORWARD: false
|
50 |
+
|
51 |
+
# _BASE_: transformer.yml
|
52 |
+
reduce_on_plateau: false
|
53 |
+
noamopt: false
|
54 |
+
learning_rate: 0.000005
|
55 |
+
learning_rate_decay_start: -1
|
56 |
+
|
57 |
+
self_critical_after: 15
|
58 |
+
max_epochs: 40
|
59 |
+
|
60 |
+
verbose: false
|
61 |
+
precision: 32
|
62 |
+
|
63 |
+
use_clipscore: true
|
64 |
+
clipscore_reward_weight: 2.0
|
configs/phase2/transformer.yml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
noamopt_warmup: 20000
|
4 |
+
label_smoothing: 0.0
|
5 |
+
input_json: data/cocotalk.json
|
6 |
+
input_label_h5: data/cocotalk_label.h5
|
7 |
+
input_att_dir: data/cocotalk_att
|
8 |
+
seq_per_img: 5
|
9 |
+
batch_size: 10
|
10 |
+
learning_rate: 0.0005
|
11 |
+
|
12 |
+
checkpoint_path: ./save/trans_rn50_sc
|
13 |
+
|
14 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
15 |
+
# N=num_layers
|
16 |
+
# d_model=input_encoding_size
|
17 |
+
# d_ff=rnn_size
|
18 |
+
|
19 |
+
# will be ignored
|
20 |
+
num_layers: 6
|
21 |
+
input_encoding_size: 512
|
22 |
+
rnn_size: 2048
|
23 |
+
|
24 |
+
# Transformer config
|
25 |
+
N_enc: 6
|
26 |
+
N_dec: 6
|
27 |
+
d_model: 512
|
28 |
+
d_ff: 2048
|
29 |
+
num_att_heads: 8
|
30 |
+
dropout: 0.1
|
31 |
+
|
32 |
+
|
33 |
+
learning_rate_decay_start: 0
|
34 |
+
scheduled_sampling_start: -1
|
35 |
+
save_checkpoint_every: 3000
|
36 |
+
language_eval: 1
|
37 |
+
val_images_use: 5000
|
38 |
+
max_epochs: 15
|
39 |
+
train_sample_n: 5
|
40 |
+
|
41 |
+
REFORWARD: false
|
data/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
directory to store preprocessed files
|
retrieval/README.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Finetuning CLIP reward model
|
2 |
+
|
3 |
+
```bash
|
4 |
+
python train_pl.py --cfg clip_negative_text --id clip_negative_text
|
5 |
+
```
|
retrieval/caption_data.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader, Dataset, Sampler
|
2 |
+
from pathlib import Path
|
3 |
+
import json
|
4 |
+
from multiprocessing import Pool
|
5 |
+
from tqdm import tqdm
|
6 |
+
from PIL import Image
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torchvision
|
11 |
+
import torchvision.transforms as T
|
12 |
+
|
13 |
+
from torch.utils.data.distributed import DistributedSampler
|
14 |
+
|
15 |
+
from transformers import T5Tokenizer, BertTokenizer, BertTokenizerFast, CLIPTokenizer
|
16 |
+
|
17 |
+
import text_utils
|
18 |
+
|
19 |
+
project_dir = Path(__file__).parent.resolve()
|
20 |
+
workspace_dir = project_dir.parent.parent
|
21 |
+
dataset_dir = workspace_dir.joinpath('datasets/').resolve()
|
22 |
+
# coco_dir = dataset_dir.joinpath('COCO')
|
23 |
+
# vg_dir = dataset_dir.joinpath('VG')
|
24 |
+
coco_img_dir = dataset_dir.joinpath('COCO/images/')
|
25 |
+
coco_data_dir = project_dir.parent.joinpath('CLIP-ViL/CLIP-ViL-Direct/caption/data/')
|
26 |
+
# coco_feature_dir = coco_dir.joinpath('features')
|
27 |
+
|
28 |
+
|
29 |
+
class COCORetrievalDataset(Dataset):
|
30 |
+
def __init__(self, split='karpathy_train', rank=-1, topk=-1, verbose=True, args=None, mode='train'):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.topk = topk
|
34 |
+
self.verbose = verbose
|
35 |
+
self.args = args
|
36 |
+
self.rank = rank
|
37 |
+
self.mode = mode
|
38 |
+
|
39 |
+
# Loading datasets to data
|
40 |
+
self.source = split
|
41 |
+
if self.verbose:
|
42 |
+
print('Data source: ', self.source)
|
43 |
+
|
44 |
+
# if self.args.tokenizer is None:
|
45 |
+
# self.args.tokenizer = self.args.decoder_backbone
|
46 |
+
|
47 |
+
# if 'bert' in self.args.tokenizer:
|
48 |
+
# self.tokenizer = BertTokenizerFast.from_pretrained(
|
49 |
+
# self.args.tokenizer,
|
50 |
+
# # max_length=self.args.max_text_length,
|
51 |
+
# # do_lower_case=self.args.do_lower_case
|
52 |
+
# )
|
53 |
+
# elif 'clip' in self.args.tokenizer:
|
54 |
+
# self.tokenizer = CLIPTokenizer.from_pretrained(
|
55 |
+
# self.args.tokenizer,
|
56 |
+
# # max_length=self.args.max_text_length,
|
57 |
+
# # do_lower_case=self.args.do_lower_case
|
58 |
+
# )
|
59 |
+
|
60 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(
|
61 |
+
self.args.tokenizer,
|
62 |
+
# max_length=self.args.max_text_length,
|
63 |
+
# do_lower_case=self.args.do_lower_case
|
64 |
+
)
|
65 |
+
|
66 |
+
with open(coco_data_dir.joinpath('cocotalk.json')) as f:
|
67 |
+
self.vocab = list(json.load(f)['ix_to_word'].values())
|
68 |
+
popped = self.vocab.pop(-1)
|
69 |
+
assert popped == 'UNK'
|
70 |
+
if self.verbose:
|
71 |
+
print('vocab size: ', len(self.vocab))
|
72 |
+
|
73 |
+
|
74 |
+
data_info_path = coco_data_dir.joinpath('dataset_coco.json')
|
75 |
+
with open(data_info_path) as f:
|
76 |
+
karpathy_data = json.load(f)
|
77 |
+
|
78 |
+
split_rename = {
|
79 |
+
'train': 'train',
|
80 |
+
'restval': 'train',
|
81 |
+
'val': 'val',
|
82 |
+
'test': 'test'
|
83 |
+
}
|
84 |
+
|
85 |
+
n_images = 0
|
86 |
+
|
87 |
+
data = []
|
88 |
+
# self.vocab = set()
|
89 |
+
for datum in karpathy_data['images']:
|
90 |
+
re_split = split_rename[datum['split']]
|
91 |
+
|
92 |
+
# if re_split == 'train':
|
93 |
+
# for d in datum['sentences']:
|
94 |
+
# self.vocab = self.vocab.union(set(d['tokens']))
|
95 |
+
|
96 |
+
if re_split != self.source.split('_')[-1]:
|
97 |
+
continue
|
98 |
+
|
99 |
+
if re_split == 'train':
|
100 |
+
# for d in datum['sentences']:
|
101 |
+
# img_id = datum['filename'].split('.')[0]
|
102 |
+
# new_datum = {
|
103 |
+
# 'filename': datum['filename'],
|
104 |
+
# 'img_id': img_id,
|
105 |
+
# 'sent': d['raw'].strip(),
|
106 |
+
# 'targets': [d['raw'].strip() for d in datum['sentences']],
|
107 |
+
# 'is_train': True,
|
108 |
+
# 'cocoid': datum['cocoid']
|
109 |
+
# }
|
110 |
+
# data.append(new_datum)
|
111 |
+
img_id = datum['filename'].split('.')[0]
|
112 |
+
new_datum = {
|
113 |
+
'filename': datum['filename'],
|
114 |
+
'img_id': img_id,
|
115 |
+
# 'sent': d['raw'],
|
116 |
+
# 'targets': [d['raw'].strip() for d in datum['sentences']],
|
117 |
+
'targets': [" ".join(d['tokens']) for d in datum['sentences']],
|
118 |
+
'is_train': True,
|
119 |
+
'cocoid': datum['cocoid']
|
120 |
+
}
|
121 |
+
data.append(new_datum)
|
122 |
+
|
123 |
+
else:
|
124 |
+
img_id = datum['filename'].split('.')[0]
|
125 |
+
new_datum = {
|
126 |
+
'filename': datum['filename'],
|
127 |
+
'img_id': img_id,
|
128 |
+
# 'sent': d['raw'],
|
129 |
+
# 'targets': [d['raw'].strip() for d in datum['sentences']],
|
130 |
+
'targets': [" ".join(d['tokens']) for d in datum['sentences']],
|
131 |
+
'is_train': False,
|
132 |
+
'cocoid': datum['cocoid']
|
133 |
+
}
|
134 |
+
data.append(new_datum)
|
135 |
+
|
136 |
+
n_images += 1
|
137 |
+
|
138 |
+
if self.verbose:
|
139 |
+
print(f"{self.source} has {n_images} images")
|
140 |
+
# print(f"Loaded {len(data)} data from", split)
|
141 |
+
|
142 |
+
self.n_gpus = torch.cuda.device_count()
|
143 |
+
|
144 |
+
if self.topk > 0:
|
145 |
+
data = data[:self.topk]
|
146 |
+
if self.verbose:
|
147 |
+
print(f"Use only {self.topk} data")
|
148 |
+
|
149 |
+
self.data = data
|
150 |
+
|
151 |
+
# if self.verbose:
|
152 |
+
# print("# all sentences:", len(self.data))
|
153 |
+
|
154 |
+
if self.args.load_feat:
|
155 |
+
# feat_dir = coco_dir.joinpath(''
|
156 |
+
# self.feat_loader = HybridLoader('/scratch-space/CLIP-ViL/CLIP-ViL-Direct/caption/data/cocotalk_clipscore_vis', ext='.npy', in_memory=False)
|
157 |
+
self.feat_loader = HybridLoader(
|
158 |
+
coco_data_dir.joinpath('cocotalk_clipscore_vis'),
|
159 |
+
ext='.npy', in_memory=False)
|
160 |
+
else:
|
161 |
+
if 'openai/clip' in self.args.encoder_backbone:
|
162 |
+
# from transformers import CLIPProcessor
|
163 |
+
# self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32",
|
164 |
+
# size=args.image_size,
|
165 |
+
# do_resize=True,
|
166 |
+
# do_center_crop=False,
|
167 |
+
# )
|
168 |
+
# self.img_transform = lambda image: self.processor.feature_extractor(
|
169 |
+
# image,
|
170 |
+
# return_tensors='pt')['pixel_values'][0]
|
171 |
+
|
172 |
+
self.image_mean = [0.48145466, 0.4578275, 0.40821073]
|
173 |
+
self.image_std = [0.26862954, 0.26130258, 0.27577711]
|
174 |
+
|
175 |
+
# captioning
|
176 |
+
# self.img_transform = T.Compose([
|
177 |
+
# T.Resize((self.args.image_size, self.args.image_size))
|
178 |
+
# ])
|
179 |
+
|
180 |
+
# retrieval
|
181 |
+
self.img_transform = T.Compose([
|
182 |
+
T.Resize(self.args.image_size, interpolation=T.functional.InterpolationMode.BICUBIC),
|
183 |
+
T.CenterCrop(self.args.image_size)
|
184 |
+
])
|
185 |
+
|
186 |
+
self.img_tensor_transform = T.Compose([
|
187 |
+
# T.RandomCrop(224),
|
188 |
+
# T.RandomHorizontalFlip(p=0.3),
|
189 |
+
T.ConvertImageDtype(torch.float),
|
190 |
+
T.Normalize(self.image_mean, self.image_std)
|
191 |
+
]
|
192 |
+
)
|
193 |
+
# elif 'google/vit' in self.args.encoder_backbone:
|
194 |
+
# self.image_mean = [0.5, 0.5, 0.5]
|
195 |
+
# self.image_std = [0.5, 0.5, 0.5]
|
196 |
+
|
197 |
+
# self.img_transform = T.Compose([
|
198 |
+
# # T.PILToTensor(),
|
199 |
+
# T.Resize((self.args.image_size, self.args.image_size))
|
200 |
+
# ])
|
201 |
+
|
202 |
+
# self.img_tensor_transform = T.Compose([
|
203 |
+
# # T.RandomCrop(224),
|
204 |
+
# # T.RandomHorizontalFlip(p=0.3),
|
205 |
+
# T.ConvertImageDtype(torch.float),
|
206 |
+
# T.Normalize(self.image_mean, self.image_std)
|
207 |
+
# ]
|
208 |
+
# )
|
209 |
+
|
210 |
+
def get_negative_text(self, text):
|
211 |
+
neg_type = random.choice(['repeat', 'remove', 'insert', 'swap', 'shuffle'])
|
212 |
+
|
213 |
+
if neg_type == 'repeat':
|
214 |
+
text = text_utils.repeat(text)
|
215 |
+
elif neg_type == 'remove':
|
216 |
+
text = text_utils.remove(text)
|
217 |
+
elif neg_type == 'insert':
|
218 |
+
text = text_utils.insert(text, self.vocab)
|
219 |
+
elif neg_type == 'swap':
|
220 |
+
text = text_utils.swap(text, self.vocab)
|
221 |
+
elif neg_type == 'shuffle':
|
222 |
+
text = text_utils.shuffle(text)
|
223 |
+
|
224 |
+
return text, neg_type
|
225 |
+
|
226 |
+
def __len__(self):
|
227 |
+
return len(self.data)
|
228 |
+
|
229 |
+
def __getitem__(self, idx):
|
230 |
+
datum = self.data[idx]
|
231 |
+
return self.process_datum(datum)
|
232 |
+
|
233 |
+
def process_datum(self, datum):
|
234 |
+
out_dict = {}
|
235 |
+
|
236 |
+
###### Image ######
|
237 |
+
|
238 |
+
if self.args.load_feat:
|
239 |
+
cocoid = datum['cocoid']
|
240 |
+
out_dict['cocoid'] = str(cocoid)
|
241 |
+
img_feat = self.feat_loader.get(str(cocoid))
|
242 |
+
out_dict['img_feat'] = torch.from_numpy(img_feat)
|
243 |
+
|
244 |
+
else:
|
245 |
+
img_id = datum['img_id']
|
246 |
+
out_dict['img_id'] = img_id
|
247 |
+
|
248 |
+
if 'train' in datum['filename']:
|
249 |
+
img_split = 'train2014'
|
250 |
+
elif 'val' in datum['filename']:
|
251 |
+
img_split = 'val2014'
|
252 |
+
img_path = coco_img_dir.joinpath(img_split).joinpath(datum['filename']).with_suffix('.jpg')
|
253 |
+
assert img_path.exists()
|
254 |
+
img_path = str(img_path)
|
255 |
+
out_dict['img_path'] = img_path
|
256 |
+
|
257 |
+
img_tensor = torchvision.io.read_image(img_path)
|
258 |
+
# out_dict['img_tensor'] = img
|
259 |
+
|
260 |
+
# img = Image.open(img_path).convert('RGB')
|
261 |
+
# img_tensor = torch.as_tensor(np.asarray(img))
|
262 |
+
out_dict['img_tensor'] = self.img_transform(img_tensor)
|
263 |
+
# self.img_transform(img_tensor)
|
264 |
+
# out_dict['img_tensor'] = self.img_transform(img)
|
265 |
+
|
266 |
+
###### Text #####
|
267 |
+
# if datum['is_train']:
|
268 |
+
# sent = datum['sent'].strip()
|
269 |
+
|
270 |
+
sent = random.choice(datum['targets'])
|
271 |
+
|
272 |
+
# target_ids = self.tokenizer.encode(
|
273 |
+
# sent, max_length=self.args.gen_max_length, truncation=True)
|
274 |
+
|
275 |
+
# assert len(target_ids) <= self.args.gen_max_length, len(target_ids)
|
276 |
+
out_dict['sent'] = sent
|
277 |
+
# out_dict['target_ids'] = torch.LongTensor(target_ids)
|
278 |
+
# out_dict['target_length'] = len(target_ids)
|
279 |
+
|
280 |
+
|
281 |
+
# negative sample
|
282 |
+
neg_sent, neg_type = self.get_negative_text(sent)
|
283 |
+
|
284 |
+
# neg_target_ids = self.tokenizer.encode(
|
285 |
+
# neg_sent, max_length=self.args.gen_max_length, truncation=True)
|
286 |
+
|
287 |
+
# assert len(neg_target_ids) <= self.args.gen_max_length, len(neg_target_ids)
|
288 |
+
out_dict['neg_sent'] = neg_sent
|
289 |
+
out_dict['neg_type'] = neg_type
|
290 |
+
# out_dict['neg_target_ids'] = torch.LongTensor(neg_target_ids)
|
291 |
+
# out_dict['neg_target_length'] = len(neg_target_ids)
|
292 |
+
|
293 |
+
|
294 |
+
if 'targets' in datum:
|
295 |
+
out_dict['targets'] = datum['targets']
|
296 |
+
|
297 |
+
return out_dict
|
298 |
+
|
299 |
+
def collate_fn(self, batch):
|
300 |
+
batch_entry = {}
|
301 |
+
|
302 |
+
B = len(batch)
|
303 |
+
|
304 |
+
# if 'target_ids' in batch[0]:
|
305 |
+
# T_W_L = max(entry['target_length'] for entry in batch)
|
306 |
+
# target_ids = torch.ones(
|
307 |
+
# B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
|
308 |
+
|
309 |
+
# if 'target_ids' in batch[0]:
|
310 |
+
# T_W_L = max(entry['target_length'] for entry in batch)
|
311 |
+
# target_ids = torch.ones(
|
312 |
+
# B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
targets = []
|
317 |
+
img_ids = []
|
318 |
+
img_paths = []
|
319 |
+
|
320 |
+
coco_ids = []
|
321 |
+
|
322 |
+
if self.args.load_feat:
|
323 |
+
img_feats = torch.zeros(B, 512, dtype=torch.float)
|
324 |
+
else:
|
325 |
+
# imgs = []
|
326 |
+
img_tensor = torch.zeros(B, 3, self.args.image_size, self.args.image_size, dtype=torch.uint8)
|
327 |
+
|
328 |
+
for i, entry in enumerate(batch):
|
329 |
+
|
330 |
+
if self.args.load_feat:
|
331 |
+
coco_ids.append(entry['cocoid'])
|
332 |
+
img_feats[i] = entry['img_feat']
|
333 |
+
|
334 |
+
else:
|
335 |
+
|
336 |
+
img_ids.append(entry['img_id'])
|
337 |
+
img_paths.append(entry['img_path'])
|
338 |
+
img_tensor[i] = entry['img_tensor']
|
339 |
+
|
340 |
+
# if 'target_ids' in entry:
|
341 |
+
# target_ids[i, :entry['target_length']] = entry['target_ids']
|
342 |
+
|
343 |
+
if 'targets' in entry:
|
344 |
+
targets.append(entry['targets'])
|
345 |
+
|
346 |
+
if 'sent' in batch[0]:
|
347 |
+
# word_mask = target_ids != self.tokenizer.pad_token_id
|
348 |
+
# target_ids[~word_mask] = -100
|
349 |
+
# batch_entry['target_ids'] = target_ids
|
350 |
+
|
351 |
+
tokenized = self.tokenizer([entry['sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt')
|
352 |
+
neg_tokenized = self.tokenizer([entry['neg_sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt')
|
353 |
+
# sent, max_length=self.args.gen_max_length, truncation=True)
|
354 |
+
|
355 |
+
batch_entry['text'] = (tokenized.input_ids, tokenized.attention_mask)
|
356 |
+
batch_entry['neg_text'] = (neg_tokenized.input_ids, neg_tokenized.attention_mask)
|
357 |
+
|
358 |
+
|
359 |
+
if self.args.load_feat:
|
360 |
+
batch_entry['coco_ids'] = coco_ids
|
361 |
+
batch_entry['img_feats'] = img_feats
|
362 |
+
|
363 |
+
else:
|
364 |
+
|
365 |
+
img_tensor = self.img_tensor_transform(img_tensor)
|
366 |
+
|
367 |
+
batch_entry['img_id'] = img_ids
|
368 |
+
batch_entry['img_paths'] = img_paths
|
369 |
+
batch_entry['img_tensor'] = img_tensor
|
370 |
+
|
371 |
+
batch_entry['targets'] = targets
|
372 |
+
|
373 |
+
# print('batch created')
|
374 |
+
|
375 |
+
# batch_entry['task'] = 'caption'
|
376 |
+
|
377 |
+
return batch_entry
|
378 |
+
|
379 |
+
|
380 |
+
# def get_loader(args, split='karpathy_train', mode='train',
|
381 |
+
# batch_size=32, workers=4, distributed=False, gpu=0,
|
382 |
+
# topk=-1):
|
383 |
+
|
384 |
+
# verbose = (gpu == 0)
|
385 |
+
|
386 |
+
# dataset = COCORetrievalDataset(
|
387 |
+
# split,
|
388 |
+
# rank=gpu,
|
389 |
+
# topk=topk,
|
390 |
+
# verbose=verbose,
|
391 |
+
# args=args,
|
392 |
+
# mode=mode)
|
393 |
+
|
394 |
+
# # if distributed:
|
395 |
+
# # sampler = DistributedSampler(dataset)
|
396 |
+
# # else:
|
397 |
+
# # sampler = None
|
398 |
+
|
399 |
+
# if mode == 'train':
|
400 |
+
# loader = DataLoader(
|
401 |
+
# dataset, batch_size=batch_size, shuffle=(sampler is None),
|
402 |
+
# num_workers=workers, pin_memory=True, sampler=sampler,
|
403 |
+
# collate_fn=dataset.collate_fn)
|
404 |
+
# else:
|
405 |
+
# loader = DataLoader(
|
406 |
+
# dataset,
|
407 |
+
# batch_size=batch_size, shuffle=False,
|
408 |
+
# num_workers=workers, pin_memory=True,
|
409 |
+
# sampler=sampler,
|
410 |
+
# collate_fn=dataset.collate_fn,
|
411 |
+
# drop_last=False)
|
412 |
+
|
413 |
+
# # if verbose:
|
414 |
+
# # loader.evaluator = COCOCaptionEvaluator()
|
415 |
+
|
416 |
+
# # loader.task = 'caption'
|
417 |
+
|
418 |
+
# return loader
|
419 |
+
|
420 |
+
|
421 |
+
# class COCOCaptionEvaluator:
|
422 |
+
# def __init__(self):
|
423 |
+
# import language_evaluation
|
424 |
+
# self.evaluator = language_evaluation.CocoEvaluator(verbose=False)
|
425 |
+
|
426 |
+
# def evaluate(self, predicts, answers):
|
427 |
+
|
428 |
+
# results = self.evaluator.run_evaluation(predicts, answers)
|
429 |
+
|
430 |
+
# return results
|
431 |
+
|
432 |
+
import six
|
433 |
+
import os
|
434 |
+
import h5py
|
435 |
+
|
436 |
+
class HybridLoader:
|
437 |
+
"""
|
438 |
+
If db_path is a director, then use normal file loading
|
439 |
+
If lmdb, then load from lmdb
|
440 |
+
The loading method depend on extention.
|
441 |
+
|
442 |
+
in_memory: if in_memory is True, we save all the features in memory
|
443 |
+
For individual np(y|z)s, we don't need to do that because the system will do this for us.
|
444 |
+
Should be useful for lmdb or h5.
|
445 |
+
(Copied this idea from vilbert)
|
446 |
+
"""
|
447 |
+
|
448 |
+
def __init__(self, db_path, ext='.npy', in_memory=False):
|
449 |
+
self.db_path = db_path
|
450 |
+
self.ext = ext
|
451 |
+
if self.ext == '.npy':
|
452 |
+
self.loader = lambda x: np.load(six.BytesIO(x))
|
453 |
+
else:
|
454 |
+
self.loader = lambda x: np.load(six.BytesIO(x))['feat']
|
455 |
+
# if db_path.endswith('.lmdb'):
|
456 |
+
# self.db_type = 'lmdb'
|
457 |
+
# self.lmdb = lmdbdict(db_path, unsafe=True)
|
458 |
+
# self.lmdb._key_dumps = DUMPS_FUNC['ascii']
|
459 |
+
# self.lmdb._value_loads = LOADS_FUNC['identity']
|
460 |
+
# elif db_path.endswith('.pth'): # Assume a key,value dictionary
|
461 |
+
# self.db_type = 'pth'
|
462 |
+
# self.feat_file = torch.load(db_path)
|
463 |
+
# self.loader = lambda x: x
|
464 |
+
# print('HybridLoader: ext is ignored')
|
465 |
+
# elif db_path.endswith('h5'):
|
466 |
+
# self.db_type = 'h5'
|
467 |
+
# self.loader = lambda x: np.array(x).astype('float32')
|
468 |
+
# else:
|
469 |
+
# self.db_type = 'dir'
|
470 |
+
|
471 |
+
self.in_memory = in_memory
|
472 |
+
if self.in_memory:
|
473 |
+
self.features = {}
|
474 |
+
|
475 |
+
def get(self, key):
|
476 |
+
|
477 |
+
# if self.in_memory and key in self.features:
|
478 |
+
# # We save f_input because we want to save the
|
479 |
+
# # compressed bytes to save memory
|
480 |
+
# f_input = self.features[key]
|
481 |
+
# elif self.db_type == 'lmdb':
|
482 |
+
# f_input = self.lmdb[key]
|
483 |
+
# elif self.db_type == 'pth':
|
484 |
+
# f_input = self.feat_file[key]
|
485 |
+
# elif self.db_type == 'h5':
|
486 |
+
# f_input = h5py.File(self.db_path, 'r')[key]
|
487 |
+
# else:
|
488 |
+
# f_input = open(os.path.join(
|
489 |
+
# self.db_path, key + self.ext), 'rb').read()
|
490 |
+
|
491 |
+
f_input = open(os.path.join(
|
492 |
+
self.db_path, key + self.ext), 'rb').read()
|
493 |
+
|
494 |
+
if self.in_memory and key not in self.features:
|
495 |
+
self.features[key] = f_input
|
496 |
+
|
497 |
+
# load image
|
498 |
+
feat = self.loader(f_input)
|
499 |
+
|
500 |
+
return feat
|
retrieval/clip_model.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPModel, CLIPTokenizer
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import argparse
|
5 |
+
from random import shuffle, seed
|
6 |
+
import string
|
7 |
+
# non-standard dependencies:
|
8 |
+
import h5py
|
9 |
+
from six.moves import cPickle
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchvision.models as models
|
13 |
+
import skimage.io
|
14 |
+
|
15 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
16 |
+
from PIL import Image
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
|
20 |
+
class CLIPScore(nn.Module):
|
21 |
+
def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s', use_grammar=False, joint_out=False):
|
22 |
+
super(CLIPScore, self).__init__()
|
23 |
+
# from transformers import CLIPModel, CLIPTokenizer
|
24 |
+
self.clip_model = CLIPModel.from_pretrained(
|
25 |
+
'openai/clip-vit-base-patch32')
|
26 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(
|
27 |
+
'openai/clip-vit-base-patch32')
|
28 |
+
|
29 |
+
self.clip_model.eval()
|
30 |
+
|
31 |
+
self.clipscore_w = clipscore_w
|
32 |
+
|
33 |
+
self.image_transform = self._transform(image_size)
|
34 |
+
|
35 |
+
self.mode = mode
|
36 |
+
assert mode in ['clip_s', 'refclip_s']
|
37 |
+
|
38 |
+
self.use_grammar = use_grammar
|
39 |
+
self.joint_out = joint_out
|
40 |
+
|
41 |
+
if self.use_grammar and self.joint_out is False:
|
42 |
+
self.grammar_score_head = nn.Sequential(
|
43 |
+
nn.Linear(self.clip_model.text_embed_dim, self.clip_model.projection_dim, bias=False),
|
44 |
+
nn.ReLU(),
|
45 |
+
nn.Linear(self.clip_model.projection_dim, 2, bias=False)
|
46 |
+
)
|
47 |
+
|
48 |
+
def _transform(self, n_px):
|
49 |
+
return Compose([
|
50 |
+
Resize(n_px, interpolation=Image.BICUBIC),
|
51 |
+
CenterCrop(n_px),
|
52 |
+
lambda image: image.convert("RGB"),
|
53 |
+
ToTensor(),
|
54 |
+
Normalize((0.48145466, 0.4578275, 0.40821073),
|
55 |
+
(0.26862954, 0.26130258, 0.27577711)),
|
56 |
+
])
|
57 |
+
|
58 |
+
def load_image(self, image_path):
|
59 |
+
image = Image.open(image_path)
|
60 |
+
return image
|
61 |
+
|
62 |
+
# @torch.no_grad()
|
63 |
+
def image_extract(self, image):
|
64 |
+
if isinstance(image, str):
|
65 |
+
image = self.load_image(image)
|
66 |
+
if not isinstance(image, torch.Tensor):
|
67 |
+
image = self.image_transform(image)
|
68 |
+
|
69 |
+
img_tensor = image.view(-1, 3, 224, 224)
|
70 |
+
device = next(self.clip_model.parameters()).device
|
71 |
+
img_tensor = img_tensor.to(device)
|
72 |
+
|
73 |
+
clip_model = self.clip_model
|
74 |
+
|
75 |
+
img_feat = clip_model.vision_model(img_tensor).pooler_output
|
76 |
+
img_feat = clip_model.visual_projection(img_feat)
|
77 |
+
img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
|
78 |
+
|
79 |
+
return img_feat
|
80 |
+
|
81 |
+
# @torch.no_grad()
|
82 |
+
def text_extract(self, text, prompt="A photo depicts", proj_norm=True):
|
83 |
+
if isinstance(text, str):
|
84 |
+
text_batch = [" ".join([prompt, text])]
|
85 |
+
elif isinstance(text, list):
|
86 |
+
text_batch = [" ".join([prompt, txt]) for txt in text]
|
87 |
+
|
88 |
+
if isinstance(text, tuple) and isinstance(text[0], torch.Tensor):
|
89 |
+
input_ids, attention_mask = text
|
90 |
+
else:
|
91 |
+
input_text = text_batch
|
92 |
+
|
93 |
+
tokenized = self.tokenizer(
|
94 |
+
input_text, return_tensors='pt', padding=True)
|
95 |
+
|
96 |
+
input_ids = tokenized.input_ids
|
97 |
+
attention_mask = tokenized.attention_mask
|
98 |
+
|
99 |
+
clip_model = self.clip_model
|
100 |
+
device = next(self.clip_model.parameters()).device
|
101 |
+
input_ids = input_ids.to(device)
|
102 |
+
attention_mask = attention_mask.to(device)
|
103 |
+
|
104 |
+
text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output
|
105 |
+
|
106 |
+
if proj_norm:
|
107 |
+
text_feat = clip_model.text_projection(text_feat)
|
108 |
+
text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
|
109 |
+
|
110 |
+
return text_feat
|
111 |
+
|
112 |
+
# @torch.no_grad()
|
113 |
+
def calc_clip_s(self, img_feat, text_feat):
|
114 |
+
return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1))
|
115 |
+
|
116 |
+
# @torch.no_grad()
|
117 |
+
def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None):
|
118 |
+
|
119 |
+
if clip_s is None:
|
120 |
+
clip_s = self.calc_clip_s(img_feat, text_feat)
|
121 |
+
|
122 |
+
B, dim = img_feat.size()
|
123 |
+
|
124 |
+
ref_text_feat = ref_text_feat.view(B, -1, dim)
|
125 |
+
|
126 |
+
K = ref_text_feat.size(1)
|
127 |
+
|
128 |
+
text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1)
|
129 |
+
assert ref_text_feat.size() == text_feat.size(
|
130 |
+
), (ref_text_feat.size(), text_feat.size())
|
131 |
+
|
132 |
+
ref_score = self.calc_clip_s(text_feat, ref_text_feat)
|
133 |
+
if ref_text_mask is not None:
|
134 |
+
if not isinstance(ref_text_mask, torch.Tensor):
|
135 |
+
ref_text_mask = torch.tensor(
|
136 |
+
ref_text_mask, dtype=ref_score.dtype, device=ref_score.device)
|
137 |
+
ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K)
|
138 |
+
|
139 |
+
ref_score = ref_score.view(B, K).max(dim=1).values
|
140 |
+
|
141 |
+
assert clip_s.size() == (B,)
|
142 |
+
assert clip_s.size() == ref_score.size()
|
143 |
+
|
144 |
+
# harmonic mean
|
145 |
+
refclip_s = 2 / (1 / clip_s + 1 / ref_score)
|
146 |
+
return refclip_s
|
147 |
+
|
148 |
+
# # @torch.no_grad()
|
149 |
+
# def forward(self,
|
150 |
+
# images=None, text=None,
|
151 |
+
# img_feat=None, text_feat=None,
|
152 |
+
# ref_text=None, ref_text_feat=None, ref_text_mask=None,
|
153 |
+
# prompt="A photo depicts",
|
154 |
+
# mode=None):
|
155 |
+
# if img_feat is None:
|
156 |
+
# img_feat = self.image_extract(images)
|
157 |
+
# img_feat = img_feat.view(-1, 512)
|
158 |
+
|
159 |
+
# if text_feat is None:
|
160 |
+
# text_feat = self.text_extract(text, prompt=prompt)
|
161 |
+
# text_feat = text_feat.view(-1, 512)
|
162 |
+
|
163 |
+
# if mode is None:
|
164 |
+
# mode = self.mode
|
165 |
+
# assert mode in ['clip_s', 'refclip_s']
|
166 |
+
|
167 |
+
# if mode == 'clip_s':
|
168 |
+
# clip_s = self.calc_clip_s(img_feat, text_feat)
|
169 |
+
# return clip_s
|
170 |
+
# elif mode == 'refclip_s':
|
171 |
+
# if ref_text_feat is None:
|
172 |
+
# ref_text_feat = self.text_extract(ref_text, prompt=prompt)
|
173 |
+
# ref_text_feat = ref_text_feat.view(-1, 512)
|
174 |
+
|
175 |
+
# refclip_s = self.calc_refclip_s(
|
176 |
+
# img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask)
|
177 |
+
# return refclip_s
|
178 |
+
|
179 |
+
|
180 |
+
def train_step(self,
|
181 |
+
images=None, text=None,
|
182 |
+
img_feat=None, text_feat=None,
|
183 |
+
neg_text=None, neg_text_feat=None,
|
184 |
+
# ref_text=None, ref_text_feat=None, ref_text_mask=None,
|
185 |
+
prompt="A photo depicts",
|
186 |
+
# return_loss=True,
|
187 |
+
**kwargs):
|
188 |
+
|
189 |
+
if img_feat is None:
|
190 |
+
img_feat = self.image_extract(images)
|
191 |
+
img_feat = img_feat.view(-1, 512)
|
192 |
+
|
193 |
+
B = img_feat.size(0)
|
194 |
+
|
195 |
+
if self.joint_out:
|
196 |
+
pos_text_feat = self.text_extract(text, prompt=prompt, proj_norm=False).view(B, 512)
|
197 |
+
neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(-1, 512)
|
198 |
+
neg_B = neg_text_feat.size(0)
|
199 |
+
|
200 |
+
# [B+neg_B, 512]
|
201 |
+
text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
|
202 |
+
|
203 |
+
text_cont_feat = self.clip_model.text_projection(text_feat)
|
204 |
+
text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
|
205 |
+
|
206 |
+
text_cont_feat = text_cont_feat.view(B+neg_B, 512)
|
207 |
+
|
208 |
+
logit_scale = self.clip_model.logit_scale.exp()
|
209 |
+
|
210 |
+
# [B+neg_B * B]
|
211 |
+
logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
|
212 |
+
|
213 |
+
# image-to-text label: positive text
|
214 |
+
caption_loss = -torch.diag(nn.functional.log_softmax(logits_per_text, dim=0)[:B]).mean()
|
215 |
+
|
216 |
+
# calculate text-to-image only on positive text
|
217 |
+
image_loss = -torch.diag(nn.functional.log_softmax(logits_per_text[:B], dim=1)).mean()
|
218 |
+
|
219 |
+
clip_loss = (caption_loss + image_loss) / 2.0
|
220 |
+
|
221 |
+
out = {
|
222 |
+
'clip_loss': clip_loss,
|
223 |
+
'img_feat': img_feat,
|
224 |
+
'text_feat': text_cont_feat[:B].detach(),
|
225 |
+
# 'neg_text_feat': neg_text_feat,
|
226 |
+
}
|
227 |
+
|
228 |
+
return out
|
229 |
+
|
230 |
+
|
231 |
+
else:
|
232 |
+
if text_feat is None:
|
233 |
+
text_feat = self.text_extract(text, prompt=prompt, proj_norm=False)
|
234 |
+
|
235 |
+
text_cont_feat = self.clip_model.text_projection(text_feat)
|
236 |
+
text_cont_feat = text_cont_feat / \
|
237 |
+
text_cont_feat.norm(dim=-1, keepdim=True)
|
238 |
+
|
239 |
+
text_cont_feat = text_cont_feat.view(B, 512)
|
240 |
+
|
241 |
+
|
242 |
+
# cosine similarity as logits
|
243 |
+
logit_scale = self.clip_model.logit_scale.exp()
|
244 |
+
logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
|
245 |
+
# logits_per_image = logits_per_text.T
|
246 |
+
|
247 |
+
clip_loss = clip_loss_fn(logits_per_text)
|
248 |
+
|
249 |
+
|
250 |
+
# negative sampling
|
251 |
+
pos_text_feat = text_feat.view(B, 512)
|
252 |
+
neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512)
|
253 |
+
|
254 |
+
grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
|
255 |
+
|
256 |
+
# 2B, 1
|
257 |
+
grammar_text_logit = self.grammar_score_head(grammar_text_feat)
|
258 |
+
grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B)
|
259 |
+
|
260 |
+
grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels)
|
261 |
+
|
262 |
+
grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False)
|
263 |
+
grammar_pos_pred = grammar_pred[:B]
|
264 |
+
grammar_neg_pred = grammar_pred[B:]
|
265 |
+
# grammar_acc = (grammar_pred == grammar_labels).float().mean()
|
266 |
+
|
267 |
+
out = {
|
268 |
+
'clip_loss': clip_loss,
|
269 |
+
'grammar_loss': grammar_loss,
|
270 |
+
'img_feat': img_feat,
|
271 |
+
'text_feat': text_cont_feat,
|
272 |
+
'neg_text_feat': neg_text_feat,
|
273 |
+
'grammar_pos_pred': grammar_pos_pred,
|
274 |
+
'grammar_neg_pred': grammar_neg_pred,
|
275 |
+
}
|
276 |
+
|
277 |
+
return out
|
278 |
+
|
279 |
+
def train_step_old(self,
|
280 |
+
images=None, text=None,
|
281 |
+
img_feat=None, text_feat=None,
|
282 |
+
neg_text=None, neg_text_feat=None,
|
283 |
+
# ref_text=None, ref_text_feat=None, ref_text_mask=None,
|
284 |
+
prompt="A photo depicts",
|
285 |
+
# return_loss=True,
|
286 |
+
**kwargs):
|
287 |
+
|
288 |
+
if img_feat is None:
|
289 |
+
img_feat = self.image_extract(images)
|
290 |
+
img_feat = img_feat.view(-1, 512)
|
291 |
+
|
292 |
+
B = img_feat.size(0)
|
293 |
+
|
294 |
+
|
295 |
+
|
296 |
+
if text_feat is None:
|
297 |
+
text_feat = self.text_extract(text, prompt=prompt, proj_norm=False)
|
298 |
+
|
299 |
+
text_cont_feat = self.clip_model.text_projection(text_feat)
|
300 |
+
text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
|
301 |
+
text_cont_feat = text_cont_feat.view(B, 512)
|
302 |
+
|
303 |
+
# cosine similarity as logits
|
304 |
+
logit_scale = self.clip_model.logit_scale.exp()
|
305 |
+
logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
|
306 |
+
# logits_per_image = logits_per_text.T
|
307 |
+
|
308 |
+
clip_loss = clip_loss_fn(logits_per_text)
|
309 |
+
|
310 |
+
|
311 |
+
# negative sampling
|
312 |
+
pos_text_feat = text_feat.view(B, 512)
|
313 |
+
neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512)
|
314 |
+
|
315 |
+
grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
|
316 |
+
|
317 |
+
# 2B, 1
|
318 |
+
grammar_text_logit = self.grammar_score_head(grammar_text_feat)
|
319 |
+
grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B)
|
320 |
+
|
321 |
+
grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels)
|
322 |
+
|
323 |
+
grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False)
|
324 |
+
grammar_pos_pred = grammar_pred[:B]
|
325 |
+
grammar_neg_pred = grammar_pred[B:]
|
326 |
+
# grammar_acc = (grammar_pred == grammar_labels).float().mean()
|
327 |
+
|
328 |
+
out = {
|
329 |
+
'clip_loss': clip_loss,
|
330 |
+
'grammar_loss': grammar_loss,
|
331 |
+
'img_feat': img_feat,
|
332 |
+
'text_feat': text_cont_feat,
|
333 |
+
'neg_text_feat': neg_text_feat,
|
334 |
+
'grammar_pos_pred': grammar_pos_pred,
|
335 |
+
'grammar_neg_pred': grammar_neg_pred,
|
336 |
+
}
|
337 |
+
|
338 |
+
return out
|
339 |
+
|
340 |
+
# contrastive loss function, adapted from
|
341 |
+
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
|
342 |
+
def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor:
|
343 |
+
neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim))
|
344 |
+
return -neg_ce.mean()
|
345 |
+
|
346 |
+
|
347 |
+
def clip_loss_fn(similarity: torch.Tensor) -> torch.Tensor:
|
348 |
+
caption_loss = contrastive_loss(similarity, dim=0)
|
349 |
+
image_loss = contrastive_loss(similarity, dim=1)
|
350 |
+
return (caption_loss + image_loss) / 2.0
|
retrieval/configs/clip_negative_text.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoint_dir: ./save/clip_negative_text/
|
2 |
+
|
3 |
+
losses_log_every: 25
|
4 |
+
precision: 32
|
5 |
+
load_feat: true
|
6 |
+
data_in_memory: false
|
7 |
+
|
8 |
+
batch_size: 1600
|
9 |
+
valid_batch_size: 200
|
10 |
+
clip_grad_norm: 0
|
11 |
+
|
12 |
+
epochs: 30
|
13 |
+
use_grammar: true
|
14 |
+
joint_out: false
|
retrieval/param.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import pprint
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
|
11 |
+
def str2bool(v):
|
12 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
13 |
+
return True
|
14 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
15 |
+
return False
|
16 |
+
else:
|
17 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
18 |
+
|
19 |
+
|
20 |
+
def is_interactive():
|
21 |
+
import __main__ as main
|
22 |
+
return not hasattr(main, '__file__')
|
23 |
+
|
24 |
+
|
25 |
+
def get_optimizer(optim, verbose=False):
|
26 |
+
# Bind the optimizer
|
27 |
+
if optim == 'rms':
|
28 |
+
if verbose:
|
29 |
+
print("Optimizer: Using RMSProp")
|
30 |
+
optimizer = torch.optim.RMSprop
|
31 |
+
elif optim == 'adam':
|
32 |
+
if verbose:
|
33 |
+
print("Optimizer: Using Adam")
|
34 |
+
optimizer = torch.optim.Adam
|
35 |
+
elif optim == 'adamw':
|
36 |
+
if verbose:
|
37 |
+
print("Optimizer: Using AdamW")
|
38 |
+
# optimizer = torch.optim.AdamW
|
39 |
+
optimizer = 'adamw'
|
40 |
+
elif optim == 'adamax':
|
41 |
+
if verbose:
|
42 |
+
print("Optimizer: Using Adamax")
|
43 |
+
optimizer = torch.optim.Adamax
|
44 |
+
elif optim == 'sgd':
|
45 |
+
if verbose:
|
46 |
+
print("Optimizer: SGD")
|
47 |
+
optimizer = torch.optim.SGD
|
48 |
+
else:
|
49 |
+
assert False, "Please add your optimizer %s in the list." % optim
|
50 |
+
|
51 |
+
return optimizer
|
52 |
+
|
53 |
+
|
54 |
+
def parse_args(parse=True, **optional_kwargs):
|
55 |
+
parser = argparse.ArgumentParser()
|
56 |
+
|
57 |
+
parser.add_argument('--seed', type=int, default=9595, help='random seed')
|
58 |
+
|
59 |
+
# Data Splits
|
60 |
+
parser.add_argument("--train", default='karpathy_train')
|
61 |
+
parser.add_argument("--valid", default='karpathy_val')
|
62 |
+
parser.add_argument("--test", default='karpathy_test')
|
63 |
+
# parser.add_argument('--test_only', action='store_true')
|
64 |
+
|
65 |
+
# Quick experiments
|
66 |
+
parser.add_argument('--train_topk', type=int, default=-1)
|
67 |
+
parser.add_argument('--valid_topk', type=int, default=-1)
|
68 |
+
|
69 |
+
# Checkpoint
|
70 |
+
parser.add_argument('--output', type=str, default='snap/test')
|
71 |
+
parser.add_argument('--load', type=str, default=None, help='Load the model (usually the fine-tuned model).')
|
72 |
+
parser.add_argument('--from_scratch', action='store_true')
|
73 |
+
|
74 |
+
# CPU/GPU
|
75 |
+
parser.add_argument("--multiGPU", action='store_const', default=False, const=True)
|
76 |
+
parser.add_argument('--fp16', action='store_true')
|
77 |
+
parser.add_argument("--distributed", action='store_true')
|
78 |
+
parser.add_argument("--num_workers", default=0, type=int)
|
79 |
+
parser.add_argument('--local_rank', type=int, default=-1)
|
80 |
+
# parser.add_argument('--rank', type=int, default=-1)
|
81 |
+
|
82 |
+
# Model Config
|
83 |
+
# parser.add_argument('--encoder_backbone', type=str, default='openai/clip-vit-base-patch32')
|
84 |
+
# parser.add_argument('--decoder_backbone', type=str, default='bert-base-uncased')
|
85 |
+
parser.add_argument('--tokenizer', type=str, default='openai/clip-vit-base-patch32')
|
86 |
+
|
87 |
+
# parser.add_argument('--position_embedding_type', type=str, default='absolute')
|
88 |
+
|
89 |
+
# parser.add_argument('--encoder_transform', action='store_true')
|
90 |
+
|
91 |
+
parser.add_argument('--max_text_length', type=int, default=40)
|
92 |
+
|
93 |
+
# parser.add_argument('--image_size', type=int, default=224)
|
94 |
+
# parser.add_argument('--patch_size', type=int, default=32)
|
95 |
+
|
96 |
+
# parser.add_argument('--decoder_num_layers', type=int, default=12)
|
97 |
+
|
98 |
+
# Training
|
99 |
+
parser.add_argument('--batch_size', type=int, default=256)
|
100 |
+
parser.add_argument('--valid_batch_size', type=int, default=None)
|
101 |
+
|
102 |
+
parser.add_argument('--optim', default='adamw')
|
103 |
+
|
104 |
+
parser.add_argument('--warmup_ratio', type=float, default=0.05)
|
105 |
+
parser.add_argument('--weight_decay', type=float, default=0.01)
|
106 |
+
parser.add_argument('--clip_grad_norm', type=float, default=-1.0)
|
107 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
|
108 |
+
parser.add_argument('--lr', type=float, default=1e-4)
|
109 |
+
parser.add_argument('--adam_eps', type=float, default=1e-6)
|
110 |
+
parser.add_argument('--adam_beta1', type=float, default=0.9)
|
111 |
+
parser.add_argument('--adam_beta2', type=float, default=0.999)
|
112 |
+
|
113 |
+
parser.add_argument('--epochs', type=int, default=20)
|
114 |
+
# parser.add_argument('--dropout', type=float, default=0.1)
|
115 |
+
|
116 |
+
|
117 |
+
# Inference
|
118 |
+
# parser.add_argument('--num_beams', type=int, default=1)
|
119 |
+
# parser.add_argument('--gen_max_length', type=int, default=20)
|
120 |
+
|
121 |
+
parser.add_argument('--start_from', type=str, default=None)
|
122 |
+
|
123 |
+
# Data
|
124 |
+
# parser.add_argument('--do_lower_case', type=str2bool, default=None)
|
125 |
+
|
126 |
+
# parser.add_argument('--prefix', type=str, default=None)
|
127 |
+
|
128 |
+
|
129 |
+
# COCO Caption
|
130 |
+
# parser.add_argument('--no_prefix', action='store_true')
|
131 |
+
|
132 |
+
parser.add_argument('--no_cls', action='store_true')
|
133 |
+
|
134 |
+
parser.add_argument('--cfg', type=str, default=None)
|
135 |
+
parser.add_argument('--id', type=str, default=None)
|
136 |
+
|
137 |
+
# Etc.
|
138 |
+
parser.add_argument('--comment', type=str, default='')
|
139 |
+
parser.add_argument("--dry", action='store_true')
|
140 |
+
|
141 |
+
# Parse the arguments.
|
142 |
+
if parse:
|
143 |
+
args = parser.parse_args()
|
144 |
+
# For interative engironmnet (ex. jupyter)
|
145 |
+
else:
|
146 |
+
args = parser.parse_known_args()[0]
|
147 |
+
|
148 |
+
loaded_kwargs = {}
|
149 |
+
if args.cfg is not None:
|
150 |
+
cfg_path = f'configs/{args.cfg}.yaml'
|
151 |
+
with open(cfg_path, 'r') as f:
|
152 |
+
loaded_kwargs = yaml.safe_load(f)
|
153 |
+
|
154 |
+
# Namespace => Dictionary
|
155 |
+
parsed_kwargs = vars(args)
|
156 |
+
parsed_kwargs.update(optional_kwargs)
|
157 |
+
|
158 |
+
kwargs = {}
|
159 |
+
kwargs.update(parsed_kwargs)
|
160 |
+
kwargs.update(loaded_kwargs)
|
161 |
+
|
162 |
+
args = Config(**kwargs)
|
163 |
+
|
164 |
+
# Bind optimizer class.
|
165 |
+
verbose = False
|
166 |
+
args.optimizer = get_optimizer(args.optim, verbose=verbose)
|
167 |
+
|
168 |
+
# Set seeds
|
169 |
+
torch.manual_seed(args.seed)
|
170 |
+
random.seed(args.seed)
|
171 |
+
np.random.seed(args.seed)
|
172 |
+
|
173 |
+
return args
|
174 |
+
|
175 |
+
|
176 |
+
class Config(object):
|
177 |
+
def __init__(self, **kwargs):
|
178 |
+
"""Configuration Class: set kwargs as class attributes with setattr"""
|
179 |
+
for k, v in kwargs.items():
|
180 |
+
setattr(self, k, v)
|
181 |
+
|
182 |
+
@property
|
183 |
+
def config_str(self):
|
184 |
+
return pprint.pformat(self.__dict__)
|
185 |
+
|
186 |
+
def __repr__(self):
|
187 |
+
"""Pretty-print configurations in alphabetical order"""
|
188 |
+
config_str = 'Configurations\n'
|
189 |
+
config_str += self.config_str
|
190 |
+
return config_str
|
191 |
+
|
192 |
+
# def update(self, **kwargs):
|
193 |
+
# for k, v in kwargs.items():
|
194 |
+
# setattr(self, k, v)
|
195 |
+
|
196 |
+
# def save(self, path):
|
197 |
+
# with open(path, 'w') as f:
|
198 |
+
# yaml.dump(self.__dict__, f, default_flow_style=False)
|
199 |
+
|
200 |
+
# @classmethod
|
201 |
+
# def load(cls, path):
|
202 |
+
# with open(path, 'r') as f:
|
203 |
+
# kwargs = yaml.load(f)
|
204 |
+
|
205 |
+
# return Config(**kwargs)
|
206 |
+
|
207 |
+
|
208 |
+
if __name__ == '__main__':
|
209 |
+
args = parse_args(True)
|
retrieval/pth_loader.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import json
|
6 |
+
import h5py
|
7 |
+
from lmdbdict import lmdbdict
|
8 |
+
from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import numpy.random as npr
|
12 |
+
import random
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.utils.data as data
|
16 |
+
|
17 |
+
import multiprocessing
|
18 |
+
import six
|
19 |
+
|
20 |
+
verbose = True
|
21 |
+
# import torch
|
22 |
+
# if torch.cuda.current_device() in [0, -1]:
|
23 |
+
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
|
24 |
+
verbose = False
|
25 |
+
|
26 |
+
class HybridLoader:
|
27 |
+
"""
|
28 |
+
If db_path is a director, then use normal file loading
|
29 |
+
If lmdb, then load from lmdb
|
30 |
+
The loading method depend on extention.
|
31 |
+
|
32 |
+
in_memory: if in_memory is True, we save all the features in memory
|
33 |
+
For individual np(y|z)s, we don't need to do that because the system will do this for us.
|
34 |
+
Should be useful for lmdb or h5.
|
35 |
+
(Copied this idea from vilbert)
|
36 |
+
"""
|
37 |
+
def __init__(self, db_path, ext, in_memory=False):
|
38 |
+
self.db_path = db_path
|
39 |
+
self.ext = ext
|
40 |
+
if self.ext == '.npy':
|
41 |
+
self.loader = lambda x: np.load(six.BytesIO(x))
|
42 |
+
else:
|
43 |
+
self.loader = lambda x: np.load(six.BytesIO(x))['feat']
|
44 |
+
if db_path.endswith('.lmdb'):
|
45 |
+
self.db_type = 'lmdb'
|
46 |
+
self.lmdb = lmdbdict(db_path, unsafe=True)
|
47 |
+
self.lmdb._key_dumps = DUMPS_FUNC['ascii']
|
48 |
+
self.lmdb._value_loads = LOADS_FUNC['identity']
|
49 |
+
elif db_path.endswith('.pth'): # Assume a key,value dictionary
|
50 |
+
self.db_type = 'pth'
|
51 |
+
self.feat_file = torch.load(db_path)
|
52 |
+
self.loader = lambda x: x
|
53 |
+
print('HybridLoader: ext is ignored')
|
54 |
+
elif db_path.endswith('h5'):
|
55 |
+
self.db_type = 'h5'
|
56 |
+
self.loader = lambda x: np.array(x).astype('float32')
|
57 |
+
else:
|
58 |
+
self.db_type = 'dir'
|
59 |
+
|
60 |
+
self.in_memory = in_memory
|
61 |
+
if self.in_memory:
|
62 |
+
self.features = {}
|
63 |
+
|
64 |
+
def get(self, key):
|
65 |
+
|
66 |
+
if self.in_memory and key in self.features:
|
67 |
+
# We save f_input because we want to save the
|
68 |
+
# compressed bytes to save memory
|
69 |
+
f_input = self.features[key]
|
70 |
+
elif self.db_type == 'lmdb':
|
71 |
+
f_input = self.lmdb[key]
|
72 |
+
elif self.db_type == 'pth':
|
73 |
+
f_input = self.feat_file[key]
|
74 |
+
elif self.db_type == 'h5':
|
75 |
+
f_input = h5py.File(self.db_path, 'r')[key]
|
76 |
+
else:
|
77 |
+
f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
|
78 |
+
|
79 |
+
if self.in_memory and key not in self.features:
|
80 |
+
self.features[key] = f_input
|
81 |
+
|
82 |
+
# load image
|
83 |
+
feat = self.loader(f_input)
|
84 |
+
|
85 |
+
return feat
|
86 |
+
|
87 |
+
class CaptionDataset(data.Dataset):
|
88 |
+
|
89 |
+
def get_vocab_size(self):
|
90 |
+
return self.vocab_size
|
91 |
+
|
92 |
+
def get_vocab(self):
|
93 |
+
return self.ix_to_word
|
94 |
+
|
95 |
+
def get_seq_length(self):
|
96 |
+
return self.seq_length
|
97 |
+
|
98 |
+
def __init__(self, opt):
|
99 |
+
self.opt = opt
|
100 |
+
self.seq_per_img = opt.seq_per_img
|
101 |
+
|
102 |
+
# feature related options
|
103 |
+
self.use_fc = getattr(opt, 'use_fc', True)
|
104 |
+
self.use_att = getattr(opt, 'use_att', True)
|
105 |
+
self.use_box = getattr(opt, 'use_box', 0)
|
106 |
+
self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
|
107 |
+
self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
|
108 |
+
|
109 |
+
# load the json file which contains additional information about the dataset
|
110 |
+
if verbose:
|
111 |
+
print('DataLoader loading json file: ', opt.input_json)
|
112 |
+
self.info = json.load(open(self.opt.input_json))
|
113 |
+
if 'ix_to_word' in self.info:
|
114 |
+
self.ix_to_word = self.info['ix_to_word']
|
115 |
+
self.vocab_size = len(self.ix_to_word)
|
116 |
+
if verbose:
|
117 |
+
print('vocab size is ', self.vocab_size)
|
118 |
+
|
119 |
+
# open the hdf5 file
|
120 |
+
if verbose:
|
121 |
+
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
|
122 |
+
"""
|
123 |
+
Setting input_label_h5 to none is used when only doing generation.
|
124 |
+
For example, when you need to test on coco test set.
|
125 |
+
"""
|
126 |
+
if self.opt.input_label_h5 != 'none':
|
127 |
+
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
|
128 |
+
# load in the sequence data
|
129 |
+
seq_size = self.h5_label_file['labels'].shape
|
130 |
+
self.label = self.h5_label_file['labels'][:]
|
131 |
+
self.seq_length = seq_size[1]
|
132 |
+
if verbose:
|
133 |
+
print('max sequence length in data is', self.seq_length)
|
134 |
+
# load the pointers in full to RAM (should be small enough)
|
135 |
+
self.label_start_ix = self.h5_label_file['label_start_ix'][:]
|
136 |
+
self.label_end_ix = self.h5_label_file['label_end_ix'][:]
|
137 |
+
else:
|
138 |
+
self.seq_length = 1
|
139 |
+
|
140 |
+
self.data_in_memory = getattr(opt, 'data_in_memory', False)
|
141 |
+
self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
|
142 |
+
self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
|
143 |
+
self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
|
144 |
+
|
145 |
+
self.use_clipscore = getattr(opt, 'use_clipscore', False)
|
146 |
+
if self.use_clipscore:
|
147 |
+
self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory)
|
148 |
+
|
149 |
+
|
150 |
+
self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
|
151 |
+
if verbose:
|
152 |
+
print('read %d image features' %(self.num_images))
|
153 |
+
|
154 |
+
# separate out indexes for each of the provided splits
|
155 |
+
self.split_ix = {'train': [], 'val': [], 'test': []}
|
156 |
+
for ix in range(len(self.info['images'])):
|
157 |
+
img = self.info['images'][ix]
|
158 |
+
if not 'split' in img:
|
159 |
+
self.split_ix['train'].append(ix)
|
160 |
+
self.split_ix['val'].append(ix)
|
161 |
+
self.split_ix['test'].append(ix)
|
162 |
+
elif img['split'] == 'train':
|
163 |
+
self.split_ix['train'].append(ix)
|
164 |
+
elif img['split'] == 'val':
|
165 |
+
self.split_ix['val'].append(ix)
|
166 |
+
elif img['split'] == 'test':
|
167 |
+
self.split_ix['test'].append(ix)
|
168 |
+
elif opt.train_only == 0: # restval
|
169 |
+
self.split_ix['train'].append(ix)
|
170 |
+
|
171 |
+
if verbose:
|
172 |
+
print('assigned %d images to split train' %len(self.split_ix['train']))
|
173 |
+
print('assigned %d images to split val' %len(self.split_ix['val']))
|
174 |
+
print('assigned %d images to split test' %len(self.split_ix['test']))
|
175 |
+
|
176 |
+
def get_captions(self, ix, seq_per_img):
|
177 |
+
# fetch the sequence labels
|
178 |
+
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
|
179 |
+
ix2 = self.label_end_ix[ix] - 1
|
180 |
+
ncap = ix2 - ix1 + 1 # number of captions available for this image
|
181 |
+
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
|
182 |
+
|
183 |
+
if ncap < seq_per_img:
|
184 |
+
# we need to subsample (with replacement)
|
185 |
+
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
|
186 |
+
for q in range(seq_per_img):
|
187 |
+
ixl = random.randint(ix1,ix2)
|
188 |
+
seq[q, :] = self.label[ixl, :self.seq_length]
|
189 |
+
else:
|
190 |
+
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
|
191 |
+
seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
|
192 |
+
|
193 |
+
return seq
|
194 |
+
|
195 |
+
def collate_func(self, batch):
|
196 |
+
seq_per_img = self.seq_per_img
|
197 |
+
|
198 |
+
fc_batch = []
|
199 |
+
att_batch = []
|
200 |
+
label_batch = []
|
201 |
+
|
202 |
+
clip_vis_feat_batch = []
|
203 |
+
|
204 |
+
wrapped = False
|
205 |
+
|
206 |
+
infos = []
|
207 |
+
gts = []
|
208 |
+
|
209 |
+
for sample in batch:
|
210 |
+
# fetch image
|
211 |
+
if self.use_clipscore:
|
212 |
+
tmp_fc, tmp_att, tmp_seq, \
|
213 |
+
ix, tmp_clip_vis_feat = sample
|
214 |
+
|
215 |
+
clip_vis_feat_batch.append(tmp_clip_vis_feat)
|
216 |
+
else:
|
217 |
+
tmp_fc, tmp_att, tmp_seq, \
|
218 |
+
ix = sample
|
219 |
+
|
220 |
+
fc_batch.append(tmp_fc)
|
221 |
+
att_batch.append(tmp_att)
|
222 |
+
|
223 |
+
tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
|
224 |
+
if hasattr(self, 'h5_label_file'):
|
225 |
+
# if there is ground truth
|
226 |
+
tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
|
227 |
+
label_batch.append(tmp_label)
|
228 |
+
|
229 |
+
# Used for reward evaluation
|
230 |
+
if hasattr(self, 'h5_label_file'):
|
231 |
+
# if there is ground truth
|
232 |
+
gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
|
233 |
+
else:
|
234 |
+
gts.append([])
|
235 |
+
|
236 |
+
# record associated info as well
|
237 |
+
info_dict = {}
|
238 |
+
info_dict['ix'] = ix
|
239 |
+
info_dict['id'] = self.info['images'][ix]['id']
|
240 |
+
info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
|
241 |
+
infos.append(info_dict)
|
242 |
+
|
243 |
+
# #sort by att_feat length
|
244 |
+
# fc_batch, att_batch, label_batch, gts, infos = \
|
245 |
+
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
|
246 |
+
if self.use_clipscore:
|
247 |
+
fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \
|
248 |
+
zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True))
|
249 |
+
else:
|
250 |
+
fc_batch, att_batch, label_batch, gts, infos = \
|
251 |
+
zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
|
252 |
+
data = {}
|
253 |
+
data['fc_feats'] = np.stack(fc_batch)
|
254 |
+
# merge att_feats
|
255 |
+
max_att_len = max([_.shape[0] for _ in att_batch])
|
256 |
+
data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
|
257 |
+
for i in range(len(att_batch)):
|
258 |
+
data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
|
259 |
+
data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
|
260 |
+
for i in range(len(att_batch)):
|
261 |
+
data['att_masks'][i, :att_batch[i].shape[0]] = 1
|
262 |
+
# set att_masks to None if attention features have same length
|
263 |
+
if data['att_masks'].sum() == data['att_masks'].size:
|
264 |
+
data['att_masks'] = None
|
265 |
+
|
266 |
+
if self.use_clipscore:
|
267 |
+
data['clip_vis_feats'] = np.stack(clip_vis_feat_batch)
|
268 |
+
|
269 |
+
data['labels'] = np.vstack(label_batch)
|
270 |
+
# generate mask
|
271 |
+
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
|
272 |
+
mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
|
273 |
+
for ix, row in enumerate(mask_batch):
|
274 |
+
row[:nonzeros[ix]] = 1
|
275 |
+
data['masks'] = mask_batch
|
276 |
+
data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
|
277 |
+
data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
|
278 |
+
|
279 |
+
data['gts'] = gts # all ground truth captions of each images
|
280 |
+
data['infos'] = infos
|
281 |
+
|
282 |
+
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
|
283 |
+
|
284 |
+
return data
|
285 |
+
|
286 |
+
def __getitem__(self, ix):
|
287 |
+
"""This function returns a tuple that is further passed to collate_fn
|
288 |
+
"""
|
289 |
+
if self.use_att:
|
290 |
+
att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
|
291 |
+
# Reshape to K x C
|
292 |
+
att_feat = att_feat.reshape(-1, att_feat.shape[-1])
|
293 |
+
if self.norm_att_feat:
|
294 |
+
att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
|
295 |
+
if self.use_box:
|
296 |
+
box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
|
297 |
+
# devided by image width and height
|
298 |
+
x1,y1,x2,y2 = np.hsplit(box_feat, 4)
|
299 |
+
h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
|
300 |
+
box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
|
301 |
+
if self.norm_box_feat:
|
302 |
+
box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
|
303 |
+
att_feat = np.hstack([att_feat, box_feat])
|
304 |
+
# sort the features by the size of boxes
|
305 |
+
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
|
306 |
+
else:
|
307 |
+
att_feat = np.zeros((0,0), dtype='float32')
|
308 |
+
if self.use_fc:
|
309 |
+
try:
|
310 |
+
fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
|
311 |
+
except:
|
312 |
+
# Use average of attention when there is no fc provided (For bottomup feature)
|
313 |
+
fc_feat = att_feat.mean(0)
|
314 |
+
else:
|
315 |
+
fc_feat = np.zeros((0), dtype='float32')
|
316 |
+
if hasattr(self, 'h5_label_file'):
|
317 |
+
seq = self.get_captions(ix, self.seq_per_img)
|
318 |
+
else:
|
319 |
+
seq = None
|
320 |
+
|
321 |
+
if self.use_clipscore:
|
322 |
+
clip_vis_feat = self.clipscore_loader.get(
|
323 |
+
str(self.info['images'][ix]['id']))
|
324 |
+
|
325 |
+
return (fc_feat,
|
326 |
+
att_feat, seq,
|
327 |
+
ix, clip_vis_feat)
|
328 |
+
|
329 |
+
return (fc_feat,
|
330 |
+
att_feat, seq,
|
331 |
+
ix)
|
332 |
+
|
333 |
+
def __len__(self):
|
334 |
+
return len(self.info['images'])
|
retrieval/text_utils.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
def repeat(text, n_max_gram=3, n_max_repeat=3):
|
4 |
+
"""repeat n-grams"""
|
5 |
+
tokens = text.split()
|
6 |
+
|
7 |
+
n_gram = random.randint(1, n_max_gram)
|
8 |
+
|
9 |
+
repeat_token_idx = random.randint(0, len(tokens) - n_gram)
|
10 |
+
|
11 |
+
repeated_tokens = tokens[repeat_token_idx:repeat_token_idx+n_gram]
|
12 |
+
|
13 |
+
n_repeat = random.randint(1, n_max_repeat)
|
14 |
+
for _ in range(n_repeat):
|
15 |
+
insert_idx = random.randint(0, len(tokens))
|
16 |
+
tokens = tokens[:insert_idx] + \
|
17 |
+
repeated_tokens + tokens[insert_idx:]
|
18 |
+
|
19 |
+
new_text = " ".join(tokens)
|
20 |
+
return new_text
|
21 |
+
|
22 |
+
def remove(text, n_max_gram=3):
|
23 |
+
"""remove n-grams"""
|
24 |
+
tokens = text.split()
|
25 |
+
|
26 |
+
n_gram = random.randint(1, n_max_gram)
|
27 |
+
|
28 |
+
remove_token_idx = random.randint(0, len(tokens) - n_gram)
|
29 |
+
|
30 |
+
tokens = tokens[:remove_token_idx] + tokens[remove_token_idx + n_gram:]
|
31 |
+
|
32 |
+
new_text = " ".join(tokens)
|
33 |
+
return new_text
|
34 |
+
|
35 |
+
def insert(text, vocab, n_max_tokens=3):
|
36 |
+
"""Insert tokens"""
|
37 |
+
tokens = text.split()
|
38 |
+
|
39 |
+
n_insert_token = random.randint(1, n_max_tokens)
|
40 |
+
|
41 |
+
for _ in range(n_insert_token):
|
42 |
+
insert_token_idx = random.randint(0, len(tokens) - 1)
|
43 |
+
insert_token = random.choice(vocab)
|
44 |
+
tokens = tokens[:insert_token_idx] + [insert_token] + tokens[insert_token_idx:]
|
45 |
+
|
46 |
+
new_text = " ".join(tokens)
|
47 |
+
return new_text
|
48 |
+
|
49 |
+
def swap(text, vocab, n_max_tokens=3):
|
50 |
+
"""Swap tokens"""
|
51 |
+
tokens = text.split()
|
52 |
+
|
53 |
+
n_swap_tokens = random.randint(1, n_max_tokens)
|
54 |
+
|
55 |
+
for _ in range(n_swap_tokens):
|
56 |
+
swap_token_idx = random.randint(0, len(tokens) - 1)
|
57 |
+
|
58 |
+
swap_token = random.choice(vocab)
|
59 |
+
while swap_token == tokens[swap_token_idx]:
|
60 |
+
swap_token = random.choice(vocab)
|
61 |
+
|
62 |
+
tokens[swap_token_idx] = swap_token
|
63 |
+
|
64 |
+
new_text = " ".join(tokens)
|
65 |
+
return new_text
|
66 |
+
|
67 |
+
def shuffle(text):
|
68 |
+
"""shuffle tokens"""
|
69 |
+
tokens = text.split()
|
70 |
+
|
71 |
+
random.shuffle(tokens)
|
72 |
+
|
73 |
+
new_text = " ".join(tokens)
|
74 |
+
return new_text
|
retrieval/train_pl.py
ADDED
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ast import parse
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.optim as optim
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import time
|
10 |
+
import os
|
11 |
+
from collections import defaultdict
|
12 |
+
|
13 |
+
# import captioning.utils.opts as opts
|
14 |
+
# import captioning.models as models
|
15 |
+
# from captioning.data.pth_loader import CaptionDataset
|
16 |
+
# import captioning.utils.eval_utils as eval_utils
|
17 |
+
# import captioning.utils.misc as utils
|
18 |
+
# from captioning.utils.rewards import init_scorer, get_self_critical_reward
|
19 |
+
# from captioning.modules.loss_wrapper import LossWrapper
|
20 |
+
|
21 |
+
from clip_model import CLIPScore
|
22 |
+
from caption_data import COCORetrievalDataset
|
23 |
+
|
24 |
+
import pytorch_lightning as pl
|
25 |
+
|
26 |
+
import detectron2.utils.comm as d2comm
|
27 |
+
from detectron2.utils.env import seed_all_rng
|
28 |
+
seed_all_rng(1234)
|
29 |
+
|
30 |
+
|
31 |
+
class LitModel(pl.LightningModule):
|
32 |
+
def __init__(self, opt):
|
33 |
+
super().__init__()
|
34 |
+
self.opt = opt
|
35 |
+
self.args = args
|
36 |
+
# Intilaize dataset
|
37 |
+
# self.dataset = CaptionDataset(opt)
|
38 |
+
|
39 |
+
# self.dataset =
|
40 |
+
|
41 |
+
# opt.vocab_size = self.dataset.vocab_size
|
42 |
+
# opt.seq_length = self.dataset.seq_length
|
43 |
+
# self.batch_size = opt.batch_size
|
44 |
+
|
45 |
+
# Build model
|
46 |
+
# opt.vocab = self.dataset.get_vocab()
|
47 |
+
# model = models.setup(opt)
|
48 |
+
# print(model)
|
49 |
+
# del opt.vocab
|
50 |
+
|
51 |
+
# wrapper with loss in it.
|
52 |
+
# lw_model = LossWrapper(model, opt)
|
53 |
+
|
54 |
+
self.model = CLIPScore(use_grammar=opt.use_grammar, joint_out=opt.joint_out)
|
55 |
+
# self.lw_model = lw_model
|
56 |
+
|
57 |
+
for p in self.model.clip_model.vision_model.parameters():
|
58 |
+
p.requires_grad = False
|
59 |
+
for p in self.model.clip_model.visual_projection.parameters():
|
60 |
+
p.requires_grad = False
|
61 |
+
|
62 |
+
# self.struc_flag = None
|
63 |
+
# self.sc_flag = None
|
64 |
+
|
65 |
+
|
66 |
+
def forward(self, *args, **kwargs):
|
67 |
+
"""
|
68 |
+
I hate this design. Never pretend it as a nn.Module
|
69 |
+
"""
|
70 |
+
raise NotImplementedError
|
71 |
+
|
72 |
+
def train_dataloader(self):
|
73 |
+
# train_dataset = torch.utils.data.Subset(
|
74 |
+
# self.dataset,
|
75 |
+
# self.dataset.split_ix['train']
|
76 |
+
# )
|
77 |
+
|
78 |
+
# train_loader = torch.utils.data.DataLoader(
|
79 |
+
# dataset=train_dataset,
|
80 |
+
# batch_size=self.batch_size,
|
81 |
+
# shuffle=True,
|
82 |
+
# num_workers=4,
|
83 |
+
# collate_fn=self.dataset.collate_func
|
84 |
+
# )
|
85 |
+
|
86 |
+
train_dataset = COCORetrievalDataset(
|
87 |
+
split='karpathy_train', mode='train',
|
88 |
+
args=opt,
|
89 |
+
verbose=verbose
|
90 |
+
)
|
91 |
+
|
92 |
+
train_loader = torch.utils.data.DataLoader(
|
93 |
+
dataset=train_dataset,
|
94 |
+
batch_size=opt.batch_size,
|
95 |
+
shuffle=True,
|
96 |
+
num_workers=4,
|
97 |
+
collate_fn=train_dataset.collate_fn
|
98 |
+
)
|
99 |
+
|
100 |
+
return train_loader
|
101 |
+
|
102 |
+
def val_dataloader(self, split='karpathy_val'):
|
103 |
+
# val_dataset = torch.utils.data.Subset(
|
104 |
+
# self.dataset,
|
105 |
+
# self.dataset.split_ix[split]
|
106 |
+
# )
|
107 |
+
# val_loader = torch.utils.data.DataLoader(
|
108 |
+
# val_dataset,
|
109 |
+
# batch_size=self.batch_size,
|
110 |
+
# shuffle=False,
|
111 |
+
# num_workers=4,
|
112 |
+
# drop_last=False,
|
113 |
+
# collate_fn=self.dataset.collate_func
|
114 |
+
# )
|
115 |
+
|
116 |
+
val_dataset = COCORetrievalDataset(
|
117 |
+
split=split, mode='val',
|
118 |
+
args=opt,
|
119 |
+
verbose=verbose
|
120 |
+
)
|
121 |
+
|
122 |
+
val_loader = torch.utils.data.DataLoader(
|
123 |
+
dataset=val_dataset,
|
124 |
+
batch_size=opt.valid_batch_size,
|
125 |
+
shuffle=False,
|
126 |
+
num_workers=4,
|
127 |
+
drop_last=False,
|
128 |
+
collate_fn=val_dataset.collate_fn
|
129 |
+
)
|
130 |
+
|
131 |
+
return val_loader
|
132 |
+
|
133 |
+
def test_dataloader(self):
|
134 |
+
|
135 |
+
return self.val_dataloader('karpathy_test')
|
136 |
+
|
137 |
+
def training_step(self, data, batch_idx):
|
138 |
+
|
139 |
+
|
140 |
+
batch = data
|
141 |
+
self.model.train()
|
142 |
+
|
143 |
+
model_out = self.model.train_step(
|
144 |
+
img_feat=batch['img_feats'],
|
145 |
+
text=batch['text'],
|
146 |
+
neg_text=batch['neg_text'],
|
147 |
+
)
|
148 |
+
|
149 |
+
clip_loss = model_out['clip_loss']
|
150 |
+
|
151 |
+
if self.opt.joint_out:
|
152 |
+
loss = clip_loss
|
153 |
+
else:
|
154 |
+
grammar_loss = model_out['grammar_loss']
|
155 |
+
loss = clip_loss + grammar_loss
|
156 |
+
|
157 |
+
|
158 |
+
data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1]
|
159 |
+
data_time = torch.tensor(data_time)
|
160 |
+
|
161 |
+
# print('batch_idx', batch_idx)
|
162 |
+
# print('loss:', loss)
|
163 |
+
|
164 |
+
# logger_logs = model_out.copy()
|
165 |
+
logger_logs = {}
|
166 |
+
|
167 |
+
logger_logs['loss'] = loss.detach()
|
168 |
+
|
169 |
+
logger_logs['clip_loss'] = clip_loss.detach()
|
170 |
+
|
171 |
+
if not self.opt.joint_out:
|
172 |
+
logger_logs['grammar_loss'] = grammar_loss.detach()
|
173 |
+
|
174 |
+
logger_logs['data_time'] = data_time.detach()
|
175 |
+
|
176 |
+
# UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
|
177 |
+
# Please use self.log(...) inside the lightningModule instead.
|
178 |
+
|
179 |
+
# # log on a step or aggregate epoch metric to the logger and/or progress bar
|
180 |
+
# # (inside LightningModule)
|
181 |
+
# self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
182 |
+
# warnings.warn(*args, **kwargs)
|
183 |
+
# UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
|
184 |
+
# Please use self.log(...) inside the lightningModule instead.
|
185 |
+
|
186 |
+
# output = {
|
187 |
+
# 'loss': loss,
|
188 |
+
# 'log': logger_logs,
|
189 |
+
# 'progress_bar': {'data_time': data_time}
|
190 |
+
# }
|
191 |
+
|
192 |
+
for k, v in logger_logs.items():
|
193 |
+
if k in ['data_time', 'clip_loss', 'grammar_loss']:
|
194 |
+
self.log('train/'+k, v, prog_bar=True)
|
195 |
+
else:
|
196 |
+
self.log('train/'+k, v)
|
197 |
+
|
198 |
+
# print('training step logged')
|
199 |
+
|
200 |
+
return loss
|
201 |
+
|
202 |
+
def validation_step(self, data, batch_idx):
|
203 |
+
|
204 |
+
batch = data
|
205 |
+
self.model.eval()
|
206 |
+
|
207 |
+
with torch.no_grad():
|
208 |
+
model_out = self.model.train_step(
|
209 |
+
img_feat=batch['img_feats'],
|
210 |
+
text=batch['text'],
|
211 |
+
neg_text=batch['neg_text'],
|
212 |
+
)
|
213 |
+
|
214 |
+
if self.opt.joint_out:
|
215 |
+
clip_loss = model_out['clip_loss']
|
216 |
+
loss = clip_loss
|
217 |
+
|
218 |
+
output = {
|
219 |
+
# 'val_loss': loss,
|
220 |
+
'loss': loss.detach(),
|
221 |
+
'clip_loss': clip_loss.detach(),
|
222 |
+
# 'grammar_loss': grammar_loss.detach(),
|
223 |
+
|
224 |
+
'img_feat': model_out['img_feat'].detach(),
|
225 |
+
'text_feat': model_out['text_feat'].detach(),
|
226 |
+
# 'neg_text_feat': model_out['neg_text_feat'].detach(),
|
227 |
+
# 'grammar_pos_pred': model_out['grammar_pos_pred'].detach(),
|
228 |
+
# 'grammar_neg_pred': model_out['grammar_neg_pred'].detach(),
|
229 |
+
# 'predictions': predictions,
|
230 |
+
# 'n_predictions': n_predictions,
|
231 |
+
}
|
232 |
+
else:
|
233 |
+
clip_loss = model_out['clip_loss']
|
234 |
+
grammar_loss = model_out['grammar_loss']
|
235 |
+
loss = clip_loss + grammar_loss
|
236 |
+
|
237 |
+
output = {
|
238 |
+
# 'val_loss': loss,
|
239 |
+
'loss': loss.detach(),
|
240 |
+
'clip_loss': clip_loss.detach(),
|
241 |
+
'grammar_loss': grammar_loss.detach(),
|
242 |
+
|
243 |
+
'img_feat': model_out['img_feat'].detach(),
|
244 |
+
'text_feat': model_out['text_feat'].detach(),
|
245 |
+
# 'neg_text_feat': model_out['neg_text_feat'].detach(),
|
246 |
+
'grammar_pos_pred': model_out['grammar_pos_pred'].detach(),
|
247 |
+
'grammar_neg_pred': model_out['grammar_neg_pred'].detach(),
|
248 |
+
# 'predictions': predictions,
|
249 |
+
# 'n_predictions': n_predictions,
|
250 |
+
}
|
251 |
+
return output
|
252 |
+
|
253 |
+
def test_step(self, *args, **kwargs):
|
254 |
+
return self.validation_step(*args, **kwargs)
|
255 |
+
|
256 |
+
def validation_epoch_end(self, outputs, split='val'):
|
257 |
+
outputs = d2comm.gather(outputs)
|
258 |
+
# master node
|
259 |
+
if d2comm.is_main_process():
|
260 |
+
assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0
|
261 |
+
outputs = sum(outputs, [])
|
262 |
+
|
263 |
+
out = {}
|
264 |
+
|
265 |
+
val_loss_mean = sum([_['loss'].cpu() for _ in outputs]) / len(outputs)
|
266 |
+
val_clip_loss_mean = sum([_['clip_loss'].cpu() for _ in outputs]) / len(outputs)
|
267 |
+
if not self.opt.joint_out:
|
268 |
+
val_grammar_loss_mean = sum([_['grammar_loss'].cpu() for _ in outputs]) / len(outputs)
|
269 |
+
|
270 |
+
print('loss', val_loss_mean.item())
|
271 |
+
print('clip_loss', val_clip_loss_mean.item())
|
272 |
+
if not self.opt.joint_out:
|
273 |
+
print('grammar_loss', val_grammar_loss_mean.item())
|
274 |
+
|
275 |
+
logit_scale = self.model.clip_model.logit_scale.exp().cpu()
|
276 |
+
|
277 |
+
text_feats = torch.cat([_['text_feat'].cpu() for _ in outputs], dim=0)
|
278 |
+
img_feats = torch.cat([_['img_feat'].cpu() for _ in outputs], dim=0)
|
279 |
+
|
280 |
+
assert text_feats.size() == (5000, 512), text_feats.size()
|
281 |
+
assert img_feats.size() == (5000, 512), img_feats.size()
|
282 |
+
|
283 |
+
logits_per_text = torch.matmul(text_feats, img_feats.t()) * logit_scale
|
284 |
+
logits_per_image = logits_per_text.T
|
285 |
+
|
286 |
+
# text-to-image retrieval
|
287 |
+
print('Text-to-Image retrieval')
|
288 |
+
for k in [1, 5, 10]:
|
289 |
+
text_to_image_topk = logits_per_text.topk(k, dim=1).indices
|
290 |
+
|
291 |
+
n_text = len(text_to_image_topk)
|
292 |
+
|
293 |
+
labels = torch.arange(0, n_text).view(-1, 1)
|
294 |
+
|
295 |
+
n_retrieved = ((text_to_image_topk == labels).sum(dim=1) > 0).sum()
|
296 |
+
|
297 |
+
recall_k = n_retrieved / n_text * 100
|
298 |
+
|
299 |
+
out[f'text_to_image_recall_{k}'] = recall_k.item()
|
300 |
+
|
301 |
+
print(f'R@{k}: {recall_k.item():.2f}%')
|
302 |
+
|
303 |
+
# image-to-text retrieval
|
304 |
+
print('Image-to-Text retrieval')
|
305 |
+
for k in [1, 5, 10]:
|
306 |
+
image_to_text_topk = logits_per_image.topk(k, dim=1).indices
|
307 |
+
|
308 |
+
n_image = len(image_to_text_topk)
|
309 |
+
|
310 |
+
labels = torch.arange(0, n_image).view(-1, 1)
|
311 |
+
|
312 |
+
n_retrieved = ((image_to_text_topk == labels).sum(dim=1) > 0).sum()
|
313 |
+
|
314 |
+
recall_k = n_retrieved / n_image * 100
|
315 |
+
|
316 |
+
out[f'image_to_text_recall_{k}'] = recall_k.item()
|
317 |
+
|
318 |
+
print(f'R@{k}: {recall_k.item():.2f}%')
|
319 |
+
|
320 |
+
out.update({
|
321 |
+
'loss': val_loss_mean.item(),
|
322 |
+
'clip_loss': val_clip_loss_mean.item()
|
323 |
+
})
|
324 |
+
|
325 |
+
if not self.opt.joint_out:
|
326 |
+
# grammar scoring
|
327 |
+
grammar_pos_pred = torch.cat([_['grammar_pos_pred'].cpu() for _ in outputs], dim=0)
|
328 |
+
grammar_neg_pred = torch.cat([_['grammar_neg_pred'].cpu() for _ in outputs], dim=0)
|
329 |
+
|
330 |
+
TP = (grammar_pos_pred == 1).sum().item()
|
331 |
+
FP = (grammar_pos_pred == 0).sum().item()
|
332 |
+
FN = (grammar_neg_pred == 1).sum().item()
|
333 |
+
TN = (grammar_neg_pred == 0).sum().item()
|
334 |
+
print('Grammar check')
|
335 |
+
print(f'TP: {TP} FP: {FP} FN: {FN} TN: {TN}')
|
336 |
+
|
337 |
+
precision = TP / (TP + FP) * 100
|
338 |
+
recall = TP / (TP + FN) * 100
|
339 |
+
accuracy = (TP + TN) / (TP + FP + FN + TN) * 100
|
340 |
+
f1 = 2 * precision * recall / (precision + recall)
|
341 |
+
print(f'Precision: {precision:.2f}%')
|
342 |
+
print(f'Recall: {recall:.2f}%')
|
343 |
+
print(f'Accuracy: {accuracy:.2f}%')
|
344 |
+
print(f'F1: {f1:.2f}%')
|
345 |
+
print('Total: {}'.format(len(grammar_pos_pred)))
|
346 |
+
|
347 |
+
out.update({
|
348 |
+
'grammar_loss': val_grammar_loss_mean,
|
349 |
+
|
350 |
+
'grammar_precision': precision,
|
351 |
+
'grammar_recall': recall,
|
352 |
+
'grammar_accuracy': accuracy,
|
353 |
+
'grammar_f1': f1,
|
354 |
+
|
355 |
+
})
|
356 |
+
|
357 |
+
else:
|
358 |
+
out = {}
|
359 |
+
|
360 |
+
out = d2comm.all_gather(out)[0] # Only the one from master node
|
361 |
+
assert len(out) > 0 # make sure the head has index 0
|
362 |
+
|
363 |
+
# must all be tensors
|
364 |
+
out = {k: torch.tensor(v) if not torch.is_tensor(
|
365 |
+
v) else v for k, v in out.items()}
|
366 |
+
|
367 |
+
for k, v in out.items():
|
368 |
+
self.log(f'{split}/{k}', v)
|
369 |
+
|
370 |
+
def test_epoch_end(self, outputs):
|
371 |
+
|
372 |
+
self.validation_epoch_end(outputs, 'test')
|
373 |
+
|
374 |
+
def configure_optimizers(self):
|
375 |
+
# opt = self.opt
|
376 |
+
# model = self.model
|
377 |
+
|
378 |
+
# parameters = [p for p in model.parameters() if p.requires_grad]
|
379 |
+
|
380 |
+
# if opt.noamopt:
|
381 |
+
# # assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer'
|
382 |
+
# optimizer = utils.get_std_opt(
|
383 |
+
# model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
|
384 |
+
# elif opt.reduce_on_plateau:
|
385 |
+
# # optimizer = utils.build_optimizer(model.parameters(), opt)
|
386 |
+
# optimizer = utils.build_optimizer(parameters, opt)
|
387 |
+
# optimizer = utils.ReduceLROnPlateau(optimizer,
|
388 |
+
# factor=opt.reduce_on_plateau_factor,
|
389 |
+
# patience=opt.reduce_on_plateau_patience)
|
390 |
+
# else:
|
391 |
+
# # optimizer = utils.build_optimizer(model.parameters(), opt)
|
392 |
+
# optimizer = utils.build_optimizer(parameters, opt)
|
393 |
+
|
394 |
+
|
395 |
+
# from transformers.optimization import AdamW, get_linear_schedule_with_warmup
|
396 |
+
# batch_per_epoch = len(self.train_loader)
|
397 |
+
# t_total = batch_per_epoch // self.args.gradient_accumulation_steps * self.args.epochs
|
398 |
+
# warmup_ratio = self.args.warmup_ratio
|
399 |
+
# warmup_iters = int(t_total * warmup_ratio)
|
400 |
+
# if self.verbose:
|
401 |
+
# print("Batch per epoch: %d" % batch_per_epoch)
|
402 |
+
# print("Total Iters: %d" % t_total)
|
403 |
+
# print('Warmup ratio:', warmup_ratio)
|
404 |
+
# print("Warm up Iters: %d" % warmup_iters)
|
405 |
+
|
406 |
+
if self.args.optim == 'adamw':
|
407 |
+
no_decay = ["bias", "LayerNorm.weight"]
|
408 |
+
optimizer_grouped_parameters = [
|
409 |
+
{
|
410 |
+
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
411 |
+
"weight_decay": self.args.weight_decay,
|
412 |
+
},
|
413 |
+
{
|
414 |
+
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
415 |
+
"weight_decay": 0.0,
|
416 |
+
},
|
417 |
+
]
|
418 |
+
|
419 |
+
for group in optimizer_grouped_parameters:
|
420 |
+
group['params'] = [p for p in group['params'] if p.requires_grad]
|
421 |
+
|
422 |
+
from transformers.optimization import AdamW
|
423 |
+
optim = AdamW(optimizer_grouped_parameters,
|
424 |
+
lr=self.args.lr, eps=self.args.adam_eps)
|
425 |
+
# lr_scheduler = get_linear_schedule_with_warmup(
|
426 |
+
# optim, warmup_iters, t_total)
|
427 |
+
|
428 |
+
# optimizers = []
|
429 |
+
optimizers = [optim]
|
430 |
+
lr_schedulers = []
|
431 |
+
|
432 |
+
return optimizers, lr_schedulers
|
433 |
+
|
434 |
+
def optimizer_step(self, epoch, batch_idx, optimizer,
|
435 |
+
optimizer_idx, *args, **kwargs):
|
436 |
+
# # warm up lr
|
437 |
+
# opt = self.opt
|
438 |
+
# iteration = self.trainer.global_step
|
439 |
+
# if opt.use_warmup and (iteration < opt.noamopt_warmup):
|
440 |
+
# opt.current_lr = opt.learning_rate * \
|
441 |
+
# (iteration+1) / opt.noamopt_warmup
|
442 |
+
# utils.set_lr(optimizer, opt.current_lr)
|
443 |
+
|
444 |
+
super().optimizer_step(epoch, batch_idx, optimizer,
|
445 |
+
optimizer_idx, *args, **kwargs)
|
446 |
+
|
447 |
+
# print('optimizer step')
|
448 |
+
|
449 |
+
def state_dict(self):
|
450 |
+
"""
|
451 |
+
Save the model state dict as well as opt and vocab
|
452 |
+
"""
|
453 |
+
state_dict = self.model.state_dict()
|
454 |
+
device = next(iter(state_dict.values())).device
|
455 |
+
assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case'
|
456 |
+
# state_dict.update({
|
457 |
+
# '_vocab': utils.serialize_to_tensor(self.model.vocab).to(device),
|
458 |
+
# '_opt': utils.serialize_to_tensor(self.opt).to(device)
|
459 |
+
# })
|
460 |
+
return state_dict
|
461 |
+
|
462 |
+
def load_state_dict(self, state_dict=None, strict=True):
|
463 |
+
# if '_vocab' in state_dict:
|
464 |
+
# self.model.vocab = utils.deserialize(state_dict['_vocab'])
|
465 |
+
# del state_dict['_vocab']
|
466 |
+
# elif strict:
|
467 |
+
# raise KeyError
|
468 |
+
# if '_opt' in state_dict:
|
469 |
+
# saved_model_opt = utils.deserialize(state_dict['_opt'])
|
470 |
+
# del state_dict['_opt']
|
471 |
+
# opt = self.opt
|
472 |
+
# # Make sure the saved opt is compatible with the curren topt
|
473 |
+
# need_be_same = ["caption_model",
|
474 |
+
# "rnn_type", "rnn_size", "num_layers"]
|
475 |
+
# for checkme in need_be_same:
|
476 |
+
# if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
|
477 |
+
# getattr(opt, checkme) in ['updown', 'topdown']:
|
478 |
+
# continue
|
479 |
+
# assert getattr(saved_model_opt, checkme) == getattr(
|
480 |
+
# opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
|
481 |
+
# elif strict:
|
482 |
+
# raise KeyError
|
483 |
+
self.model.load_state_dict(state_dict, strict)
|
484 |
+
|
485 |
+
|
486 |
+
class OnEpochStartCallback(pl.Callback):
|
487 |
+
|
488 |
+
def on_epoch_start(self, trainer, pl_module):
|
489 |
+
# Update lr/training stage/scheduled sampling prob etc.
|
490 |
+
opt = pl_module.opt
|
491 |
+
model = pl_module.model
|
492 |
+
epoch = trainer.current_epoch
|
493 |
+
optimizer = trainer.optimizers[0]
|
494 |
+
|
495 |
+
# if not opt.noamopt and not opt.reduce_on_plateau:
|
496 |
+
# # Assign the learning rate
|
497 |
+
# if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
|
498 |
+
# frac = (
|
499 |
+
# epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
|
500 |
+
# decay_factor = opt.learning_rate_decay_rate ** frac
|
501 |
+
# opt.current_lr = opt.learning_rate * decay_factor
|
502 |
+
# else:
|
503 |
+
# opt.current_lr = opt.learning_rate
|
504 |
+
# utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
|
505 |
+
# # Assign the scheduled sampling prob
|
506 |
+
# if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
|
507 |
+
# frac = (
|
508 |
+
# epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
|
509 |
+
# opt.ss_prob = min(opt.scheduled_sampling_increase_prob *
|
510 |
+
# frac, opt.scheduled_sampling_max_prob)
|
511 |
+
# model.ss_prob = opt.ss_prob
|
512 |
+
|
513 |
+
# # If start self critical training
|
514 |
+
# if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
|
515 |
+
# sc_flag = True
|
516 |
+
# init_scorer(opt.cached_tokens)
|
517 |
+
# else:
|
518 |
+
# sc_flag = False
|
519 |
+
|
520 |
+
# # If start structure loss training
|
521 |
+
# if opt.structure_after != -1 and epoch >= opt.structure_after:
|
522 |
+
# struc_flag = True
|
523 |
+
# init_scorer(opt.cached_tokens)
|
524 |
+
# else:
|
525 |
+
# struc_flag = False
|
526 |
+
|
527 |
+
# pl_module.struc_flag = struc_flag
|
528 |
+
# pl_module.sc_flag = sc_flag
|
529 |
+
|
530 |
+
|
531 |
+
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
|
532 |
+
|
533 |
+
def on_keyboard_interrupt(self, trainer, pl_module):
|
534 |
+
# Save model when keyboard interrupt
|
535 |
+
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
|
536 |
+
self._save_model(filepath)
|
537 |
+
|
538 |
+
from param import parse_args
|
539 |
+
# opt = opts.parse_opt()
|
540 |
+
args = parse_args()
|
541 |
+
opt = args
|
542 |
+
|
543 |
+
checkpoint_callback = ModelCheckpoint(
|
544 |
+
filepath=opt.checkpoint_dir + '{epoch:02d}',
|
545 |
+
# dirpath=opt.checkpoint_path,
|
546 |
+
save_last=True,
|
547 |
+
save_top_k=1,
|
548 |
+
verbose=True,
|
549 |
+
# monitor='to_monitor',
|
550 |
+
# monitor='val/to_monitor',
|
551 |
+
# monitor='val/CIDEr',
|
552 |
+
monitor='val/loss',
|
553 |
+
mode='min',
|
554 |
+
# prefix=opt.id+'_',
|
555 |
+
prefix=opt.id,
|
556 |
+
# filename=f'{opt.id}_',
|
557 |
+
)
|
558 |
+
|
559 |
+
verbose = True
|
560 |
+
# import torch
|
561 |
+
# if torch.cuda.current_device() in [0, -1]:
|
562 |
+
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
|
563 |
+
verbose = False
|
564 |
+
|
565 |
+
# if verbose:
|
566 |
+
# print(opt)
|
567 |
+
# print("""
|
568 |
+
# val_image_use,
|
569 |
+
# save_checkpoint_very
|
570 |
+
# save_every_epoch,
|
571 |
+
# save_history-ckpt will be ignored.
|
572 |
+
# """)
|
573 |
+
|
574 |
+
# Lightning defines batch size as batch size per gpu
|
575 |
+
assert opt.batch_size % torch.cuda.device_count() == 0
|
576 |
+
opt.batch_size = opt.batch_size // torch.cuda.device_count()
|
577 |
+
opt.valid_batch_size = opt.valid_batch_size // torch.cuda.device_count()
|
578 |
+
|
579 |
+
# If resume from last checkpoint
|
580 |
+
# if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}_last.ckpt')):
|
581 |
+
# resume_from = os.path.join(opt.start_from, f'{opt.id}_last.ckpt')
|
582 |
+
if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}-last.ckpt')):
|
583 |
+
resume_from = os.path.join(opt.start_from, f'{opt.id}-last.ckpt')
|
584 |
+
if verbose:
|
585 |
+
print('resume from', resume_from)
|
586 |
+
else:
|
587 |
+
resume_from = None
|
588 |
+
|
589 |
+
from pytorch_lightning.loggers import WandbLogger
|
590 |
+
wandb_logger = WandbLogger(
|
591 |
+
# project='CLIP-ViL-COCOCaption',
|
592 |
+
project='CLIP-Finetune-COCO',
|
593 |
+
name=opt.id,
|
594 |
+
)
|
595 |
+
|
596 |
+
if verbose:
|
597 |
+
wandb_logger.experiment.config.update(opt)
|
598 |
+
from pathlib import Path
|
599 |
+
import glob
|
600 |
+
import wandb
|
601 |
+
# src_dir = Path(__file__).resolve().parent.parent
|
602 |
+
glob_str = "*.py"
|
603 |
+
base_path = './'
|
604 |
+
wandb.save(glob_str=glob_str, base_path=base_path)
|
605 |
+
|
606 |
+
glob_str = "**/*.yaml"
|
607 |
+
base_path = './'
|
608 |
+
wandb.save(glob_str=glob_str, base_path=base_path)
|
609 |
+
|
610 |
+
# code = wandb.Artifact('project-source', type='code')
|
611 |
+
# for path in glob.glob('**/*.py', recursive=True):
|
612 |
+
# code.add_file(path, name='source/'+path)
|
613 |
+
# print(path)
|
614 |
+
# wandb.run.use_artifact(code)
|
615 |
+
|
616 |
+
|
617 |
+
|
618 |
+
|
619 |
+
lit = LitModel(opt)
|
620 |
+
# warning grad_clip_mode is ignored.
|
621 |
+
trainer = pl.Trainer(
|
622 |
+
callbacks=[
|
623 |
+
OnEpochStartCallback(),
|
624 |
+
# pl.callbacks.lr_logger.LearningRateLogger()
|
625 |
+
pl.callbacks.LearningRateMonitor()
|
626 |
+
],
|
627 |
+
default_root_dir=opt.checkpoint_dir,
|
628 |
+
resume_from_checkpoint=resume_from,
|
629 |
+
|
630 |
+
distributed_backend='ddp',
|
631 |
+
gpus=torch.cuda.device_count(),
|
632 |
+
|
633 |
+
# gpus=1,
|
634 |
+
|
635 |
+
check_val_every_n_epoch=1,
|
636 |
+
# max_epochs=opt.max_epochs,
|
637 |
+
max_epochs=opt.epochs,
|
638 |
+
# gradient_clip_val=opt.grad_clip_value,
|
639 |
+
gradient_clip_val=opt.clip_grad_norm,
|
640 |
+
|
641 |
+
checkpoint_callback=checkpoint_callback,
|
642 |
+
log_gpu_memory='min_max',
|
643 |
+
# log_save_interval=opt.losses_log_every,
|
644 |
+
log_every_n_steps=opt.losses_log_every,
|
645 |
+
profiler=True,
|
646 |
+
# profiler='simple',
|
647 |
+
# row_log_interval=10, # what is it?
|
648 |
+
flush_logs_every_n_steps=10,
|
649 |
+
num_sanity_val_steps=0,
|
650 |
+
# val_check_interval=0.01,
|
651 |
+
# limit_train_batches=500,
|
652 |
+
# progress_bar_refresh_rate=0,
|
653 |
+
# fast_dev_run=True,
|
654 |
+
precision=opt.precision,
|
655 |
+
logger=wandb_logger
|
656 |
+
)
|
657 |
+
|
658 |
+
if os.getenv('EVALUATE', '0') == '1':
|
659 |
+
trainer.test(lit)
|
660 |
+
else:
|
661 |
+
trainer.fit(lit)
|
save/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Directory for checkpoints
|