taesiri commited on
Commit
bbd199b
1 Parent(s): 56f7845

initial commit

Browse files
Files changed (5) hide show
  1. ExtractEmbedding.py +59 -0
  2. README.md +2 -2
  3. SaveEmbedding.py +100 -0
  4. SimSearch.py +66 -0
  5. app.py +79 -0
ExtractEmbedding.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import torchvision
6
+ import torch.nn.functional as F
7
+ from torchvision.datasets import ImageFolder
8
+ import torchvision.transforms as transforms
9
+ from tqdm import tqdm
10
+ import pickle
11
+ import argparse
12
+ from PIL import Image
13
+
14
+ concat = lambda x: np.concatenate(x, axis=0)
15
+ to_np = lambda x: x.data.to("cpu").numpy()
16
+
17
+
18
+ class Wrapper(torch.nn.Module):
19
+ def __init__(self, model):
20
+ super(Wrapper, self).__init__()
21
+ self.model = model
22
+ self.avgpool_output = None
23
+ self.query = None
24
+ self.cossim_value = {}
25
+
26
+ def fw_hook(module, input, output):
27
+ self.avgpool_output = output.squeeze()
28
+
29
+ self.model.avgpool.register_forward_hook(fw_hook)
30
+
31
+ def forward(self, input):
32
+ _ = self.model(input)
33
+ return self.avgpool_output
34
+
35
+ def __repr__(self):
36
+ return "Wrappper"
37
+
38
+
39
+ def QueryToEmbedding(query_pil):
40
+ dataset_transform = transforms.Compose(
41
+ [
42
+ transforms.Resize(256),
43
+ transforms.CenterCrop(224),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
46
+ ]
47
+ )
48
+
49
+ model = torchvision.models.resnet50(pretrained=True)
50
+ model.eval()
51
+ myw = Wrapper(model)
52
+
53
+ # query_pil = Image.open(query_path)
54
+ query_pt = dataset_transform(query_pil)
55
+
56
+ with torch.no_grad():
57
+ embedding = to_np(myw(query_pt.unsqueeze(0)))
58
+
59
+ return np.asarray([embedding])
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: CHM Corr
3
  emoji: 🐨
4
  colorFrom: yellow
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.1.1
8
  app_file: app.py
 
1
  ---
2
+ title: CHM-Corr
3
  emoji: 🐨
4
  colorFrom: yellow
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.1.1
8
  app_file: app.py
SaveEmbedding.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import torchvision
6
+ import torch.nn.functional as F
7
+ from torchvision.datasets import ImageFolder
8
+ import torchvision.transforms as transforms
9
+ from tqdm import tqdm
10
+ import pickle
11
+ import argparse
12
+
13
+
14
+ concat = lambda x: np.concatenate(x, axis=0)
15
+ to_np = lambda x: x.data.to("cpu").numpy()
16
+
17
+
18
+ class Wrapper(torch.nn.Module):
19
+ def __init__(self, model):
20
+ super(Wrapper, self).__init__()
21
+ self.model = model
22
+ self.avgpool_output = None
23
+ self.query = None
24
+ self.cossim_value = {}
25
+
26
+ def fw_hook(module, input, output):
27
+ self.avgpool_output = output.squeeze()
28
+
29
+ self.model.avgpool.register_forward_hook(fw_hook)
30
+
31
+ def forward(self, input):
32
+ _ = self.model(input)
33
+ return self.avgpool_output
34
+
35
+ def __repr__(self):
36
+ return "Wrappper"
37
+
38
+
39
+ def run(training_set_path):
40
+ # Standard ImageNet Transformer to apply imagenet's statistics to input batch
41
+ dataset_transform = transforms.Compose(
42
+ [
43
+ transforms.Resize(256),
44
+ transforms.CenterCrop(224),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
47
+ ]
48
+ )
49
+
50
+ training_imagefolder = ImageFolder(
51
+ root=training_set_path, transform=dataset_transform
52
+ )
53
+ train_loader = torch.utils.data.DataLoader(
54
+ training_imagefolder,
55
+ batch_size=512,
56
+ shuffle=False,
57
+ num_workers=2,
58
+ pin_memory=True,
59
+ )
60
+ print(f"# of Training folder samples: {len(training_imagefolder)}")
61
+ ########################################################################################################################
62
+ model = torchvision.models.resnet50(pretrained=True)
63
+ model.eval()
64
+ myw = Wrapper(model)
65
+
66
+ training_embeddings = []
67
+ training_labels = []
68
+
69
+ with torch.no_grad():
70
+ for _, (data, target) in enumerate(tqdm(train_loader)):
71
+ embeddings = to_np(myw(data))
72
+ labels = to_np(target)
73
+
74
+ training_embeddings.append(embeddings)
75
+ training_labels.append(labels)
76
+
77
+ training_embeddings_concatted = concat(training_embeddings)
78
+ training_labels_concatted = concat(training_labels)
79
+
80
+ print(training_embeddings_concatted.shape)
81
+ return training_embeddings_concatted, training_labels_concatted
82
+
83
+
84
+ def main():
85
+ parser = argparse.ArgumentParser(description="Saving Embeddings")
86
+ parser.add_argument("--train", help="Path to the Dataaset", type=str, required=True)
87
+ args = parser.parse_args()
88
+
89
+ embeddings, labels = run(args.train)
90
+
91
+ # Caluclate Accuracy
92
+ with open(f"embeddings.pickle", "wb") as f:
93
+ pickle.dump(embeddings, f)
94
+
95
+ with open(f"labels.pickle", "wb") as f:
96
+ pickle.dump(labels, f)
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
SimSearch.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+
4
+
5
+ class FaissNeighbors:
6
+ def __init__(self):
7
+ self.index = None
8
+ self.y = None
9
+
10
+ def fit(self, X, y):
11
+ self.index = faiss.IndexFlatL2(X.shape[1])
12
+ self.index.add(X.astype(np.float32))
13
+ self.y = y
14
+
15
+ def get_distances_and_indices(self, X, top_K=1000):
16
+ distances, indices = self.index.search(X.astype(np.float32), k=top_K)
17
+ return np.copy(distances), np.copy(indices), np.copy(self.y[indices])
18
+
19
+ def get_nearest_labels(self, X, top_K=1000):
20
+ distances, indices = self.index.search(X.astype(np.float32), k=top_K)
21
+ return np.copy(self.y[indices])
22
+
23
+
24
+ class FaissCosineNeighbors:
25
+ def __init__(self):
26
+ self.cindex = None
27
+ self.y = None
28
+
29
+ def fit(self, X, y):
30
+ self.cindex = faiss.index_factory(
31
+ X.shape[1], "Flat", faiss.METRIC_INNER_PRODUCT
32
+ )
33
+ X = np.copy(X)
34
+ X = X.astype(np.float32)
35
+ faiss.normalize_L2(X)
36
+ self.cindex.add(X)
37
+ self.y = y
38
+
39
+ def get_distances_and_indices(self, Q, topK):
40
+ Q = np.copy(Q)
41
+ faiss.normalize_L2(Q)
42
+ distances, indices = self.cindex.search(Q.astype(np.float32), k=topK)
43
+ return np.copy(distances), np.copy(indices), np.copy(self.y[indices])
44
+
45
+ def get_nearest_labels(self, Q, topK=1000):
46
+ Q = np.copy(Q)
47
+ faiss.normalize_L2(Q)
48
+ distances, indices = self.cindex.search(Q.astype(np.float32), k=topK)
49
+ return np.copy(self.y[indices])
50
+
51
+
52
+ class SearchableTrainingSet:
53
+ def __init__(self, embeddings, labels):
54
+ self.simsearcher = FaissCosineNeighbors()
55
+ self.X_train = embeddings
56
+ self.y_train = labels
57
+
58
+ def build_index(self):
59
+ self.simsearcher.fit(self.X_train, self.y_train)
60
+
61
+ def search(self, query, k=20):
62
+ nearest_data_points = self.simsearcher.get_distances_and_indices(
63
+ Q=query, topK=100
64
+ )
65
+ # topKs = [x[0] for x in Counter(nearest_data_points[0]).most_common(k)]
66
+ return nearest_data_points
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from collections import Counter
3
+ import numpy as np
4
+ import gradio as gr
5
+ import gdown
6
+ import torchvision
7
+ from torchvision.datasets import ImageFolder
8
+
9
+ from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
10
+ from ExtractEmbedding import QueryToEmbedding
11
+
12
+ concat = lambda x: np.concatenate(x, axis=0)
13
+
14
+ gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89")
15
+ gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")
16
+
17
+ # CUB training set
18
+ gdown.cached_download(
19
+ url="https://drive.google.com/uc?id=1iR19j7532xqPefWYT-BdtcaKnsEokIqo",
20
+ path="./CUB_train.zip",
21
+ quiet=False,
22
+ md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
23
+ )
24
+
25
+ # EXTRACT
26
+ torchvision.datasets.utils.extract_archive(
27
+ from_path="CUB_train.zip",
28
+ to_path="Training/",
29
+ remove_finished=False,
30
+ )
31
+
32
+
33
+ # Caluclate Accuracy
34
+ with open(f"./embeddings.pickle", "rb") as f:
35
+ Xtrain = pickle.load(f)
36
+ # FIXME: re-run the code to get the embeddings in the right format
37
+ with open(f"./labels.pickle", "rb") as f:
38
+ ytrain = pickle.load(f)
39
+
40
+ searcher = SearchableTrainingSet(Xtrain, ytrain)
41
+ searcher.build_index()
42
+
43
+ # Extract label names
44
+ training_folder = ImageFolder(root="./Training/train/")
45
+ id_to_bird_name = {
46
+ x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
47
+ }
48
+
49
+
50
+ def search(query_imag, searcher=searcher):
51
+ query_embedding = QueryToEmbedding(query_imag)
52
+ indices, scores, labels = searcher.search(query_embedding, k=50)
53
+
54
+ result_ctr = Counter(labels[0][:20]).most_common(5)
55
+
56
+ top1_label = result_ctr[0][0]
57
+ top_indices = []
58
+
59
+ for a, b in zip(labels[0][:20], scores[0][:20]):
60
+ if a == top1_label:
61
+ top_indices.append(b)
62
+
63
+ gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
64
+ predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}
65
+
66
+ return predicted_labels, gallery_images
67
+
68
+
69
+ demo = gr.Interface(
70
+ search,
71
+ gr.Image(type="pil"),
72
+ ["label", "gallery"],
73
+ examples=[["./examples/bird.jpg"]],
74
+ description="WIP - kNN on CUB dataset",
75
+ title="Work in Progress - CHM-Corr",
76
+ )
77
+
78
+ if __name__ == "__main__":
79
+ demo.launch()