sedrickkeh commited on
Commit
5a72dbb
1 Parent(s): 2bf46a7

major updates demo v2

Browse files
__pycache__/response_db.cpython-37.pyc DELETED
Binary file (2.86 kB)
 
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  from response_db import ResponseDb
 
3
  from create_cache import Game_Cache
4
  import numpy as np
5
  from PIL import Image
@@ -7,7 +8,6 @@ import pandas as pd
7
  import torch
8
  import pickle
9
  import uuid
10
-
11
  import nltk
12
  nltk.download('punkt')
13
 
@@ -19,23 +19,34 @@ css = """
19
  .msg.bot {background-color:lightgray}
20
  .na_button {background-color:red;color:red}
21
  """
 
 
 
 
 
 
 
 
 
22
 
23
  from model.run_question_asking_model import return_modules, return_modules_yn
24
- question_model, response_model_simul, _, caption_model = return_modules()
25
- question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn()
26
 
27
  class Game_Session:
28
  def __init__(self, taskid, yn, hard_setting):
29
  self.yn = yn
30
  self.hard_setting = hard_setting
31
 
32
- global question_model, response_model_simul, caption_model
33
- global question_model_yn, response_model_simul_yn, caption_model_yn
34
  self.question_model = question_model
35
  self.response_model_simul = response_model_simul
 
36
  self.caption_model = caption_model
37
  self.question_model_yn = question_model_yn
38
  self.response_model_simul_yn = response_model_simul_yn
 
39
  self.caption_model_yn = caption_model_yn
40
 
41
  global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions
@@ -48,55 +59,69 @@ class Game_Session:
48
 
49
  def set_curr_models(self):
50
  if self.yn:
51
- self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model_yn, self.caption_model_yn, self.response_model_simul_yn
52
  else:
53
- self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model, self.caption_model, self.response_model_simul
54
 
55
  def get_next_question(self):
56
  return self.curr_question_model.select_best_question(self.p_y_x, self.questions, self.images_np, self.captions, self.curr_response_model_simul)
57
 
 
 
 
58
 
59
  def ask_a_question(input, taskid, gs):
 
 
 
 
 
 
 
 
 
 
60
  gs.history.append(input)
61
  gs.p_r_qy = gs.curr_response_model_simul.get_p_r_qy(input, gs.history[-2], gs.images_np, gs.captions)
62
  gs.p_y_xqr = gs.p_y_x*gs.p_r_qy
63
  gs.p_y_xqr = gs.p_y_xqr/torch.sum(gs.p_y_xqr)if torch.sum(gs.p_y_xqr) != 0 else torch.zeros_like(gs.p_y_xqr)
64
  gs.p_y_x = gs.p_y_xqr
65
  gs.questions.remove(gs.history[-2])
66
- db.add(gs.game_id, taskid, len(gs.history)//2-1, gs.history[-2], gs.history[-1])
67
  gs.history.append(gs.get_next_question())
68
 
69
  top_prob = torch.max(gs.p_y_x).item()
70
  top_pred = torch.argmax(gs.p_y_x).item()
71
- if top_prob > 0.8:
72
  gs.history = gs.history[:-1]
73
- db.add(gs.game_id, taskid, len(gs.history)//2, f"Guess: Image {top_pred}", "")
74
 
75
  # write some HTML
76
  html = "<div class='chatbot'>"
77
  for m, msg in enumerate(gs.history):
78
- if msg=="nothing": msg="n/a"
79
  cls = "bot" if m%2 == 0 else "user"
80
  html += "<div class='msg {}'> {}</div>".format(cls, msg)
81
  html += "</div>"
82
 
83
  ### Game finished:
84
- if top_prob > 0.8:
85
  html += f"<p>The model identified <b>Image {top_pred+1}</b> as the image. Please select a new task ID to continue.</p>"
86
- return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False)
 
87
  else:
88
  if not gs.yn:
89
- return html, gs, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False)
90
  else:
91
- return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True)
92
 
93
 
94
  def set_images(taskid):
95
  pilot_study = pd.read_csv("pilot-study.csv")
96
  taskid_original = taskid
 
97
  taskid = pilot_study['mscoco-id'].tolist()[int(taskid)]
98
 
99
- with open(f'cache/{int(taskid)}.p', 'rb') as fp:
100
  game_cache = pickle.load(fp)
101
  gs = Game_Session(int(taskid), game_cache.yn, game_cache.hard_setting)
102
  id1 = f"./mscoco-images/val2014/{game_cache.image_files[0]}"
@@ -123,12 +148,14 @@ def set_images(taskid):
123
  first_question = gs.curr_question_model.select_best_question(gs.p_y_x, gs.questions, gs.images_np, gs.captions, gs.curr_response_model_simul)
124
  first_question_html = f"<div class='chatbot'><div class='msg bot'>{first_question}</div></div>"
125
  gs.history.append(first_question)
126
- html = f"<p>Current Task ID: <b>{int(taskid_original)}</b></p>"
127
  if not gs.yn:
128
- return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False)
129
  else:
130
- return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=True), gr.Button.update(visible=True)
131
 
 
 
132
 
133
  with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
134
  gr.HTML("<h1>Image Q&A Guessing Game</h1>\
@@ -139,22 +166,19 @@ with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
139
  The model can see all 10 images and all the questions and answers for the current set of images. It will ask a question based on the available information.<br>\
140
  <span style='color: #0000ff'>The goal of the model is to accurately guess the correct image (i.e. <b><span style='color: #0000ff'>Image 1</span></b>) in as few turns as possible.<br>\
141
  Your goal is to help the model guess the image by answering as clearly and accurately as possible.</span><br><br>\
142
- <b>Guidelines:</b><br>\
143
- <ol style='font-size:120%;'>\
144
- <li>It is best to keep your answers short (a single word or a short phrase). No need to answer in full sentences.</li>\
145
- <li>If you feel that the question cannot be answered or does not apply to Image 1, please select N/A.</li>\
146
- </ol> \
147
- <br>\
148
- (Note: We are testing multiple game settings. In some instances, the game will be open-ended, while in other instances, the answer choices will be limited to yes/no.)<br></p>\
149
- <br>\
150
- <h2>Please enter a TaskID to start</h2>")
151
 
152
  with gr.Column():
153
  with gr.Row():
154
- taskid = gr.Number(label="Task ID (Enter a number from 0 to 160)", value=0)
155
- start_button = gr.Button("Enter")
156
- with gr.Row():
157
- task_text = gr.HTML()
158
 
159
  with gr.Column() as img_block:
160
  with gr.Row():
@@ -172,8 +196,22 @@ with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
172
  conversation = gr.HTML()
173
  game_session_state = gr.State()
174
 
175
- answer = gr.Textbox(placeholder="Insert answer here.", label="Answer the given question.", visible=False)
176
- null_answer = gr.Textbox("nothing", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  yes_answer = gr.Textbox("yes", visible=False)
178
  no_answer = gr.Textbox("no", visible=False)
179
 
@@ -185,11 +223,22 @@ with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
185
  with gr.Row():
186
  na_box = gr.Button("N/A", visible=False, elem_classes="na_button")
187
  submit = gr.Button("Submit", visible=False)
 
 
 
 
 
 
 
188
  ### Button click events
189
- start_button.click(fn=set_images, inputs=taskid, outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, game_session_state, conversation, task_text, answer, na_box, submit, taskid, start_button, yes_box, no_box])
190
- submit.click(fn=ask_a_question, inputs=[answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
191
- na_box.click(fn=ask_a_question, inputs=[null_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
192
- yes_box.click(fn=ask_a_question, inputs=[yes_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
193
- no_box.click(fn=ask_a_question, inputs=[no_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
 
 
 
 
194
 
195
  demo.launch()
 
1
  import gradio as gr
2
  from response_db import ResponseDb
3
+ from response_db import get_code
4
  from create_cache import Game_Cache
5
  import numpy as np
6
  from PIL import Image
 
8
  import torch
9
  import pickle
10
  import uuid
 
11
  import nltk
12
  nltk.download('punkt')
13
 
 
19
  .msg.bot {background-color:lightgray}
20
  .na_button {background-color:red;color:red}
21
  """
22
+ get_window_url_params = """
23
+ function(url_params) {
24
+ console.log(url_params);
25
+ const params = new URLSearchParams(window.location.search);
26
+ url_params = Object.fromEntries(params);
27
+ return url_params;
28
+ }
29
+ """
30
+ quals = {1001:99, 1002:136, 1003:56, 1004:105}
31
 
32
  from model.run_question_asking_model import return_modules, return_modules_yn
33
+ question_model, response_model_simul, response_model_gtruth, caption_model = return_modules()
34
+ question_model_yn, response_model_simul_yn, response_model_gtruth_yn, caption_model_yn = return_modules_yn()
35
 
36
  class Game_Session:
37
  def __init__(self, taskid, yn, hard_setting):
38
  self.yn = yn
39
  self.hard_setting = hard_setting
40
 
41
+ global question_model, response_model_simul, response_model_gtruth, caption_model
42
+ global question_model_yn, response_model_simul_yn, response_model_gtruth_yn, caption_model_yn
43
  self.question_model = question_model
44
  self.response_model_simul = response_model_simul
45
+ self.response_model_gtruth = response_model_gtruth
46
  self.caption_model = caption_model
47
  self.question_model_yn = question_model_yn
48
  self.response_model_simul_yn = response_model_simul_yn
49
+ self.response_model_gtruth_yn = response_model_gtruth_yn
50
  self.caption_model_yn = caption_model_yn
51
 
52
  global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions
 
59
 
60
  def set_curr_models(self):
61
  if self.yn:
62
+ self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul, self.curr_response_model_gtruth = self.question_model_yn, self.caption_model_yn, self.response_model_simul_yn, self.response_model_gtruth_yn
63
  else:
64
+ self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul, self.curr_response_model_gtruth = self.question_model, self.caption_model, self.response_model_simul, self.response_model_gtruth
65
 
66
  def get_next_question(self):
67
  return self.curr_question_model.select_best_question(self.p_y_x, self.questions, self.images_np, self.captions, self.curr_response_model_simul)
68
 
69
+ def get_model_gtruth_response(self, question):
70
+ return self.response_model_gtruth.get_response(question, self.images_np[0], self.captions[0], self.target_questions, is_a=self.yn)
71
+
72
 
73
  def ask_a_question(input, taskid, gs):
74
+ # input = gs.get_model_gtruth_response(gs.history[-1])
75
+
76
+ if input not in ["n/a", "yes", "no"] and input not in gs.curr_response_model_simul.model.config.label2id:
77
+ html = "<div class='chatbot'>"
78
+ for m, msg in enumerate(gs.history):
79
+ cls = "bot" if m%2 == 0 else "user"
80
+ html += "<div class='msg {}'> {}</div>".format(cls, msg)
81
+ html += "</div>"
82
+ return html, gs, gr.Dropdown.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Textbox.update(visible=False), gr.HTML.update(visible=True)
83
+
84
  gs.history.append(input)
85
  gs.p_r_qy = gs.curr_response_model_simul.get_p_r_qy(input, gs.history[-2], gs.images_np, gs.captions)
86
  gs.p_y_xqr = gs.p_y_x*gs.p_r_qy
87
  gs.p_y_xqr = gs.p_y_xqr/torch.sum(gs.p_y_xqr)if torch.sum(gs.p_y_xqr) != 0 else torch.zeros_like(gs.p_y_xqr)
88
  gs.p_y_x = gs.p_y_xqr
89
  gs.questions.remove(gs.history[-2])
90
+ if taskid not in quals: db.add(gs.game_id, taskid, len(gs.history)//2-1, gs.history[-2], gs.history[-1])
91
  gs.history.append(gs.get_next_question())
92
 
93
  top_prob = torch.max(gs.p_y_x).item()
94
  top_pred = torch.argmax(gs.p_y_x).item()
95
+ if top_prob > 0.8 or len(gs.history) > 19:
96
  gs.history = gs.history[:-1]
97
+ if taskid not in quals: db.add(gs.game_id, taskid, len(gs.history)//2, f"Guess: Image {top_pred}", "")
98
 
99
  # write some HTML
100
  html = "<div class='chatbot'>"
101
  for m, msg in enumerate(gs.history):
 
102
  cls = "bot" if m%2 == 0 else "user"
103
  html += "<div class='msg {}'> {}</div>".format(cls, msg)
104
  html += "</div>"
105
 
106
  ### Game finished:
107
+ if top_prob > 0.8 or len(gs.history) > 19:
108
  html += f"<p>The model identified <b>Image {top_pred+1}</b> as the image. Please select a new task ID to continue.</p>"
109
+ finish_html = "<h2>Congratulations on finishing the game! Please copy the Task Finish Code below to MTurk to complete your task. You can now exit this window.</h2>"
110
+ return html, gs, gr.Dropdown.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Textbox.update(value=get_code(taskid, gs.history, top_pred), visible=True), gr.HTML.update(value=finish_html, visible=True)
111
  else:
112
  if not gs.yn:
113
+ return html, gs, gr.Dropdown.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Textbox.update(visible=False), gr.HTML.update(visible=False)
114
  else:
115
+ return html, gs, gr.Dropdown.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True), gr.Textbox.update(visible=False), gr.HTML.update(visible=False)
116
 
117
 
118
  def set_images(taskid):
119
  pilot_study = pd.read_csv("pilot-study.csv")
120
  taskid_original = taskid
121
+ if int(taskid) in quals: taskid = quals[int(taskid)]
122
  taskid = pilot_study['mscoco-id'].tolist()[int(taskid)]
123
 
124
+ with open(f'cache-soft/{int(taskid)}.p', 'rb') as fp:
125
  game_cache = pickle.load(fp)
126
  gs = Game_Session(int(taskid), game_cache.yn, game_cache.hard_setting)
127
  id1 = f"./mscoco-images/val2014/{game_cache.image_files[0]}"
 
148
  first_question = gs.curr_question_model.select_best_question(gs.p_y_x, gs.questions, gs.images_np, gs.captions, gs.curr_response_model_simul)
149
  first_question_html = f"<div class='chatbot'><div class='msg bot'>{first_question}</div></div>"
150
  gs.history.append(first_question)
151
+ # html = f"<p>Current Task ID: <b>{int(taskid_original)}</b></p>"
152
  if not gs.yn:
153
+ return id1, id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.Dropdown.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False)
154
  else:
155
+ return id1, id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.Dropdown.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=True), gr.Button.update(visible=True)
156
 
157
+ def reset_dropdown():
158
+ return gr.Dropdown.update(visible=True, value='')
159
 
160
  with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
161
  gr.HTML("<h1>Image Q&A Guessing Game</h1>\
 
166
  The model can see all 10 images and all the questions and answers for the current set of images. It will ask a question based on the available information.<br>\
167
  <span style='color: #0000ff'>The goal of the model is to accurately guess the correct image (i.e. <b><span style='color: #0000ff'>Image 1</span></b>) in as few turns as possible.<br>\
168
  Your goal is to help the model guess the image by answering as clearly and accurately as possible.</span><br><br>\
169
+ (Note: We are testing multiple game settings. In some instances, the game will be open-ended, while in other instances, the answer choices will be limited to yes/no.)<br><br>\
170
+ <b>Selecting N/A:</b><br>\
171
+ <ul style='font-size:120%;'>\
172
+ <li>In some games, there will be an N/A option. Please select N/A only if the question is unanswerable BECAUSE IT DOES NOT APPLY TO THE IMAGE.</li>\
173
+ <li>Otherwise, please select the closest possible option.</li>\
174
+ <li>e.g. Q:\"What is the dog doing?\" Please select N/A if there is no dog in the image.\
175
+ </ul> \
176
+ <br>")
 
177
 
178
  with gr.Column():
179
  with gr.Row():
180
+ taskid = gr.Number(label="Task ID (Enter a number from 0 to 160)", visible=False)
181
+ start_button = gr.Button("Start")
 
 
182
 
183
  with gr.Column() as img_block:
184
  with gr.Row():
 
196
  conversation = gr.HTML()
197
  game_session_state = gr.State()
198
 
199
+ # answer = gr.Textbox(placeholder="Insert answer here.", label="Answer the given question.", visible=False)
200
+ full_vocab_dict = response_model_simul_yn.model.config.label2id
201
+ vocab_list_numbers, vocab_list_letters = [], []
202
+ for i in full_vocab_dict:
203
+ if i=="None" or i is None: continue
204
+ if i[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
205
+ vocab_list_numbers.append(i)
206
+ else:
207
+ vocab_list_letters.append(i)
208
+ with gr.Row():
209
+ answer = gr.Dropdown(vocab_list_letters+vocab_list_numbers, label="Answer the given question.", \
210
+ info="If you cannot find your exact answer, pick the word you feel would be most appropriate. ONLY SELECT N/A IF THE QUESTION DOES NOT APPLY TO THE IMAGE.", visible=False)
211
+ clear_box = gr.Button("Reset Selection \n(Use this to clear the dropdown selection.)", visible=False)
212
+ with gr.Row():
213
+ vocab_warning = gr.HTML("<h3>The word you typed in is not a valid word in the model vocabulary. Please clear it and select a valid word from the dropdown menu.</h3>", visible=False)
214
+ null_answer = gr.Textbox("n/a", visible=False)
215
  yes_answer = gr.Textbox("yes", visible=False)
216
  no_answer = gr.Textbox("no", visible=False)
217
 
 
223
  with gr.Row():
224
  na_box = gr.Button("N/A", visible=False, elem_classes="na_button")
225
  submit = gr.Button("Submit", visible=False)
226
+ with gr.Row():
227
+ reward_code = gr.Textbox("", label="Task Finish Code", visible=False)
228
+
229
+ with gr.Column() as img_block0:
230
+ with gr.Row():
231
+ img0 = gr.Image(label="Image 1", show_label=True).style(height=700, width=700)
232
+
233
  ### Button click events
234
+ start_button.click(fn=set_images, inputs=taskid, outputs=[img0, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, game_session_state, conversation, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box])
235
+ submit.click(fn=ask_a_question, inputs=[answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
236
+ na_box.click(fn=ask_a_question, inputs=[null_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
237
+ yes_box.click(fn=ask_a_question, inputs=[yes_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
238
+ no_box.click(fn=ask_a_question, inputs=[no_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
239
+ clear_box.click(fn=reset_dropdown, inputs=[], outputs=[answer])
240
+
241
+ url_params = gr.JSON({}, visible=False, label="URL Params")
242
+ demo.load(fn = lambda url_params : gr.Number.update(value=int(url_params['p'])), inputs=[url_params], outputs=taskid, _js=get_window_url_params)
243
 
244
  demo.launch()
create_cache.py CHANGED
@@ -37,11 +37,11 @@ def create_cache(taskid):
37
  global question_model_yn, response_model_simul_yn, caption_model_yn
38
  if taskid in yn_indices:
39
  yn = True
40
- curr_question_model, curr_response_model_simul, curr_caption_model = question_model, response_model_simul, caption_model
41
  taskid-=40
42
  else:
43
  yn = False
44
- curr_question_model, curr_response_model_simul, curr_caption_model = question_model_yn, response_model_simul_yn, caption_model_yn
45
  if taskid in hard_setting_indices:
46
  hard_setting = True
47
  image_list_curr = image_list_hard
@@ -65,6 +65,9 @@ def create_cache(taskid):
65
  image_names.append(image_list_curr[int(taskid)*10+i])
66
  image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
67
  image_files = [x[15:] for x in image_files]
 
 
 
68
  images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in image_files]
69
  images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images_np]
70
  p_y_x = (torch.ones(10)/10).to(curr_question_model.device)
@@ -74,7 +77,7 @@ def create_cache(taskid):
74
  first_question = curr_question_model.select_best_question(p_y_x, questions, images_np, captions, curr_response_model_simul)
75
 
76
  gc = Game_Cache(curr_question_model.question_bank, image_names, yn, hard_setting)
77
- with open(f'./cache{int(taskid)}.p', 'wb') as fp:
78
  pickle.dump(gc, fp, protocol=pickle.HIGHEST_PROTOCOL)
79
 
80
  if __name__=="__main__":
 
37
  global question_model_yn, response_model_simul_yn, caption_model_yn
38
  if taskid in yn_indices:
39
  yn = True
40
+ curr_question_model, curr_response_model_simul, curr_caption_model = question_model_yn, response_model_simul_yn, caption_model_yn
41
  taskid-=40
42
  else:
43
  yn = False
44
+ curr_question_model, curr_response_model_simul, curr_caption_model = question_model, response_model_simul, caption_model
45
  if taskid in hard_setting_indices:
46
  hard_setting = True
47
  image_list_curr = image_list_hard
 
65
  image_names.append(image_list_curr[int(taskid)*10+i])
66
  image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
67
  image_files = [x[15:] for x in image_files]
68
+ import os
69
+ for i in image_files:
70
+ os.system(f"cp ./../../../data/ms-coco/images/{i} ./mscoco-images/val2014/")
71
  images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in image_files]
72
  images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images_np]
73
  p_y_x = (torch.ones(10)/10).to(curr_question_model.device)
 
77
  first_question = curr_question_model.select_best_question(p_y_x, questions, images_np, captions, curr_response_model_simul)
78
 
79
  gc = Game_Cache(curr_question_model.question_bank, image_names, yn, hard_setting)
80
+ with open(f'./cache-soft/{int(original_taskid)}.p', 'wb') as fp:
81
  pickle.dump(gc, fp, protocol=pickle.HIGHEST_PROTOCOL)
82
 
83
  if __name__=="__main__":
data/questions_gpt4.csv ADDED
The diff for this file is too large to render. See raw diff
 
model/model/question_asking_model.py CHANGED
@@ -66,13 +66,59 @@ class QuestionAskingModelGPT3(QuestionAskingModel):
66
  questions = [i+"?" for i in questions]
67
  return questions
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def get_questions(self, images, captions, target_idx=0):
70
  questions = []
71
  for i, (image, caption) in enumerate(zip(images, captions)):
72
  if caption in self.gpt3_captions.caption.tolist():
73
  image_questions = self.gpt3_captions[self.gpt3_captions.caption==caption].question.tolist()
74
  else:
75
- image_questions = self.generate_gpt3_questions(caption)
76
  image_df = pd.DataFrame({'caption':[caption for _ in image_questions], 'question':image_questions})
77
  self.gpt3_captions = pd.concat([self.gpt3_captions, image_df])
78
  self.gpt3_captions.to_csv(self.gpt3_path, index=False)
 
66
  questions = [i+"?" for i in questions]
67
  return questions
68
 
69
+ def generate_gpt4_questions(self, caption):
70
+ print("generating gpt4 captions")
71
+ instructions_prompt="You are tasked to produce reasonable questions from a given caption.\
72
+ The questions you ask must be answerable only using visual information.\
73
+ As such, never ask questions that involve exact measurement such as \"How tall\", \"How big\", or \"How far\", since these cannot be easily inferred from just looking at an object.\
74
+ Likewise, never ask questions that involve age (\"How old\"), composition (\"What is it made of\"), material, emotion, or personal relationship.\
75
+ When asking \"Where\" questions, the subject of your question must be a person or a small object.\
76
+ Never ask questions that can be answered with yes or no.\
77
+ When refering to objects, try to be general. For example, instead of saying \"cat\", you should say \"animal\". Instead of saying \"cake\", you should say \"food\".\
78
+ I repeat, when refering to object, try to be general!\
79
+ Good questions to ask include general \"What color\", as well as general probing questions such as \"What is the man doing?\" or \"What is the main subject of the image?\"\
80
+ For each caption, please generate 3-5 reasonable questions."
81
+ c1="A living room with a couch, coffee table and two large windows with white curtains."
82
+ q1="What color is the couch? How many windows are there? How many tables are there? What color is the table? What color are the curtains? What is next to the table? What is next to the couch?"
83
+ c2="A cat is wearing a pink wool hat."
84
+ q2="What color is the animal? What color is the hat? What is the cat wearing? "
85
+ c3="A stop sign with a skeleton painted on it, next to a car."
86
+ q3="What color is the sign? What color is the car? What is next to the sign? What is next to the car? What is on the sign? Where is the car?"
87
+ c4="A man brushing his teeth with a toothbrush"
88
+ q4="What is the man doing? Where is the man? What color is the toothbrush?"
89
+ prompt = f"{instructions_prompt}\n"
90
+ prompt+=f"\nCaption: {c1}\nQuestions: {q1}\n"
91
+ prompt+=f"\nCaption: {c2}\nQuestions: {q2}\n"
92
+ prompt+=f"\nCaption: {c3}\nQuestions: {q3}\n"
93
+ prompt+=f"\nCaption: {c4}\nQuestions: {q4}\n"
94
+ prompt+=f"\nCaption: {caption}\nQuestions:"
95
+ response = openai.ChatCompletion.create(
96
+ model="gpt-4-0314",
97
+ messages=[
98
+ {"role": "system", "content": instructions_prompt},
99
+ {"role": "user", "content": c1},
100
+ {"role": "assistant", "content": q1},
101
+ {"role": "user", "content": c2},
102
+ {"role": "assistant", "content": q2},
103
+ {"role": "user", "content": c3},
104
+ {"role": "assistant", "content": q3},
105
+ {"role": "user", "content": c4},
106
+ {"role": "assistant", "content": q4},
107
+ {"role": "user", "content": caption}
108
+ ]
109
+ )
110
+ questions = response["choices"][0]["message"]["content"]
111
+ questions = [i.strip() for i in questions.split('?') if len(i.strip())>1]
112
+ questions = [i+"?" for i in questions]
113
+ return questions
114
+
115
  def get_questions(self, images, captions, target_idx=0):
116
  questions = []
117
  for i, (image, caption) in enumerate(zip(images, captions)):
118
  if caption in self.gpt3_captions.caption.tolist():
119
  image_questions = self.gpt3_captions[self.gpt3_captions.caption==caption].question.tolist()
120
  else:
121
+ image_questions = self.generate_gpt4_questions(caption)
122
  image_df = pd.DataFrame({'caption':[caption for _ in image_questions], 'question':image_questions})
123
  self.gpt3_captions = pd.concat([self.gpt3_captions, image_df])
124
  self.gpt3_captions.to_csv(self.gpt3_path, index=False)
model/model/question_model_base.py CHANGED
@@ -1,8 +1,6 @@
1
  import torch
2
- import scipy
3
  import numpy as np
4
  import operator
5
- import random
6
  from model.model.question_generator import QuestionGenerator
7
 
8
 
@@ -39,7 +37,7 @@ class QuestionAskingModel():
39
  p_r_qy = p_r_qy.detach().cpu().numpy()
40
  is_a_multiplier.append(p_r_qy)
41
  if len(is_a_multiplier)==0: is_a_multiplier.append([0 for _ in range(self.num_images)])
42
- is_a_multiplier = torch.tensor(scipy.stats.mstats.gmean(is_a_multiplier, axis=0)).to("cuda")
43
  if self.multiplier_mode=="hard":
44
  for i in range(is_a_multiplier.shape[0]):
45
  if is_a_multiplier[i]<0.5: is_a_multiplier[i]=1e-6
 
1
  import torch
 
2
  import numpy as np
3
  import operator
 
4
  from model.model.question_generator import QuestionGenerator
5
 
6
 
 
37
  p_r_qy = p_r_qy.detach().cpu().numpy()
38
  is_a_multiplier.append(p_r_qy)
39
  if len(is_a_multiplier)==0: is_a_multiplier.append([0 for _ in range(self.num_images)])
40
+ is_a_multiplier = torch.tensor(np.mean(is_a_multiplier, axis=0)).to("cuda")
41
  if self.multiplier_mode=="hard":
42
  for i in range(is_a_multiplier.shape[0]):
43
  if is_a_multiplier[i]<0.5: is_a_multiplier[i]=1e-6
model/model/response_model.py CHANGED
@@ -121,7 +121,7 @@ class ResponseModelVQA(ResponseModel):
121
  no_cnt = sum([i.lower()=="no" for i in is_a_responses])
122
  if len(is_a_responses)>0 and no_cnt/len(is_a_responses)>=0.5:
123
  if question[:8]=="How many": return "0"
124
- else: return "nothing"
125
  if self.vqa_type in ["vilt1", "vilt2"]:
126
  outputs = self.model(**encoding)
127
  logits = torch.nn.functional.softmax(outputs.logits, dim=1)
@@ -165,9 +165,7 @@ class ResponseModelVQA(ResponseModel):
165
  response_idx = self.get_response_idx(response)
166
  p_r_qy = logits[:,response_idx]
167
 
168
- # check if this
169
- # consider rerunning also without the geometric mean
170
- if response=="nothing":
171
  is_a_questions = self.question_generator.generate_is_there_question_v2(question)
172
  for idx, (caption, image) in enumerate(zip(captions, images)):
173
  current_responses = []
 
121
  no_cnt = sum([i.lower()=="no" for i in is_a_responses])
122
  if len(is_a_responses)>0 and no_cnt/len(is_a_responses)>=0.5:
123
  if question[:8]=="How many": return "0"
124
+ else: return "n/a"
125
  if self.vqa_type in ["vilt1", "vilt2"]:
126
  outputs = self.model(**encoding)
127
  logits = torch.nn.functional.softmax(outputs.logits, dim=1)
 
165
  response_idx = self.get_response_idx(response)
166
  p_r_qy = logits[:,response_idx]
167
 
168
+ if response=="n/a":
 
 
169
  is_a_questions = self.question_generator.generate_is_there_question_v2(question)
170
  for idx, (caption, image) in enumerate(zip(captions, images)):
171
  current_responses = []
model/run_question_asking_model.py CHANGED
@@ -34,7 +34,7 @@ parser.add_argument('--response_type_gtruth', type=str, default='VQA2', choices=
34
  parser.add_argument('--question_strategy', type=str, default='gpt3', choices=['rule', 'gpt3'])
35
  parser.add_argument('--multiplier_mode', type=str, default='soft', choices=['soft', 'hard', 'none'])
36
 
37
- parser.add_argument('--gpt3_save_name', type=str, default='questions_gpt3')
38
  parser.add_argument('--save_name', type=str, default=None)
39
  parser.add_argument('--verbose', action='store_true')
40
  args = parser.parse_args()
@@ -43,8 +43,6 @@ args.include_what=True
43
  args.response_type_simul='VQA1'
44
  args.response_type_gtruth='VQA3'
45
  args.multiplier_mode='soft'
46
- args.sample_strategy='attribute'
47
- args.attribute_n=1
48
  args.caption_strategy='gtruth'
49
  assert_checks(args)
50
  if args.save_name is None: logger = logging_handler(args.verbose, args.save_name)
@@ -74,8 +72,6 @@ args.include_what=False
74
  args.response_type_simul='VQA1'
75
  args.response_type_gtruth='VQA3'
76
  args.multiplier_mode='none'
77
- args.sample_strategy='attribute'
78
- args.attribute_n=1
79
  args.caption_strategy='gtruth'
80
 
81
  print("1. Loading question model ...")
 
34
  parser.add_argument('--question_strategy', type=str, default='gpt3', choices=['rule', 'gpt3'])
35
  parser.add_argument('--multiplier_mode', type=str, default='soft', choices=['soft', 'hard', 'none'])
36
 
37
+ parser.add_argument('--gpt3_save_name', type=str, default='questions_gpt4')
38
  parser.add_argument('--save_name', type=str, default=None)
39
  parser.add_argument('--verbose', action='store_true')
40
  args = parser.parse_args()
 
43
  args.response_type_simul='VQA1'
44
  args.response_type_gtruth='VQA3'
45
  args.multiplier_mode='soft'
 
 
46
  args.caption_strategy='gtruth'
47
  assert_checks(args)
48
  if args.save_name is None: logger = logging_handler(args.verbose, args.save_name)
 
72
  args.response_type_simul='VQA1'
73
  args.response_type_gtruth='VQA3'
74
  args.multiplier_mode='none'
 
 
75
  args.caption_strategy='gtruth'
76
 
77
  print("1. Loading question model ...")
mscoco-images/val2014/COCO_val2014_000000564366.jpg CHANGED
open_db.py CHANGED
@@ -4,4 +4,12 @@ import pandas as pd
4
  db = ResponseDb()
5
  df = pd.DataFrame(list(db.get()))
6
  print(df)
7
- df.to_csv("responses.csv", index=False)
 
 
 
 
 
 
 
 
 
4
  db = ResponseDb()
5
  df = pd.DataFrame(list(db.get()))
6
  print(df)
7
+ df.to_csv("responses_new.csv", index=False)
8
+
9
+ import sqlite3
10
+ import pandas as pd
11
+
12
+ db = sqlite3.connect("response.db")
13
+ df = pd.read_sql('SELECT * from responses', db)
14
+ print(df)
15
+ df.to_csv("responses_old.csv", index=False)
response.db DELETED
Binary file (8.19 kB)
 
response_db.py CHANGED
@@ -24,3 +24,23 @@ class ResponseDb:
24
 
25
  def get(self):
26
  return self.collection.find()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def get(self):
26
  return self.collection.find()
27
+
28
+
29
+ def get_code(taskid, history, top_pred):
30
+ taskid = int(taskid)
31
+ mongodb_username=os.environ['mongodb_username_2']
32
+ mongodb_pw=os.environ['mongodb_pw_2']
33
+ mongodb_cluster_url=os.environ['mongodb_cluster_url_2']
34
+ client = MongoClient(f"mongodb+srv://{mongodb_username}:{mongodb_pw}@{mongodb_cluster_url}/?retryWrites=true&w=majority")
35
+ db = client['vqa-codes']
36
+ collection = db['vqa-codes']
37
+
38
+ threshold_dict = {1001: 6, 1002: 2, 1003: 4, 1004: 2}
39
+ if int(taskid) in threshold_dict:
40
+ threshold = threshold_dict[int(taskid)]
41
+ if len(history)<=threshold and top_pred == 0:
42
+ return list(collection.find({"taskid":int(taskid)}))[0]['code']
43
+ else:
44
+ return list(collection.find({"taskid":3000-int(taskid)}))[0]['code']
45
+
46
+ return list(collection.find({"taskid":taskid}))[0]['code']