sedrickkeh commited on
Commit
016285f
1 Parent(s): 4be744b

Upload 13 files

Browse files
app.py CHANGED
@@ -1,80 +1,192 @@
1
  import gradio as gr
2
  from response_db import StResponseDb
 
 
 
 
 
 
 
 
3
  db = StResponseDb()
4
- a = gr.Number(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- def get_next_question(history):
7
- if len(history)==2:
8
- question = "What is the man doing?"
9
- elif len(history)==4:
10
- question = "How many apples are there?"
11
- else:
12
- question = "What color is the cat?"
13
- return question
14
-
15
- def ask_a_question(input, taskid, history=[]):
16
- history.append(input)
17
- db.add(int(a.value), taskid, len(history)//2-1, history[-2], history[-1])
18
- history.append(get_next_question(history))
19
-
20
  # write some HTML
21
  html = "<div class='chatbot'>"
22
- for m, msg in enumerate(history):
 
23
  cls = "bot" if m%2 == 0 else "user"
24
  html += "<div class='msg {}'> {}</div>".format(cls, msg)
25
  html += "</div>"
26
- return html, history
27
 
 
 
 
 
 
 
 
 
 
28
 
29
- css = """
30
- .chatbox {display:flex;flex-direction:column}
31
- .msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
32
- .msg.user {background-color:cornflowerblue;color:white}
33
- .msg.bot {background-color:lightgray;align-self:self-end}
34
- .footer {display:none !important}
35
- """
36
 
37
  def set_images(taskid):
38
- id1 = f"images/img_{int(10*(taskid-1)+1)}.jpg"
39
- id2 = f"images/img_{int(10*(taskid-1)+2)}.jpg"
40
- id3 = f"images/img_{int(10*(taskid-1)+3)}.jpg"
41
- id4 = f"images/img_{int(10*(taskid-1)+4)}.jpg"
42
- id5 = f"images/img_{int(10*(taskid-1)+5)}.jpg"
43
- id6 = f"images/img_{int(10*(taskid-1)+6)}.jpg"
44
- id7 = f"images/img_{int(10*(taskid-1)+7)}.jpg"
45
- id8 = f"images/img_{int(10*(taskid-1)+8)}.jpg"
46
- id9 = f"images/img_{int(10*(taskid-1)+9)}.jpg"
47
- id10 = f"images/img_{int(10*(taskid-1)+10)}.jpg"
48
- first_question = "How many dogs are there?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  first_question_html = f"<div class='chatbot'><div class='msg bot'>{first_question}</div></div>"
50
- a.value = a.value+1
51
- return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, [first_question], first_question_html
 
 
 
 
 
52
 
53
- with gr.Blocks(css=css) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  with gr.Column() as img_block:
56
  with gr.Row():
57
- img1 = gr.Image()
58
- img2 = gr.Image()
59
- img3 = gr.Image()
60
- img4 = gr.Image()
61
- img5 = gr.Image()
62
  with gr.Row():
63
- img6 = gr.Image()
64
- img7 = gr.Image()
65
- img8 = gr.Image()
66
- img9 = gr.Image()
67
- img10 = gr.Image()
68
  conversation = gr.HTML()
69
- history_log = gr.State([])
70
 
 
 
 
 
 
 
 
 
 
71
  with gr.Column():
72
  with gr.Row():
73
- taskid = gr.Number(label="Task ID (Choose from [1,2,3])")
74
- btn1 = gr.Button("Enter")
75
- btn1.click(set_images, inputs=taskid, outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, history_log, conversation])
76
- answer = gr.inputs.Textbox(placeholder="Insert answer here.", label="Answer the given question.")
77
- submit = gr.Button("Submit")
78
- submit.click(fn=ask_a_question, inputs=[answer, taskid, history_log], outputs=[conversation, history_log])
79
-
80
- demo.launch()
 
 
 
1
  import gradio as gr
2
  from response_db import StResponseDb
3
+ from create_cache import Game_Cache
4
+ import numpy as np
5
+ from PIL import Image
6
+ import pandas as pd
7
+ import torch
8
+ import pickle
9
+ import uuid
10
+
11
  db = StResponseDb()
12
+ css = """
13
+ .chatbot {display:flex;flex-direction:column}
14
+ .msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
15
+ .msg.user {background-color:cornflowerblue;color:white;align-self:self-end}
16
+ .msg.bot {background-color:lightgray}
17
+ .na_button {background-color:red;color:red}
18
+ """
19
+
20
+ from model.run_question_asking_model import return_modules, return_modules_yn
21
+ question_model, response_model_simul, _, caption_model = return_modules()
22
+ question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn()
23
+
24
+ class Game_Session:
25
+ def __init__(self, taskid, yn, hard_setting):
26
+ self.yn = yn
27
+ self.hard_setting = hard_setting
28
+
29
+ global question_model, response_model_simul, caption_model
30
+ global question_model_yn, response_model_simul_yn, caption_model_yn
31
+ self.question_model = question_model
32
+ self.response_model_simul = response_model_simul
33
+ self.caption_model = caption_model
34
+ self.question_model_yn = question_model_yn
35
+ self.response_model_simul_yn = response_model_simul_yn
36
+ self.caption_model_yn = caption_model_yn
37
+
38
+ global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions
39
+ self.image_files, self.image_np, self.p_y_x, self.p_r_qy, self.p_y_xqr = None, None, None, None, None
40
+ self.captions, self.questions, self.target_questions = None, None, None
41
+
42
+ self.history = []
43
+ self.game_id = str(uuid.uuid4())
44
+ self.set_curr_models()
45
+
46
+ def set_curr_models(self):
47
+ if self.yn:
48
+ self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model_yn, self.caption_model_yn, self.response_model_simul_yn
49
+ else:
50
+ self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model, self.caption_model, self.response_model_simul
51
+
52
+ def get_next_question(self):
53
+ return self.curr_question_model.select_best_question(self.p_y_x, self.questions, self.images_np, self.captions, self.curr_response_model_simul)
54
+
55
+
56
+ def ask_a_question(input, taskid, gs):
57
+ gs.history.append(input)
58
+ gs.p_r_qy = gs.curr_response_model_simul.get_p_r_qy(input, gs.history[-2], gs.images_np, gs.captions)
59
+ gs.p_y_xqr = gs.p_y_x*gs.p_r_qy
60
+ gs.p_y_xqr = gs.p_y_xqr/torch.sum(gs.p_y_xqr)if torch.sum(gs.p_y_xqr) != 0 else torch.zeros_like(gs.p_y_xqr)
61
+ gs.p_y_x = gs.p_y_xqr
62
+ gs.questions.remove(gs.history[-2])
63
+ db.add(gs.game_id, taskid, len(gs.history)//2-1, gs.history[-2], gs.history[-1])
64
+ gs.history.append(gs.get_next_question())
65
+
66
+ top_prob = torch.max(gs.p_y_x).item()
67
+ top_pred = torch.argmax(gs.p_y_x).item()
68
+ if top_prob > 0.8:
69
+ gs.history = gs.history[:-1]
70
+ db.add(gs.game_id, taskid, len(gs.history)//2, f"Guess: Image {top_pred}", "")
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # write some HTML
73
  html = "<div class='chatbot'>"
74
+ for m, msg in enumerate(gs.history):
75
+ if msg=="nothing": msg="n/a"
76
  cls = "bot" if m%2 == 0 else "user"
77
  html += "<div class='msg {}'> {}</div>".format(cls, msg)
78
  html += "</div>"
 
79
 
80
+ ### Game finished:
81
+ if top_prob > 0.8:
82
+ html += f"<p>The model identified <b>Image {top_pred+1}</b> as the image. Please select a new task ID to continue.</p>"
83
+ return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False)
84
+ else:
85
+ if not gs.yn:
86
+ return html, gs, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False)
87
+ else:
88
+ return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True)
89
 
 
 
 
 
 
 
 
90
 
91
  def set_images(taskid):
92
+ pilot_study = pd.read_csv("pilot-study.csv")
93
+ taskid_original = taskid
94
+ taskid = pilot_study['mscoco-id'].tolist()[int(taskid)]
95
+
96
+ with open(f'cache/{int(taskid)}.p', 'rb') as fp:
97
+ game_cache = pickle.load(fp)
98
+ gs = Game_Session(int(taskid), game_cache.yn, game_cache.hard_setting)
99
+ id1 = f"./mscoco-images/val2014/{game_cache.image_files[0]}"
100
+ id2 = f"./mscoco-images/val2014/{game_cache.image_files[1]}"
101
+ id3 = f"./mscoco-images/val2014/{game_cache.image_files[2]}"
102
+ id4 = f"./mscoco-images/val2014/{game_cache.image_files[3]}"
103
+ id5 = f"./mscoco-images/val2014/{game_cache.image_files[4]}"
104
+ id6 = f"./mscoco-images/val2014/{game_cache.image_files[5]}"
105
+ id7 = f"./mscoco-images/val2014/{game_cache.image_files[6]}"
106
+ id8 = f"./mscoco-images/val2014/{game_cache.image_files[7]}"
107
+ id9 = f"./mscoco-images/val2014/{game_cache.image_files[8]}"
108
+ id10 = f"./mscoco-images/val2014/{game_cache.image_files[9]}"
109
+
110
+ gs.image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
111
+ gs.image_files = [x[15:] for x in gs.image_files]
112
+ gs.images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in gs.image_files]
113
+ gs.images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in gs.images_np]
114
+
115
+ gs.p_y_x = (torch.ones(10)/10).to(gs.curr_question_model.device)
116
+ gs.captions = gs.curr_caption_model.get_captions(gs.image_files)
117
+ gs.questions, gs.target_questions = gs.curr_question_model.get_questions(gs.image_files, gs.captions, 0)
118
+ gs.curr_question_model.reset_question_bank()
119
+ gs.curr_question_model.question_bank = game_cache.question_dict
120
+ first_question = gs.curr_question_model.select_best_question(gs.p_y_x, gs.questions, gs.images_np, gs.captions, gs.curr_response_model_simul)
121
  first_question_html = f"<div class='chatbot'><div class='msg bot'>{first_question}</div></div>"
122
+ gs.history.append(first_question)
123
+ html = f"<p>Current Task ID: <b>{int(taskid_original)}</b></p>"
124
+ if not gs.yn:
125
+ return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False)
126
+ else:
127
+ return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=True), gr.Button.update(visible=True)
128
+
129
 
130
+ with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
131
+ gr.HTML("<h1>Image Q&A Guessing Game</h1>\
132
+ <p style='font-size:120%;'>\
133
+ Imagine you are playing 20-questions with an AI model.<br>\
134
+ The AI model plays the role of the question asker. You play the role of the responder. <br>\
135
+ There are 10 images. <b>Your image is Image 1</b>. The other images are distraction images.\
136
+ The model can see all 10 images and all the questions and answers for the current set of images. It will ask a question based on the available information.<br>\
137
+ <span style='color: #0000ff'>The goal of the model is to accurately guess the correct image (i.e. <b><span style='color: #0000ff'>Image 1</span></b>) in as few turns as possible.<br>\
138
+ Your goal is to help the model guess the image by answering as clearly and accurately as possible.</span><br><br>\
139
+ <b>Guidelines:</b><br>\
140
+ <ol style='font-size:120%;'>\
141
+ <li>It is best to keep your answers short (a single word or a short phrase). No need to answer in full sentences.</li>\
142
+ <li>If you feel that the question cannot be answered or does not apply to Image 1, please select N/A.</li>\
143
+ </ol> \
144
+ <br>\
145
+ (Note: We are testing multiple game settings. In some instances, the game will be open-ended, while in other instances, the answer choices will be limited to yes/no.)<br></p>\
146
+ <br>\
147
+ <h2>Please enter a TaskID to start</h2>")
148
+
149
+ with gr.Column():
150
+ with gr.Row():
151
+ taskid = gr.Number(label="Task ID (Enter a number from 0 to 160)", value=0)
152
+ start_button = gr.Button("Enter")
153
+ with gr.Row():
154
+ task_text = gr.HTML()
155
 
156
  with gr.Column() as img_block:
157
  with gr.Row():
158
+ img1 = gr.Image(label="Image 1", show_label=True)
159
+ img2 = gr.Image(label="Image 2", show_label=True)
160
+ img3 = gr.Image(label="Image 3", show_label=True)
161
+ img4 = gr.Image(label="Image 4", show_label=True)
162
+ img5 = gr.Image(label="Image 5", show_label=True)
163
  with gr.Row():
164
+ img6 = gr.Image(label="Image 6", show_label=True)
165
+ img7 = gr.Image(label="Image 7", show_label=True)
166
+ img8 = gr.Image(label="Image 8", show_label=True)
167
+ img9 = gr.Image(label="Image 9", show_label=True)
168
+ img10 = gr.Image(label="Image 10", show_label=True)
169
  conversation = gr.HTML()
170
+ game_session_state = gr.State()
171
 
172
+ answer = gr.Textbox(placeholder="Insert answer here.", label="Answer the given question.", visible=False)
173
+ null_answer = gr.Textbox("nothing", visible=False)
174
+ yes_answer = gr.Textbox("yes", visible=False)
175
+ no_answer = gr.Textbox("no", visible=False)
176
+
177
+ with gr.Column():
178
+ with gr.Row():
179
+ yes_box = gr.Button("Yes", visible=False)
180
+ no_box = gr.Button("No", visible=False)
181
  with gr.Column():
182
  with gr.Row():
183
+ na_box = gr.Button("N/A", visible=False, elem_classes="na_button")
184
+ submit = gr.Button("Submit", visible=False)
185
+ ### Button click events
186
+ start_button.click(fn=set_images, inputs=taskid, outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, game_session_state, conversation, task_text, answer, na_box, submit, taskid, start_button, yes_box, no_box])
187
+ submit.click(fn=ask_a_question, inputs=[answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
188
+ na_box.click(fn=ask_a_question, inputs=[null_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
189
+ yes_box.click(fn=ask_a_question, inputs=[yes_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
190
+ no_box.click(fn=ask_a_question, inputs=[no_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
191
+
192
+ demo.launch()
create_cache.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ import pickle
6
+
7
+ class Game_Cache:
8
+ def __init__(self, question_dict, image_files, yn, hard_setting):
9
+ self.question_dict = question_dict
10
+ self.image_files = image_files
11
+ self.yn = yn
12
+ self.hard_setting = hard_setting
13
+
14
+ image_list = []
15
+ with open('./mscoco/mscoco_images.txt', 'r') as f:
16
+ for line in f.readlines():
17
+ image_list.append(line.strip())
18
+ image_list_hard = []
19
+ with open('./mscoco/mscoco_images_attribute_n=1.txt', 'r') as f:
20
+ for line in f.readlines():
21
+ image_list_hard.append(line.strip())
22
+
23
+ yn_indices = list(range(40,80))+list(range(120,160))
24
+ hard_setting_indices = list(range(80,160))
25
+
26
+
27
+ from model.run_question_asking_model import return_modules, return_modules_yn
28
+ global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions
29
+ global question_model, response_model_simul, caption_model
30
+ question_model, response_model_simul, _, caption_model = return_modules()
31
+ global question_model_yn, response_model_simul_yn, caption_model_yn
32
+ question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn()
33
+
34
+ def create_cache(taskid):
35
+ original_taskid = taskid
36
+ global question_model, response_model_simul, caption_model
37
+ global question_model_yn, response_model_simul_yn, caption_model_yn
38
+ if taskid in yn_indices:
39
+ yn = True
40
+ curr_question_model, curr_response_model_simul, curr_caption_model = question_model, response_model_simul, caption_model
41
+ taskid-=40
42
+ else:
43
+ yn = False
44
+ curr_question_model, curr_response_model_simul, curr_caption_model = question_model_yn, response_model_simul_yn, caption_model_yn
45
+ if taskid in hard_setting_indices:
46
+ hard_setting = True
47
+ image_list_curr = image_list_hard
48
+ taskid -= 80
49
+ else:
50
+ hard_setting = False
51
+ image_list_curr = image_list
52
+
53
+ id1 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+0]}"
54
+ id2 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+1]}"
55
+ id3 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+2]}"
56
+ id4 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+3]}"
57
+ id5 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+4]}"
58
+ id6 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+5]}"
59
+ id7 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+6]}"
60
+ id8 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+7]}"
61
+ id9 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+8]}"
62
+ id10 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+9]}"
63
+ image_names = []
64
+ for i in range(10):
65
+ image_names.append(image_list_curr[int(taskid)*10+i])
66
+ image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
67
+ image_files = [x[15:] for x in image_files]
68
+ images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in image_files]
69
+ images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images_np]
70
+ p_y_x = (torch.ones(10)/10).to(curr_question_model.device)
71
+ captions = curr_caption_model.get_captions(image_files)
72
+ questions, target_questions = curr_question_model.get_questions(image_files, captions, 0)
73
+ curr_question_model.reset_question_bank()
74
+ first_question = curr_question_model.select_best_question(p_y_x, questions, images_np, captions, curr_response_model_simul)
75
+
76
+ gc = Game_Cache(curr_question_model.question_bank, image_names, yn, hard_setting)
77
+ with open(f'./cache{int(taskid)}.p', 'wb') as fp:
78
+ pickle.dump(gc, fp, protocol=pickle.HIGHEST_PROTOCOL)
79
+
80
+ if __name__=="__main__":
81
+ for i in range(160):
82
+ create_cache(i)
83
+
model/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import model.run_question_asking_model
2
+ import model.model.caption_model
3
+ import model.model.question_asking_model
4
+ import model.model.question_generator
5
+ import model.model.question_model_base
6
+ import model.model.response_model
model/model/caption_model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from pycocotools.coco import COCO
3
+
4
+ def get_caption_model(args, question_model):
5
+ if args.caption_strategy=="simple":
6
+ return CaptionModelSimple(question_model)
7
+ elif args.caption_strategy=="granular":
8
+ return CaptionModelGranular()
9
+ elif args.caption_strategy=="gtruth":
10
+ return CaptionModelCOCO()
11
+ else:
12
+ raise ValueError(f"{args.caption_strategy} is not a valid caption strategy.")
13
+
14
+
15
+ class CaptionModel():
16
+ # Class for the other CaptionModels to inherit from
17
+ def __init__(self):
18
+ pass
19
+
20
+ def get_captions(self, images, **kwargs):
21
+ raise NotImplemented
22
+
23
+ class CaptionModelCOCO():
24
+ # Ground truth annotations from COCO dataset
25
+ def __init__(self):
26
+ dataDir='./mscoco'
27
+ val_file = '{}/annotations/captions_val2014.json'.format(dataDir)
28
+ self.coco_caps_val = COCO(val_file)
29
+ val_file = '{}/annotations/instances_val2014.json'.format(dataDir)
30
+ self.coco_anns_val = COCO(val_file)
31
+
32
+ def get_captions(self, images, return_all=False):
33
+ captions = []
34
+ for i, image in enumerate(images):
35
+ image_id = int(image.split('_')[-1].split('.')[0].lstrip("0"))
36
+ annIds = self.coco_caps_val.getAnnIds(imgIds=image_id)
37
+ anns_val = self.coco_caps_val.loadAnns(annIds)
38
+ # annIds = self.coco_caps_train.getAnnIds(imgIds=image_id)
39
+ # anns_train = self.coco_caps_train.loadAnns(annIds)
40
+ # anns = anns_val + anns_train
41
+ anns = anns_val
42
+ anns = [d['caption'] for d in anns]
43
+ if return_all: captions.append(anns)
44
+ else: captions.append(anns[0])
45
+ return captions
46
+
47
+ def get_subjects(self, images):
48
+ subjects = []
49
+ for i, image in enumerate(images):
50
+ image_id = int(image.split('_')[-1].split('.')[0].lstrip("0"))
51
+ annIds = self.coco_anns_val.getAnnIds(imgIds=image_id)
52
+ anns_val = self.coco_anns_val.loadAnns(annIds)
53
+ cats_val = list(set([d['category_id'] for d in anns_val]))
54
+ annIds = self.coco_caps_train.getAnnIds(imgIds=image_id)
55
+ anns_train = self.coco_caps_train.loadAnns(annIds)
56
+ cats_train = list(set([d['category_id'] for d in anns_train]))
57
+ cats = self.coco_anns_val.loadCats(ids=cats_val+cats_train)
58
+ cats1, cats2 = [d['supercategory'] for d in cats], [d['name'] for d in cats]
59
+ cats = list(set(cats1+cats2))
60
+ subjects.append(cats)
61
+ return subjects
62
+
63
+
64
+ class CaptionModelSimple():
65
+ def __init__(self, qa_model):
66
+ self.qa_model = qa_model
67
+
68
+ def get_captions(self, images):
69
+ captions = []
70
+ for i, image in enumerate(images):
71
+ caption = self.qa_model.generate_description(image, images)
72
+ captions.append(caption)
73
+ return captions
74
+
75
+ class CaptionModelGranular():
76
+ def __init__(self):
77
+ df_train = pd.read_json("captions/coco_train_captions.jsonl", lines=True)
78
+ df_val = pd.read_json("captions/coco_val_captions.jsonl", lines=True)
79
+ self.caption_dict = {}
80
+ for i in range(len(df_train)):
81
+ self.caption_dict[str(df_train.image_id[i])] = df_train.caption[i]
82
+ for i in range(len(df_val)):
83
+ self.caption_dict[str(df_val.image_id[i])] = df_val.caption[i]
84
+
85
+ def get_captions(self, images):
86
+ captions = []
87
+ for i, image in enumerate(images):
88
+ captions.append(self.caption_dict[image.split('.')[0].split('_')[-1].lstrip('0')])
89
+ return captions
model/model/question_asking_model.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import pandas as pd
3
+ from model.model.question_model_base import QuestionAskingModel
4
+ import openai
5
+
6
+ def get_question_model(args):
7
+ if args.question_strategy=="rule":
8
+ return QuestionAskingModelSimple(args)
9
+ elif args.question_strategy=="gpt3":
10
+ return QuestionAskingModelGPT3(args)
11
+ else:
12
+ raise ValueError(f"{args.question_strategy} is not a valid question strategy.")
13
+
14
+
15
+ class QuestionAskingModelSimple(QuestionAskingModel):
16
+ def __init__(self, args):
17
+ super(QuestionAskingModelSimple, self).__init__(args)
18
+
19
+ def get_questions(self, images, captions, target_idx=0):
20
+ questions = []
21
+ for i, (image, caption) in enumerate(zip(images, captions)):
22
+ image_questions = self.question_generator.generate_is_there_question(caption)
23
+ if i == target_idx: target_questions = image_questions
24
+ questions += image_questions
25
+ questions = list(set(questions))
26
+ # random.shuffle(questions)
27
+ return questions, target_questions
28
+
29
+
30
+ class QuestionAskingModelGPT3(QuestionAskingModel):
31
+ def __init__(self, args):
32
+ super(QuestionAskingModelGPT3, self).__init__(args)
33
+ self.gpt3_path = f"data/{args.gpt3_save_name}.csv"
34
+ try: self.gpt3_captions = pd.read_csv(self.gpt3_path) # cache locally to save GPT3 compute
35
+ except: self.gpt3_captions = pd.DataFrame({"caption":[], "question":[]})
36
+
37
+ def generate_gpt3_questions(self, caption):
38
+ # c1="Two people sitting on grass. The man in a blue shirt is playing a guitar and is on the left. The woman on the right is eating a sandwich."
39
+ # q1="What are the two people doing? How many people are there? What color is the man's shirt? Where is the man? Where is the woman? What is the man doing? What is the woman doing? Who is playing the guitar? Who is eating a sandwich?"
40
+ # c2="There is a table in the middle of the room. On the left there is a bowl of red apples. To the right of the bowl, there is a glass of juice, as well as a bowl of salad. Beside the table there is a bookshelf with ten books of various colors."
41
+ # q2="What is beside the table? What is on the left of the table? What color are the apples? How many bowls are there? What is inside the bowl? What is inside the glass? How many books are there? What color are the books?"
42
+ c1="A living room with a couch, coffee table and two large windows with white curtains."
43
+ q1="What color is the couch? How many windows are there? How many tables are there? What color is the table? What color are the curtains? What is next to the table? What is next to the couch?"
44
+ c2="A large, shiny, stainless, side by side refrigerator in a kitchen."
45
+ q2="Where is the refrigerator? What color is the refrigerator?"
46
+ c3="A stop sign with a skeleton painted on it, next to a car."
47
+ q3="What color is the sign? What color is the car? What is next to the sign? What is next to the car? What is on the sign? Where is the car?"
48
+ c4="A man brushing his teeth with a toothbrush"
49
+ q4="What is the man doing? Where is the man? What color is the toothbrush? How many people are there?"
50
+ prompt=f"Generate questions for the following caption:\nCaption: {c1}\nQuestions: {q1}\n"
51
+ prompt+=f"Generate questions for the following caption:\nCaption: {c2}\nQuestions: {q2}\n"
52
+ prompt+=f"Generate questions for the following caption:\nCaption: {c3}\nQuestions: {q3}\n"
53
+ prompt+=f"Generate questions for the following caption:\nCaption: {c4}\nQuestions: {q4}\n"
54
+ prompt+=f"Generate questions for the following caption:\nCaption: {caption}\nQuestions:"
55
+ response = openai.Completion.create(
56
+ model="text-davinci-003",
57
+ prompt=prompt,
58
+ temperature=0,
59
+ max_tokens=1024,
60
+ top_p=1,
61
+ frequency_penalty=0,
62
+ presence_penalty=0
63
+ )
64
+ questions = response["choices"][0]["text"]
65
+ questions = [i.strip() for i in questions.split('?') if len(i.strip())>1]
66
+ questions = [i+"?" for i in questions]
67
+ return questions
68
+
69
+ def get_questions(self, images, captions, target_idx=0):
70
+ questions = []
71
+ for i, (image, caption) in enumerate(zip(images, captions)):
72
+ if caption in self.gpt3_captions.caption.tolist():
73
+ image_questions = self.gpt3_captions[self.gpt3_captions.caption==caption].question.tolist()
74
+ else:
75
+ image_questions = self.generate_gpt3_questions(caption)
76
+ image_df = pd.DataFrame({'caption':[caption for _ in image_questions], 'question':image_questions})
77
+ self.gpt3_captions = pd.concat([self.gpt3_captions, image_df])
78
+ self.gpt3_captions.to_csv(self.gpt3_path, index=False)
79
+ if i == target_idx: target_questions = image_questions
80
+ questions += image_questions
81
+ questions = list(set(questions))
82
+ # random.shuffle(questions)
83
+ return questions, target_questions
model/model/question_generator.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+ import nltk.tree
3
+ import collections
4
+
5
+ import benepar
6
+ ### Run this the first time if benepar_en3 is not yet downloaded
7
+ # benepar.download('benepar_en3')
8
+
9
+ def find_word(tree, kind, parents=None):
10
+ if parents is None:
11
+ parents = []
12
+ if not isinstance(tree, nltk.tree.Tree):
13
+ return None, None
14
+ if tree.label() == kind:
15
+ return tree[0], parents
16
+ parents.append(tree)
17
+ for st in tree:
18
+ n, p = find_word(st, kind, parents)
19
+ if n is not None:
20
+ return n, p
21
+ parents.pop()
22
+ return None, None
23
+
24
+ def find_subtrees(tree, kind, blocking_kinds=()):
25
+ result = []
26
+ if not isinstance(tree, nltk.tree.Tree):
27
+ return result
28
+ if tree.label() == kind:
29
+ result.append(tree)
30
+ if tree.label() not in blocking_kinds:
31
+ for st in tree:
32
+ result.extend(find_subtrees(st, kind))
33
+ return result
34
+
35
+ def tree_to_str(tree, transform=lambda w: w):
36
+ l = []
37
+ def list_words(tree):
38
+ if isinstance(tree, str):
39
+ l.append(transform(tree))
40
+ else:
41
+ for st in tree:
42
+ list_words(st)
43
+ list_words(tree)
44
+ if l[-1] == '.':
45
+ l = l[:-1]
46
+ return ' '.join(l)
47
+
48
+ def tree_to_nouns(tree, transform=lambda w: w):
49
+ l = []
50
+ def list_words(tree, noun=False):
51
+ if isinstance(tree, str):
52
+ if noun == True:
53
+ l.append(transform(tree))
54
+ else:
55
+ for st in tree:
56
+ if not isinstance(st, str):
57
+ if st.label() == 'NN':
58
+ noun = True
59
+ else:
60
+ noun = False
61
+ list_words(st, noun)
62
+ list_words(tree)
63
+ if l[-1] == '.':
64
+ l = l[:-1]
65
+ return l
66
+
67
+ def make_determinate(w):
68
+ if w.lower() in ('a', 'an'):
69
+ return 'the'
70
+ return w
71
+
72
+ def make_indeterminate(w):
73
+ if w.lower() in ('the', 'his', 'her', 'their', 'its'):
74
+ return 'a'
75
+ return w
76
+
77
+ def pluralize(singular, plural, number):
78
+ if number <= 1:
79
+ return singular
80
+ return plural
81
+
82
+ def count_labels(tree):
83
+ counts = collections.defaultdict(int)
84
+ def update_counts(node):
85
+ counts[node.label()] += 1
86
+ for child in node:
87
+ if isinstance(child, nltk.tree.Tree):
88
+ update_counts(child)
89
+ update_counts(tree)
90
+ return counts
91
+
92
+ def get_number(tree):
93
+ if not isinstance(tree, nltk.tree.Tree):
94
+ return 0
95
+ if tree.label() == 'NN':
96
+ return 1
97
+ if tree.label() == 'NNS':
98
+ return 2
99
+ first_noun_number = None
100
+ n_np_children = 0
101
+ for subtree in tree:
102
+ label = subtree.label() if isinstance(subtree, nltk.tree.Tree) else None
103
+ if label == 'NP':
104
+ n_np_children += 1
105
+ if label in ('NP', 'NN', 'NNS') and first_noun_number is None:
106
+ first_noun_number = get_number(subtree)
107
+ if tree.label() == 'NP' and n_np_children > 1:
108
+ return 2
109
+ return first_noun_number or 0
110
+
111
+ def is_present_continuous(verb):
112
+ return verb.endswith('ing')
113
+
114
+ class QuestionGenerator:
115
+ def __init__(self):
116
+ self.parser = benepar.Parser("benepar_en3")
117
+
118
+ def generate_what_question(self, s):
119
+ tree = self.parser.parse(s)[0]
120
+ questions = []
121
+ try:
122
+ if len(tree) >= 2 and tree[0].label() == 'NP' and tree[1].label() == 'VP':
123
+ np = tree[0]
124
+ verb = None
125
+ vp = tree[1]
126
+ vnp = None
127
+
128
+ while True:
129
+ verb, verb_parents = find_word(vp, 'VBG')
130
+ if verb is None:
131
+ break
132
+ if is_present_continuous(verb):
133
+ if len(verb_parents[-1]) > 1:
134
+ vnp = verb_parents[-1][1]
135
+ break
136
+ else:
137
+ vp = verb_parents[-1][1]
138
+ to_be = pluralize('is', 'are', get_number(np))
139
+ if vnp is not None and vnp.label() == 'NP':
140
+ questions.append(('What {} {} {}?'
141
+ .format(to_be,
142
+ tree_to_str(np, make_determinate).lower(),
143
+ verb),
144
+ tree_to_nouns(vnp)[-1]))
145
+ except Exception as e:
146
+ print(e)
147
+ return questions
148
+
149
+ def generate_is_there_question(self, s):
150
+ tree = self.parser.parse(s)
151
+ questions = []
152
+ nps = find_subtrees(tree, 'NP', ('PP',))
153
+ for np in nps:
154
+ only_child_label = len(np) == 1 and next(iter(np)).label()
155
+ if only_child_label in ('PRP', 'EX'):
156
+ continue
157
+ try:
158
+ to_be = pluralize('Is', 'Are', get_number(np))
159
+ questions.append('{} there {}?'
160
+ .format(to_be,
161
+ tree_to_str(np, make_indeterminate).lower()))
162
+ except Exception as e:
163
+ print(e)
164
+ return questions
165
+
166
+ def generate_is_there_question_v2(self, s):
167
+ tree = self.parser.parse(s)
168
+ questions = []
169
+ nps = find_subtrees(tree, 'NP', ('PP',))
170
+ for np in nps:
171
+ only_child_label = len(np) == 1 and next(iter(np)).label()
172
+ if only_child_label in ('PRP', 'EX'):
173
+ continue
174
+ try:
175
+ to_be = pluralize('Is', 'Are', get_number(np))
176
+ questions.append('{} there {}?'
177
+ .format(to_be,
178
+ tree_to_str(np, make_indeterminate).lower()))
179
+ except Exception as e:
180
+ print(e)
181
+ if len(questions)==0:
182
+ nps = find_subtrees(tree, 'NNS', ('PP',))
183
+ for np in nps:
184
+ only_child_label = len(np) == 1
185
+ if only_child_label in ('PRP', 'EX'):
186
+ continue
187
+ try:
188
+ to_be = pluralize('Is', 'Are', get_number(np))
189
+ questions.append('{} there {}?'
190
+ .format(to_be,
191
+ tree_to_str(np, make_indeterminate).lower()))
192
+ except Exception as e:
193
+ print(e)
194
+ return questions
model/model/question_model_base.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import scipy
3
+ import numpy as np
4
+ import operator
5
+ import random
6
+ from model.model.question_generator import QuestionGenerator
7
+
8
+
9
+ class QuestionAskingModel():
10
+ def __init__(self, args):
11
+ self.device = args.device
12
+ self.include_what = args.include_what
13
+ self.max_length = 128
14
+ self.eps = 1e-25
15
+ self.multiplier_mode = args.multiplier_mode
16
+ self.num_images = args.num_images
17
+
18
+ # Initialize question generation model
19
+ class Namespace:
20
+ def __init__(self, **kwargs):
21
+ self.__dict__.update(kwargs)
22
+
23
+ self.question_generator = QuestionGenerator()
24
+ self.reset_question_bank()
25
+
26
+ def get_questions(self, images, captions, target_idx=0):
27
+ raise NotImplemented
28
+
29
+ def get_negative_information_gain(self, p_y_x, question, image_set, caption_set, response_model):
30
+ if question in self.question_bank:
31
+ p_r_qy = self.question_bank[question]
32
+ else:
33
+ is_a_questions = self.question_generator.generate_is_there_question_v2(question)
34
+ is_a_multiplier = []
35
+ if self.multiplier_mode != "none":
36
+ for is_a_q in is_a_questions:
37
+ # print(f"IsA Question: {is_a_q}")
38
+ p_r_qy = response_model.get_p_r_qy(None, is_a_q, image_set, caption_set, is_a=True)
39
+ p_r_qy = p_r_qy.detach().cpu().numpy()
40
+ is_a_multiplier.append(p_r_qy)
41
+ if len(is_a_multiplier)==0: is_a_multiplier.append([0 for _ in range(self.num_images)])
42
+ is_a_multiplier = torch.tensor(scipy.stats.mstats.gmean(is_a_multiplier, axis=0)).to("cuda")
43
+ if self.multiplier_mode=="hard":
44
+ for i in range(is_a_multiplier.shape[0]):
45
+ if is_a_multiplier[i]<0.5: is_a_multiplier[i]=1e-6
46
+ else: is_a_multiplier[i]=0.9
47
+ elif self.multiplier_mode=="soft":
48
+ pass
49
+ elif self.multiplier_mode == "none":
50
+ is_a_multiplier = torch.tensor([1 for _ in range(self.num_images)]).to("cuda")
51
+ p_r_qy = response_model.get_p_r_qy(None, question, image_set, caption_set)
52
+ p_r_qy = torch.stack([is_a_multiplier*p_r_qy[r] for r in range(len(p_r_qy))])
53
+ for i in range(self.num_images):
54
+ if is_a_multiplier[i] < 0.5:
55
+ p_r_qy[i] = 1-is_a_multiplier
56
+ else:
57
+ p_r_qy[i] *= is_a_multiplier
58
+ if self.multiplier_mode=="none":
59
+ p_r_qy = response_model.get_p_r_qy(None, question, image_set, caption_set)
60
+ if not self.include_what:
61
+ p_r_qy = [p_r_qy, 1-p_r_qy]
62
+ self.question_bank[question] = p_r_qy
63
+
64
+ p_y_xqr = torch.stack([p_y_x*p_r_qy[r] for r in range(len(p_r_qy))])
65
+ p_y_xqr = [p_y_xqr[r]/torch.sum(p_y_xqr[r]) if torch.sum(p_y_xqr[r]) != 0 \
66
+ else [0]*len(p_y_xqr[r]) for r in range(len(p_y_xqr))]
67
+ return torch.sum(torch.stack([p_r_qy[r]*p_y_x*torch.log2(1/(p_y_xqr[r]+self.eps)) for r in range(len(p_r_qy))]))
68
+
69
+ def get_question_ranking(self, p_y_x, question_set, image_set, caption_set, response_model):
70
+ H_y_rxq = [0]*len(question_set)
71
+
72
+ for i, question in enumerate(question_set):
73
+ H_y_rxq[i] = self.get_negative_information_gain(p_y_x, question, image_set, caption_set, response_model)
74
+
75
+ IG = - torch.stack(H_y_rxq).unsqueeze(1)
76
+ ranked_questions = sorted(zip(list(IG.data.cpu().numpy()), question_set),
77
+ key = operator.itemgetter(0))[::-1]
78
+ return ranked_questions
79
+
80
+ def select_best_question(self, p_y_x, question_set, image_set, caption_set, response_model):
81
+ ranked_questions = self.get_question_ranking(p_y_x, question_set, image_set, caption_set, response_model)
82
+ return ranked_questions[0][1]
83
+
84
+ def reset_question_bank(self):
85
+ self.question_bank = {}
model/model/response_model.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoProcessor, AutoTokenizer, AutoModelForQuestionAnswering, pipeline
4
+ from transformers import ViltProcessor, ViltForQuestionAnswering
5
+ from transformers import BlipProcessor, BlipForQuestionAnswering
6
+ from sentence_transformers import SentenceTransformer
7
+ import openai
8
+ from PIL import Image
9
+
10
+ def get_response_model(args, response_type):
11
+ if response_type=="QA":
12
+ return ResponseModelQA(args.device, args.include_what)
13
+ elif response_type=="VQA1":
14
+ return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="vilt1")
15
+ elif response_type=="VQA2":
16
+ return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="vilt2")
17
+ elif response_type=="VQA3":
18
+ return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="blip")
19
+ elif response_type=="VQA4":
20
+ return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="git")
21
+ else:
22
+ raise ValueError(f"{response_type} is not a valid response type.")
23
+
24
+
25
+ class ResponseModel(nn.Module):
26
+ # Class for the other ResponseModels to inherit from
27
+ def __init__(self, device, include_what):
28
+ super(ResponseModel, self).__init__()
29
+ self.device = device
30
+ self.include_what = include_what
31
+ self.model = None
32
+
33
+ def get_response(self, question, image, caption, target_questions, **kwargs):
34
+ raise NotImplemented
35
+
36
+ def get_p_r_qy(self, response, question, images, captions, **kwargs):
37
+ raise NotImplemented
38
+
39
+ class ResponseModelQA(ResponseModel):
40
+ def __init__(self, device, include_what):
41
+ super(ResponseModelQA, self).__init__(device, include_what)
42
+ if not self.include_what:
43
+ tokenizer = AutoTokenizer.from_pretrained("AmazonScience/qanlu")
44
+ model = AutoModelForQuestionAnswering.from_pretrained("AmazonScience/qanlu")
45
+ self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=0) # remove device=0 for cpu
46
+ elif self.include_what:
47
+ tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
48
+ model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
49
+ self.model_wh = pipeline('question-answering', model=model, tokenizer=tokenizer, device=0) # remove device=0 for cpu
50
+
51
+ def get_response(self, question, image, caption, target_questions, **kwargs):
52
+ if self.include_what:
53
+ answer = self.model({'context':caption, 'question':question})
54
+ return answer['answer'].split(' ')[-1]
55
+ else:
56
+ answer = self.model({'context':f"Yes. No. {caption}", 'question':question})
57
+ response, score = answer['answer'], answer['score']
58
+ if score>0.5:
59
+ response = response.lower().replace('.','')
60
+ if "yes" in response.split() and "no" not in response.split():
61
+ response = 'yes'
62
+ elif "no" in response.split() and "yes" not in response.split():
63
+ response = 'no'
64
+ else:
65
+ response = 'yes' if question in target_questions else 'no'
66
+ else:
67
+ response = 'yes' if question in target_questions else 'no'
68
+ return response
69
+
70
+ def get_p_r_qy(self, response, question, images, captions, **kwargs):
71
+ if self.include_what:
72
+ raise NotImplementedError
73
+ else:
74
+ p_r_qy = torch.zeros(len(captions))
75
+ qa_input = {'context':[f"Yes. No. {c}" for c in captions], 'question':[question for _ in captions]}
76
+ answers = self.model(qa_input)
77
+ for idx, answer in enumerate(answers):
78
+ curr_ans, score = answer['answer'], answer['score']
79
+ if curr_ans.strip() in ["Yes.", "No."]:
80
+ if response==None:
81
+ if curr_ans.strip()=="No.": p_r_qy[idx] = 1-score
82
+ if curr_ans.strip()=="Yes.": p_r_qy[idx] = score
83
+ elif curr_ans.strip().lower().replace('.','')==response: p_r_qy[idx]=score
84
+ else: p_r_qy[idx]=1-score
85
+ else:
86
+ p_r_qy[idx]=0.5
87
+ return p_r_qy.to(self.device)
88
+
89
+ class ResponseModelVQA(ResponseModel):
90
+ def __init__(self, device, include_what, question_generator, vqa_type):
91
+ super(ResponseModelVQA, self).__init__(device, include_what)
92
+ self.vqa_type = vqa_type
93
+ self.question_generator = question_generator
94
+ self.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
95
+ if vqa_type=="vilt1":
96
+ self.processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
97
+ self.model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device)
98
+ self.vocab = list(self.model.config.label2id.keys())
99
+ elif vqa_type=="vilt2":
100
+ self.processor = AutoProcessor.from_pretrained("tufa15nik/vilt-finetuned-vqasi")
101
+ self.model = ViltForQuestionAnswering.from_pretrained("tufa15nik/vilt-finetuned-vqasi").to("cuda")
102
+ self.vocab = list(self.model.config.label2id.keys())
103
+ elif vqa_type=="blip":
104
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
105
+ self.model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
106
+ elif vqa_type=="git":
107
+ pass
108
+ else:
109
+ raise ValueError(f"{vqa_type} is not a valid vqa_type.")
110
+
111
+
112
+ def get_response(self, question, image, caption, target_questions, is_a=False):
113
+ encoding = self.processor(image, question, return_tensors="pt").to(self.device)
114
+ if is_a==False:
115
+ is_a_questions = self.question_generator.generate_is_there_question_v2(question)
116
+ is_a_responses = []
117
+ if question in ["What is in the photo?", "What is in the picture?", "What is in the background?"]:
118
+ is_a_questions = []
119
+ for q in is_a_questions:
120
+ is_a_responses.append(self.get_response(q, image, caption, target_questions, is_a=True))
121
+ no_cnt = sum([i.lower()=="no" for i in is_a_responses])
122
+ if len(is_a_responses)>0 and no_cnt/len(is_a_responses)>=0.5:
123
+ if question[:8]=="How many": return "0"
124
+ else: return "nothing"
125
+ if self.vqa_type in ["vilt1", "vilt2"]:
126
+ outputs = self.model(**encoding)
127
+ logits = torch.nn.functional.softmax(outputs.logits, dim=1)
128
+ idx = logits.argmax(-1).item()
129
+ response = self.model.config.id2label[idx]
130
+ response = response.lower().replace('.','').strip()
131
+ elif self.vqa_type == "blip":
132
+ outputs = self.model.generate(**encoding)
133
+ response = self.processor.decode(outputs[0], skip_special_tokens=True)
134
+ return response
135
+
136
+ def get_p_r_qy(self, response, question, images, captions, is_a=False):
137
+ p_r_qy = torch.zeros(len(captions))
138
+ logits_arr = []
139
+ for i, image in enumerate(images):
140
+ with torch.no_grad():
141
+ if len(question) > 150: question="" # ignore question if too long
142
+ encoding = self.processor(image, question, return_tensors="pt").to(self.device)
143
+ outputs = self.model(**encoding)
144
+ logits = torch.nn.functional.softmax(outputs.logits, dim=1)
145
+ idx = logits.argmax(-1).item()
146
+ curr_response = self.model.config.id2label[idx]
147
+ curr_response = curr_response.lower().replace('.','').strip()
148
+ if self.include_what==False or is_a==True:
149
+ if response==None:
150
+ if curr_response=="yes": p_r_qy[i] = logits[0][3].item()
151
+ elif curr_response=="no": p_r_qy[i] = 1-logits[0][9].item()
152
+ else: p_r_qy[i] = 0.5
153
+ elif curr_response==response: p_r_qy[i] = logits[0][idx].item()
154
+ else: p_r_qy[i] = 1-logits[0][idx].item()
155
+ else:
156
+ logits_arr.append(logits)
157
+ if self.include_what==False or is_a==True:
158
+ return p_r_qy.to(self.device)
159
+ else:
160
+ logits = torch.concat(logits_arr)
161
+ if response==None:
162
+ top_answers = logits.argmax(1)
163
+ p_r_qy = logits[:,top_answers]
164
+ else:
165
+ response_idx = self.get_response_idx(response)
166
+ p_r_qy = logits[:,response_idx]
167
+
168
+ # check if this
169
+ # consider rerunning also without the geometric mean
170
+ if response=="nothing":
171
+ is_a_questions = self.question_generator.generate_is_there_question_v2(question)
172
+ for idx, (caption, image) in enumerate(zip(captions, images)):
173
+ current_responses = []
174
+ for is_a_q in is_a_questions:
175
+ current_responses.append(self.get_response(is_a_q, image, caption, None, is_a=True))
176
+ no_cnt = sum([i.lower()=="no" for i in current_responses])
177
+ if len(current_responses)>0 and no_cnt/len(current_responses)>=0.5:
178
+ p_r_qy[idx] = 1.0
179
+ return p_r_qy.to(self.device)
180
+
181
+ def get_response_idx(self, response):
182
+ if response in self.model.config.label2id:
183
+ return self.model.config.label2id[response]
184
+ else:
185
+ embs = self.sentence_transformer.encode(self.vocab, convert_to_tensor=True)
186
+ emb_response = self.sentence_transformer.encode([response], convert_to_tensor=True)
187
+ dists = torch.nn.CosineSimilarity(-1)(emb_response, embs)
188
+ best_response_idx = torch.argmax(dists)
189
+ best_response = self.vocab[best_response_idx]
190
+ return self.model.config.label2id[best_response]
model/run_question_asking_model.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model.model.question_asking_model import get_question_model
2
+ from model.model.caption_model import get_caption_model
3
+ from model.model.response_model import get_response_model
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from PIL import Image
8
+
9
+ import argparse
10
+ import random
11
+ from tqdm.auto import tqdm
12
+ import numpy as np
13
+ import pandas as pd
14
+ import logging
15
+ from model.utils import logging_handler, image_saver, assert_checks
16
+
17
+ random.seed(123)
18
+
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('--device', type=str, default='cuda')
21
+ parser.add_argument('--include_what', action='store_true')
22
+ parser.add_argument('--target_idx', type=int, default=0)
23
+ parser.add_argument('--max_num_questions', type=int, default=25)
24
+ parser.add_argument('--num_images', type=int, default=10)
25
+ parser.add_argument('--beam', type=int, default=1)
26
+ parser.add_argument('--num_samples', type=int, default=100)
27
+ parser.add_argument('--threshold', type=float, default=0.9)
28
+
29
+ parser.add_argument('--caption_strategy', type=str, default='simple', choices=['simple', 'granular', 'gtruth'])
30
+ parser.add_argument('--sample_strategy', type=str, default='random', choices=['random', 'attribute', 'clip'])
31
+ parser.add_argument('--attribute_n', type=int, default=1) # Number of attributes to split
32
+ parser.add_argument('--response_type_simul', type=str, default='VQA1', choices=['simple', 'QA', 'VQA1', 'VQA2', 'VQA3', 'VQA4'])
33
+ parser.add_argument('--response_type_gtruth', type=str, default='VQA2', choices=['simple', 'QA', 'VQA1', 'VQA2', 'VQA3', 'VQA4'])
34
+ parser.add_argument('--question_strategy', type=str, default='gpt3', choices=['rule', 'gpt3'])
35
+ parser.add_argument('--multiplier_mode', type=str, default='soft', choices=['soft', 'hard', 'none'])
36
+
37
+ parser.add_argument('--gpt3_save_name', type=str, default='questions_gpt3')
38
+ parser.add_argument('--save_name', type=str, default=None)
39
+ parser.add_argument('--verbose', action='store_true')
40
+ args = parser.parse_args()
41
+ args.question_strategy='gpt3'
42
+ args.include_what=True
43
+ args.response_type_simul='VQA1'
44
+ args.response_type_gtruth='VQA3'
45
+ args.multiplier_mode='soft'
46
+ args.sample_strategy='attribute'
47
+ args.attribute_n=1
48
+ args.caption_strategy='gtruth'
49
+ assert_checks(args)
50
+ if args.save_name is None: logger = logging_handler(args.verbose, args.save_name)
51
+ args.load_response_model = True
52
+
53
+ print("1. Loading question model ...")
54
+ question_model = get_question_model(args)
55
+ args.question_generator = question_model.question_generator
56
+ print("2. Loading response model simul ...")
57
+ response_model_simul = get_response_model(args, args.response_type_simul)
58
+ response_model_simul.to(args.device)
59
+ print("3. Loading response model gtruth ...")
60
+ response_model_gtruth = get_response_model(args, args.response_type_gtruth)
61
+ response_model_gtruth.to(args.device)
62
+ print("4. Loading caption model ...")
63
+ caption_model = get_caption_model(args, question_model)
64
+
65
+
66
+
67
+ def return_modules():
68
+ return question_model, response_model_simul, response_model_gtruth, caption_model
69
+
70
+
71
+
72
+ args.question_strategy='rule'
73
+ args.include_what=False
74
+ args.response_type_simul='VQA1'
75
+ args.response_type_gtruth='VQA3'
76
+ args.multiplier_mode='none'
77
+ args.sample_strategy='attribute'
78
+ args.attribute_n=1
79
+ args.caption_strategy='gtruth'
80
+
81
+ print("1. Loading question model ...")
82
+ question_model_yn = get_question_model(args)
83
+ args.question_generator_yn = question_model_yn.question_generator
84
+ print("2. Loading response model simul ...")
85
+ response_model_simul_yn = get_response_model(args, args.response_type_simul)
86
+ response_model_simul_yn.to(args.device)
87
+ print("3. Loading response model gtruth ...")
88
+ response_model_gtruth_yn = get_response_model(args, args.response_type_gtruth)
89
+ response_model_gtruth_yn.to(args.device)
90
+ print("4. Loading caption model ...")
91
+ caption_model_yn = get_caption_model(args, question_model_yn)
92
+
93
+
94
+ def return_modules_yn():
95
+ return question_model_yn, response_model_simul_yn, response_model_gtruth_yn, caption_model_yn
96
+
97
+
98
+
99
+ # args.question_strategy='gpt3'
100
+ # args.include_what=True
101
+ # args.response_type_simul='VQA1'
102
+ # args.response_type_gtruth='VQA3'
103
+ # args.multiplier_mode='none'
104
+ # args.sample_strategy='attribute'
105
+ # args.attribute_n=1
106
+ # args.caption_strategy='gtruth'
107
+ # assert_checks(args)
108
+ # if args.save_name is None: logger = logging_handler(args.verbose, args.save_name)
109
+ # args.load_response_model = True
110
+
111
+ # print("1. Loading question model ...")
112
+ # question_model = get_question_model(args)
113
+ # args.question_generator = question_model.question_generator
114
+ # print("2. Loading response model simul ...")
115
+ # response_model_simul = get_response_model(args, args.response_type_simul)
116
+ # response_model_simul.to(args.device)
117
+ # print("3. Loading response model gtruth ...")
118
+ # response_model_gtruth = get_response_model(args, args.response_type_gtruth)
119
+ # response_model_gtruth.to(args.device)
120
+ # print("4. Loading caption model ...")
121
+ # caption_model = get_caption_model(args, question_model)
122
+
123
+ # # dataloader = DataLoader(dataset=ReferenceGameData(split='test',
124
+ # # num_images=args.num_images,
125
+ # # num_samples=args.num_samples,
126
+ # # sample_strategy=args.sample_strategy,
127
+ # # attribute_n=args.attribute_n))
128
+
129
+ # def return_modules():
130
+ # return question_model, response_model_simul, response_model_gtruth, caption_model
131
+ # # game_lens, game_preds = [], []
132
+ # for t, batch in enumerate(tqdm(dataloader)):
133
+ # image_files = [image[0] for image in batch['images'][:args.num_images]]
134
+ # image_files = [str(i).split('/')[1] for i in image_files]
135
+ # with open("mscoco_images_attribute_n=1.txt", 'a') as f:
136
+ # for i in image_files:
137
+ # f.write(str(i)+"\n")
138
+ # images = [np.asarray(Image.open(f"./../../../data/ms-coco/images/{i}")) for i in image_files]
139
+ # images = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images]
140
+ # p_y_x = (torch.ones(args.num_images)/args.num_images).to(question_model.device)
141
+
142
+ # if args.save_name is not None:
143
+ # logger = logging_handler(args.verbose, args.save_name, t)
144
+ # image_saver(images, args.save_name, t)
145
+
146
+ # captions = caption_model.get_captions(image_files)
147
+ # questions, target_questions = question_model.get_questions(image_files, captions, args.target_idx)
148
+
149
+ # question_model.reset_question_bank()
150
+ # logger.info(questions)
151
+ # for idx, c in enumerate(captions): logger.info(f"Image_{idx}: {c}")
152
+
153
+ # num_questions_original = len(questions)
154
+ # for j in range(min(args.max_num_questions, num_questions_original)):
155
+ # # Select best question
156
+ # question = question_model.select_best_question(p_y_x, questions, images, captions, response_model_simul)
157
+ # logger.info(f"Question: {question}")
158
+
159
+ # # Ask the question and get the model's response
160
+ # response = response_model_gtruth.get_response(question, images[args.target_idx], captions[args.target_idx], target_questions, is_a=1-args.include_what)
161
+ # logger.info(f"Response: {response}")
162
+
163
+ # # Update the probabilities
164
+ # p_r_qy = response_model_simul.get_p_r_qy(response, question, images, captions)
165
+ # logger.info(f"P(r|q,y):\n{np.around(p_r_qy.cpu().detach().numpy(), 3)}")
166
+ # p_y_xqr = p_y_x*p_r_qy
167
+ # p_y_xqr = p_y_xqr/torch.sum(p_y_xqr)if torch.sum(p_y_xqr) != 0 else torch.zeros_like(p_y_xqr)
168
+ # p_y_x = p_y_xqr
169
+ # logger.info(f"Updated distribution:\n{np.around(p_y_x.cpu().detach().numpy(), 3)}\n")
170
+
171
+ # # Don't repeat the same question again in the future
172
+ # questions.remove(question)
173
+
174
+ # # Terminate if probability exceeds threshold or if out of questions to ask
175
+ # top_prob = torch.max(p_y_x).item()
176
+ # if top_prob >= args.threshold or j==min(args.max_num_questions, num_questions_original)-1:
177
+ # game_preds.append(torch.argmax(p_y_x).item())
178
+ # game_lens.append(j+1)
179
+ # logger.info(f"pred:{game_preds[-1]} game_len:{game_lens[-1]}")
180
+ # break
181
+
182
+ # logger = logging_handler(args.verbose, args.save_name, "final_results")
183
+ # logger.info(f"Game lenths:\n{game_lens}")
184
+ # logger.info(sum(game_lens)/len(game_lens))
185
+ # logger.info(f"Predictions:\n{game_preds}")
186
+ # logger.info(f"Accuracy:\n{sum([i==args.target_idx for i in game_preds])/len(game_preds)}")
model/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import matplotlib.pyplot as plt
4
+ from PIL import Image
5
+ import nltk
6
+
7
+ def logging_handler(verbose, save_name, idx=0):
8
+ logger = logging.getLogger(str(idx))
9
+ logger.setLevel(logging.INFO)
10
+
11
+ stream_logger = logging.StreamHandler()
12
+ stream_logger.setFormatter(logging.Formatter("%(message)s"))
13
+ logger.addHandler(stream_logger)
14
+
15
+ if save_name is not None:
16
+ savepath = f"results/{save_name}"
17
+ if not os.path.exists(savepath):
18
+ os.makedirs(savepath)
19
+ file_logger = logging.FileHandler(f"{savepath}/{idx}.log")
20
+ file_logger.setFormatter(logging.Formatter("%(message)s"))
21
+ logger.addHandler(file_logger)
22
+
23
+ return logger
24
+
25
+
26
+ def image_saver(images, save_name, idx=0, interactive=True):
27
+ fig, a = plt.subplots(2,5)
28
+ fig.set_size_inches(30, 15)
29
+ for i in range(10):
30
+ a[i//5][i%5].imshow(images[i])
31
+ a[i//5][i%5].axis('off')
32
+ a[i//5][i%5].set_aspect('equal')
33
+ plt.tight_layout()
34
+ plt.subplots_adjust(wspace=0, hspace=0)
35
+ if not interactive:
36
+ plt.savefig(f"results/{save_name}/{idx}.png")
37
+ else:
38
+ plt.savefig(f"{save_name}.png")
39
+
40
+ def assert_checks(args):
41
+ if args.question_strategy=="gpt3":
42
+ assert args.include_what
43
+
44
+ def extract_nouns(sents):
45
+ noun_list = []
46
+ for idx, s in enumerate(sents):
47
+ curr = []
48
+ sent = (nltk.pos_tag(s.split()))
49
+ for word in sent:
50
+ if word[1] not in ["NN", "NNS"]: continue
51
+ currword = word[0].replace('.','')
52
+ curr.append(currword.lower())
53
+ noun_list.append(curr)
54
+ return noun_list
open_db.py CHANGED
@@ -4,9 +4,4 @@ import pandas as pd
4
  db = sqlite3.connect("response.db")
5
  df = pd.read_sql('SELECT * from responses', db)
6
  print(df)
7
- # def importdb(db):
8
- # conn = sqlite3.connect(db)
9
- # c = conn.cursor()
10
- # c.execute("SELECT name FROM sqlite_master WHERE type='table';")
11
- # for table in c.fetchall()
12
- # yield list(c.execute('SELECT * from ?;', (table[0],)))
 
4
  db = sqlite3.connect("response.db")
5
  df = pd.read_sql('SELECT * from responses', db)
6
  print(df)
7
+ df.to_csv("responses.csv", index=False)
 
 
 
 
 
pilot-study.csv ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ taskID,mscoco-id,task-type
2
+ 0,35,wh- standard
3
+ 1,102,wh- hard (n=1)
4
+ 2,25,wh- standard
5
+ 3,133,yes/no hard (n=1)
6
+ 4,80,wh- hard (n=1)
7
+ 5,15,wh- standard
8
+ 6,92,wh- hard (n=1)
9
+ 7,108,wh- hard (n=1)
10
+ 8,60,yes/no standard
11
+ 9,57,yes/no standard
12
+ 10,125,yes/no hard (n=1)
13
+ 11,56,yes/no standard
14
+ 12,137,yes/no hard (n=1)
15
+ 13,40,yes/no standard
16
+ 14,134,yes/no hard (n=1)
17
+ 15,130,yes/no hard (n=1)
18
+ 16,89,wh- hard (n=1)
19
+ 17,19,wh- standard
20
+ 18,58,yes/no standard
21
+ 19,81,wh- hard (n=1)
22
+ 20,5,wh- standard
23
+ 21,73,yes/no standard
24
+ 22,54,yes/no standard
25
+ 23,0,wh- standard
26
+ 24,14,wh- standard
27
+ 25,113,wh- hard (n=1)
28
+ 26,34,wh- standard
29
+ 27,159,yes/no hard (n=1)
30
+ 28,135,yes/no hard (n=1)
31
+ 29,2,wh- standard
32
+ 30,156,yes/no hard (n=1)
33
+ 31,30,wh- standard
34
+ 32,104,wh- hard (n=1)
35
+ 33,128,yes/no hard (n=1)
36
+ 34,18,wh- standard
37
+ 35,157,yes/no hard (n=1)
38
+ 36,1,wh- standard
39
+ 37,42,yes/no standard
40
+ 38,131,yes/no hard (n=1)
41
+ 39,115,wh- hard (n=1)
42
+ 40,120,yes/no hard (n=1)
43
+ 41,3,wh- standard
44
+ 42,63,yes/no standard
45
+ 43,65,yes/no standard
46
+ 44,103,wh- hard (n=1)
47
+ 45,124,yes/no hard (n=1)
48
+ 46,21,wh- standard
49
+ 47,72,yes/no standard
50
+ 48,62,yes/no standard
51
+ 49,47,yes/no standard
52
+ 50,78,yes/no standard
53
+ 51,109,wh- hard (n=1)
54
+ 52,136,yes/no hard (n=1)
55
+ 53,158,yes/no hard (n=1)
56
+ 54,61,yes/no standard
57
+ 55,27,wh- standard
58
+ 56,24,wh- standard
59
+ 57,123,yes/no hard (n=1)
60
+ 58,70,yes/no standard
61
+ 59,91,wh- hard (n=1)
62
+ 60,55,yes/no standard
63
+ 61,87,wh- hard (n=1)
64
+ 62,46,yes/no standard
65
+ 63,33,wh- standard
66
+ 64,16,wh- standard
67
+ 65,147,yes/no hard (n=1)
68
+ 66,85,wh- hard (n=1)
69
+ 67,59,yes/no standard
70
+ 68,99,wh- hard (n=1)
71
+ 69,117,wh- hard (n=1)
72
+ 70,9,wh- standard
73
+ 71,122,yes/no hard (n=1)
74
+ 72,53,yes/no standard
75
+ 73,22,wh- standard
76
+ 74,8,wh- standard
77
+ 75,29,wh- standard
78
+ 76,83,wh- hard (n=1)
79
+ 77,37,wh- standard
80
+ 78,66,yes/no standard
81
+ 79,41,yes/no standard
82
+ 80,94,wh- hard (n=1)
83
+ 81,98,wh- hard (n=1)
84
+ 82,110,wh- hard (n=1)
85
+ 83,77,yes/no standard
86
+ 84,151,yes/no hard (n=1)
87
+ 85,121,yes/no hard (n=1)
88
+ 86,6,wh- standard
89
+ 87,45,yes/no standard
90
+ 88,155,yes/no hard (n=1)
91
+ 89,88,wh- hard (n=1)
92
+ 90,96,wh- hard (n=1)
93
+ 91,75,yes/no standard
94
+ 92,112,wh- hard (n=1)
95
+ 93,49,yes/no standard
96
+ 94,152,yes/no hard (n=1)
97
+ 95,38,wh- standard
98
+ 96,7,wh- standard
99
+ 97,52,yes/no standard
100
+ 98,101,wh- hard (n=1)
101
+ 99,76,yes/no standard
102
+ 100,28,wh- standard
103
+ 101,114,wh- hard (n=1)
104
+ 102,139,yes/no hard (n=1)
105
+ 103,74,yes/no standard
106
+ 104,149,yes/no hard (n=1)
107
+ 105,84,wh- hard (n=1)
108
+ 106,79,yes/no standard
109
+ 107,127,yes/no hard (n=1)
110
+ 108,126,yes/no hard (n=1)
111
+ 109,116,wh- hard (n=1)
112
+ 110,71,yes/no standard
113
+ 111,67,yes/no standard
114
+ 112,10,wh- standard
115
+ 113,143,yes/no hard (n=1)
116
+ 114,132,yes/no hard (n=1)
117
+ 115,90,wh- hard (n=1)
118
+ 116,140,yes/no hard (n=1)
119
+ 117,144,yes/no hard (n=1)
120
+ 118,106,wh- hard (n=1)
121
+ 119,32,wh- standard
122
+ 120,154,yes/no hard (n=1)
123
+ 121,11,wh- standard
124
+ 122,17,wh- standard
125
+ 123,145,yes/no hard (n=1)
126
+ 124,118,wh- hard (n=1)
127
+ 125,48,yes/no standard
128
+ 126,148,yes/no hard (n=1)
129
+ 127,26,wh- standard
130
+ 128,51,yes/no standard
131
+ 129,13,wh- standard
132
+ 130,39,wh- standard
133
+ 131,153,yes/no hard (n=1)
134
+ 132,12,wh- standard
135
+ 133,93,wh- hard (n=1)
136
+ 134,107,wh- hard (n=1)
137
+ 135,86,wh- hard (n=1)
138
+ 136,31,wh- standard
139
+ 137,95,wh- hard (n=1)
140
+ 138,44,yes/no standard
141
+ 139,69,yes/no standard
142
+ 140,150,yes/no hard (n=1)
143
+ 141,4,wh- standard
144
+ 142,142,yes/no hard (n=1)
145
+ 143,43,yes/no standard
146
+ 144,50,yes/no standard
147
+ 145,100,wh- hard (n=1)
148
+ 146,129,yes/no hard (n=1)
149
+ 147,68,yes/no standard
150
+ 148,146,yes/no hard (n=1)
151
+ 149,64,yes/no standard
152
+ 150,23,wh- standard
153
+ 151,82,wh- hard (n=1)
154
+ 152,111,wh- hard (n=1)
155
+ 153,97,wh- hard (n=1)
156
+ 154,119,wh- hard (n=1)
157
+ 155,141,yes/no hard (n=1)
158
+ 156,20,wh- standard
159
+ 157,36,wh- standard
160
+ 158,138,yes/no hard (n=1)
161
+ 159,105,wh- hard (n=1)
response_db.py CHANGED
@@ -75,5 +75,4 @@ class StResponseDb(ResponseDb):
75
 
76
  if __name__ == "__main__":
77
  db = ResponseDb()
78
- print(db.get_all())
79
-
 
75
 
76
  if __name__ == "__main__":
77
  db = ResponseDb()
78
+ print(db.get_all())