Spaces:
Runtime error
Runtime error
sedrickkeh
commited on
Commit
•
5a72dbb
1
Parent(s):
2bf46a7
major updates demo v2
Browse files- __pycache__/response_db.cpython-37.pyc +0 -0
- app.py +88 -39
- create_cache.py +6 -3
- data/questions_gpt4.csv +0 -0
- model/model/question_asking_model.py +47 -1
- model/model/question_model_base.py +1 -3
- model/model/response_model.py +2 -4
- model/run_question_asking_model.py +1 -5
- mscoco-images/val2014/COCO_val2014_000000564366.jpg +0 -0
- open_db.py +9 -1
- response.db +0 -0
- response_db.py +20 -0
__pycache__/response_db.cpython-37.pyc
DELETED
Binary file (2.86 kB)
|
|
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
from response_db import ResponseDb
|
|
|
3 |
from create_cache import Game_Cache
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
@@ -7,7 +8,6 @@ import pandas as pd
|
|
7 |
import torch
|
8 |
import pickle
|
9 |
import uuid
|
10 |
-
|
11 |
import nltk
|
12 |
nltk.download('punkt')
|
13 |
|
@@ -19,23 +19,34 @@ css = """
|
|
19 |
.msg.bot {background-color:lightgray}
|
20 |
.na_button {background-color:red;color:red}
|
21 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
from model.run_question_asking_model import return_modules, return_modules_yn
|
24 |
-
question_model, response_model_simul,
|
25 |
-
question_model_yn, response_model_simul_yn,
|
26 |
|
27 |
class Game_Session:
|
28 |
def __init__(self, taskid, yn, hard_setting):
|
29 |
self.yn = yn
|
30 |
self.hard_setting = hard_setting
|
31 |
|
32 |
-
global question_model, response_model_simul, caption_model
|
33 |
-
global question_model_yn, response_model_simul_yn, caption_model_yn
|
34 |
self.question_model = question_model
|
35 |
self.response_model_simul = response_model_simul
|
|
|
36 |
self.caption_model = caption_model
|
37 |
self.question_model_yn = question_model_yn
|
38 |
self.response_model_simul_yn = response_model_simul_yn
|
|
|
39 |
self.caption_model_yn = caption_model_yn
|
40 |
|
41 |
global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions
|
@@ -48,55 +59,69 @@ class Game_Session:
|
|
48 |
|
49 |
def set_curr_models(self):
|
50 |
if self.yn:
|
51 |
-
self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model_yn, self.caption_model_yn, self.response_model_simul_yn
|
52 |
else:
|
53 |
-
self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model, self.caption_model, self.response_model_simul
|
54 |
|
55 |
def get_next_question(self):
|
56 |
return self.curr_question_model.select_best_question(self.p_y_x, self.questions, self.images_np, self.captions, self.curr_response_model_simul)
|
57 |
|
|
|
|
|
|
|
58 |
|
59 |
def ask_a_question(input, taskid, gs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
gs.history.append(input)
|
61 |
gs.p_r_qy = gs.curr_response_model_simul.get_p_r_qy(input, gs.history[-2], gs.images_np, gs.captions)
|
62 |
gs.p_y_xqr = gs.p_y_x*gs.p_r_qy
|
63 |
gs.p_y_xqr = gs.p_y_xqr/torch.sum(gs.p_y_xqr)if torch.sum(gs.p_y_xqr) != 0 else torch.zeros_like(gs.p_y_xqr)
|
64 |
gs.p_y_x = gs.p_y_xqr
|
65 |
gs.questions.remove(gs.history[-2])
|
66 |
-
db.add(gs.game_id, taskid, len(gs.history)//2-1, gs.history[-2], gs.history[-1])
|
67 |
gs.history.append(gs.get_next_question())
|
68 |
|
69 |
top_prob = torch.max(gs.p_y_x).item()
|
70 |
top_pred = torch.argmax(gs.p_y_x).item()
|
71 |
-
if top_prob > 0.8:
|
72 |
gs.history = gs.history[:-1]
|
73 |
-
db.add(gs.game_id, taskid, len(gs.history)//2, f"Guess: Image {top_pred}", "")
|
74 |
|
75 |
# write some HTML
|
76 |
html = "<div class='chatbot'>"
|
77 |
for m, msg in enumerate(gs.history):
|
78 |
-
if msg=="nothing": msg="n/a"
|
79 |
cls = "bot" if m%2 == 0 else "user"
|
80 |
html += "<div class='msg {}'> {}</div>".format(cls, msg)
|
81 |
html += "</div>"
|
82 |
|
83 |
### Game finished:
|
84 |
-
if top_prob > 0.8:
|
85 |
html += f"<p>The model identified <b>Image {top_pred+1}</b> as the image. Please select a new task ID to continue.</p>"
|
86 |
-
|
|
|
87 |
else:
|
88 |
if not gs.yn:
|
89 |
-
return html, gs, gr.
|
90 |
else:
|
91 |
-
return html, gs, gr.
|
92 |
|
93 |
|
94 |
def set_images(taskid):
|
95 |
pilot_study = pd.read_csv("pilot-study.csv")
|
96 |
taskid_original = taskid
|
|
|
97 |
taskid = pilot_study['mscoco-id'].tolist()[int(taskid)]
|
98 |
|
99 |
-
with open(f'cache/{int(taskid)}.p', 'rb') as fp:
|
100 |
game_cache = pickle.load(fp)
|
101 |
gs = Game_Session(int(taskid), game_cache.yn, game_cache.hard_setting)
|
102 |
id1 = f"./mscoco-images/val2014/{game_cache.image_files[0]}"
|
@@ -123,12 +148,14 @@ def set_images(taskid):
|
|
123 |
first_question = gs.curr_question_model.select_best_question(gs.p_y_x, gs.questions, gs.images_np, gs.captions, gs.curr_response_model_simul)
|
124 |
first_question_html = f"<div class='chatbot'><div class='msg bot'>{first_question}</div></div>"
|
125 |
gs.history.append(first_question)
|
126 |
-
html = f"<p>Current Task ID: <b>{int(taskid_original)}</b></p>"
|
127 |
if not gs.yn:
|
128 |
-
return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.
|
129 |
else:
|
130 |
-
return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.
|
131 |
|
|
|
|
|
132 |
|
133 |
with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
|
134 |
gr.HTML("<h1>Image Q&A Guessing Game</h1>\
|
@@ -139,22 +166,19 @@ with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
|
|
139 |
The model can see all 10 images and all the questions and answers for the current set of images. It will ask a question based on the available information.<br>\
|
140 |
<span style='color: #0000ff'>The goal of the model is to accurately guess the correct image (i.e. <b><span style='color: #0000ff'>Image 1</span></b>) in as few turns as possible.<br>\
|
141 |
Your goal is to help the model guess the image by answering as clearly and accurately as possible.</span><br><br>\
|
142 |
-
<
|
143 |
-
<
|
144 |
-
|
145 |
-
<li>
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
<br
|
150 |
-
<h2>Please enter a TaskID to start</h2>")
|
151 |
|
152 |
with gr.Column():
|
153 |
with gr.Row():
|
154 |
-
taskid = gr.Number(label="Task ID (Enter a number from 0 to 160)",
|
155 |
-
start_button = gr.Button("
|
156 |
-
with gr.Row():
|
157 |
-
task_text = gr.HTML()
|
158 |
|
159 |
with gr.Column() as img_block:
|
160 |
with gr.Row():
|
@@ -172,8 +196,22 @@ with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
|
|
172 |
conversation = gr.HTML()
|
173 |
game_session_state = gr.State()
|
174 |
|
175 |
-
answer = gr.Textbox(placeholder="Insert answer here.", label="Answer the given question.", visible=False)
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
yes_answer = gr.Textbox("yes", visible=False)
|
178 |
no_answer = gr.Textbox("no", visible=False)
|
179 |
|
@@ -185,11 +223,22 @@ with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
|
|
185 |
with gr.Row():
|
186 |
na_box = gr.Button("N/A", visible=False, elem_classes="na_button")
|
187 |
submit = gr.Button("Submit", visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
### Button click events
|
189 |
-
start_button.click(fn=set_images, inputs=taskid, outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, game_session_state, conversation,
|
190 |
-
submit.click(fn=ask_a_question, inputs=[answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
|
191 |
-
na_box.click(fn=ask_a_question, inputs=[null_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
|
192 |
-
yes_box.click(fn=ask_a_question, inputs=[yes_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
|
193 |
-
no_box.click(fn=ask_a_question, inputs=[no_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
|
|
|
|
|
|
|
|
|
194 |
|
195 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
from response_db import ResponseDb
|
3 |
+
from response_db import get_code
|
4 |
from create_cache import Game_Cache
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
|
|
8 |
import torch
|
9 |
import pickle
|
10 |
import uuid
|
|
|
11 |
import nltk
|
12 |
nltk.download('punkt')
|
13 |
|
|
|
19 |
.msg.bot {background-color:lightgray}
|
20 |
.na_button {background-color:red;color:red}
|
21 |
"""
|
22 |
+
get_window_url_params = """
|
23 |
+
function(url_params) {
|
24 |
+
console.log(url_params);
|
25 |
+
const params = new URLSearchParams(window.location.search);
|
26 |
+
url_params = Object.fromEntries(params);
|
27 |
+
return url_params;
|
28 |
+
}
|
29 |
+
"""
|
30 |
+
quals = {1001:99, 1002:136, 1003:56, 1004:105}
|
31 |
|
32 |
from model.run_question_asking_model import return_modules, return_modules_yn
|
33 |
+
question_model, response_model_simul, response_model_gtruth, caption_model = return_modules()
|
34 |
+
question_model_yn, response_model_simul_yn, response_model_gtruth_yn, caption_model_yn = return_modules_yn()
|
35 |
|
36 |
class Game_Session:
|
37 |
def __init__(self, taskid, yn, hard_setting):
|
38 |
self.yn = yn
|
39 |
self.hard_setting = hard_setting
|
40 |
|
41 |
+
global question_model, response_model_simul, response_model_gtruth, caption_model
|
42 |
+
global question_model_yn, response_model_simul_yn, response_model_gtruth_yn, caption_model_yn
|
43 |
self.question_model = question_model
|
44 |
self.response_model_simul = response_model_simul
|
45 |
+
self.response_model_gtruth = response_model_gtruth
|
46 |
self.caption_model = caption_model
|
47 |
self.question_model_yn = question_model_yn
|
48 |
self.response_model_simul_yn = response_model_simul_yn
|
49 |
+
self.response_model_gtruth_yn = response_model_gtruth_yn
|
50 |
self.caption_model_yn = caption_model_yn
|
51 |
|
52 |
global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions
|
|
|
59 |
|
60 |
def set_curr_models(self):
|
61 |
if self.yn:
|
62 |
+
self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul, self.curr_response_model_gtruth = self.question_model_yn, self.caption_model_yn, self.response_model_simul_yn, self.response_model_gtruth_yn
|
63 |
else:
|
64 |
+
self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul, self.curr_response_model_gtruth = self.question_model, self.caption_model, self.response_model_simul, self.response_model_gtruth
|
65 |
|
66 |
def get_next_question(self):
|
67 |
return self.curr_question_model.select_best_question(self.p_y_x, self.questions, self.images_np, self.captions, self.curr_response_model_simul)
|
68 |
|
69 |
+
def get_model_gtruth_response(self, question):
|
70 |
+
return self.response_model_gtruth.get_response(question, self.images_np[0], self.captions[0], self.target_questions, is_a=self.yn)
|
71 |
+
|
72 |
|
73 |
def ask_a_question(input, taskid, gs):
|
74 |
+
# input = gs.get_model_gtruth_response(gs.history[-1])
|
75 |
+
|
76 |
+
if input not in ["n/a", "yes", "no"] and input not in gs.curr_response_model_simul.model.config.label2id:
|
77 |
+
html = "<div class='chatbot'>"
|
78 |
+
for m, msg in enumerate(gs.history):
|
79 |
+
cls = "bot" if m%2 == 0 else "user"
|
80 |
+
html += "<div class='msg {}'> {}</div>".format(cls, msg)
|
81 |
+
html += "</div>"
|
82 |
+
return html, gs, gr.Dropdown.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Textbox.update(visible=False), gr.HTML.update(visible=True)
|
83 |
+
|
84 |
gs.history.append(input)
|
85 |
gs.p_r_qy = gs.curr_response_model_simul.get_p_r_qy(input, gs.history[-2], gs.images_np, gs.captions)
|
86 |
gs.p_y_xqr = gs.p_y_x*gs.p_r_qy
|
87 |
gs.p_y_xqr = gs.p_y_xqr/torch.sum(gs.p_y_xqr)if torch.sum(gs.p_y_xqr) != 0 else torch.zeros_like(gs.p_y_xqr)
|
88 |
gs.p_y_x = gs.p_y_xqr
|
89 |
gs.questions.remove(gs.history[-2])
|
90 |
+
if taskid not in quals: db.add(gs.game_id, taskid, len(gs.history)//2-1, gs.history[-2], gs.history[-1])
|
91 |
gs.history.append(gs.get_next_question())
|
92 |
|
93 |
top_prob = torch.max(gs.p_y_x).item()
|
94 |
top_pred = torch.argmax(gs.p_y_x).item()
|
95 |
+
if top_prob > 0.8 or len(gs.history) > 19:
|
96 |
gs.history = gs.history[:-1]
|
97 |
+
if taskid not in quals: db.add(gs.game_id, taskid, len(gs.history)//2, f"Guess: Image {top_pred}", "")
|
98 |
|
99 |
# write some HTML
|
100 |
html = "<div class='chatbot'>"
|
101 |
for m, msg in enumerate(gs.history):
|
|
|
102 |
cls = "bot" if m%2 == 0 else "user"
|
103 |
html += "<div class='msg {}'> {}</div>".format(cls, msg)
|
104 |
html += "</div>"
|
105 |
|
106 |
### Game finished:
|
107 |
+
if top_prob > 0.8 or len(gs.history) > 19:
|
108 |
html += f"<p>The model identified <b>Image {top_pred+1}</b> as the image. Please select a new task ID to continue.</p>"
|
109 |
+
finish_html = "<h2>Congratulations on finishing the game! Please copy the Task Finish Code below to MTurk to complete your task. You can now exit this window.</h2>"
|
110 |
+
return html, gs, gr.Dropdown.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Textbox.update(value=get_code(taskid, gs.history, top_pred), visible=True), gr.HTML.update(value=finish_html, visible=True)
|
111 |
else:
|
112 |
if not gs.yn:
|
113 |
+
return html, gs, gr.Dropdown.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Textbox.update(visible=False), gr.HTML.update(visible=False)
|
114 |
else:
|
115 |
+
return html, gs, gr.Dropdown.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True), gr.Textbox.update(visible=False), gr.HTML.update(visible=False)
|
116 |
|
117 |
|
118 |
def set_images(taskid):
|
119 |
pilot_study = pd.read_csv("pilot-study.csv")
|
120 |
taskid_original = taskid
|
121 |
+
if int(taskid) in quals: taskid = quals[int(taskid)]
|
122 |
taskid = pilot_study['mscoco-id'].tolist()[int(taskid)]
|
123 |
|
124 |
+
with open(f'cache-soft/{int(taskid)}.p', 'rb') as fp:
|
125 |
game_cache = pickle.load(fp)
|
126 |
gs = Game_Session(int(taskid), game_cache.yn, game_cache.hard_setting)
|
127 |
id1 = f"./mscoco-images/val2014/{game_cache.image_files[0]}"
|
|
|
148 |
first_question = gs.curr_question_model.select_best_question(gs.p_y_x, gs.questions, gs.images_np, gs.captions, gs.curr_response_model_simul)
|
149 |
first_question_html = f"<div class='chatbot'><div class='msg bot'>{first_question}</div></div>"
|
150 |
gs.history.append(first_question)
|
151 |
+
# html = f"<p>Current Task ID: <b>{int(taskid_original)}</b></p>"
|
152 |
if not gs.yn:
|
153 |
+
return id1, id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.Dropdown.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False)
|
154 |
else:
|
155 |
+
return id1, id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.Dropdown.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=True), gr.Button.update(visible=True)
|
156 |
|
157 |
+
def reset_dropdown():
|
158 |
+
return gr.Dropdown.update(visible=True, value='')
|
159 |
|
160 |
with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
|
161 |
gr.HTML("<h1>Image Q&A Guessing Game</h1>\
|
|
|
166 |
The model can see all 10 images and all the questions and answers for the current set of images. It will ask a question based on the available information.<br>\
|
167 |
<span style='color: #0000ff'>The goal of the model is to accurately guess the correct image (i.e. <b><span style='color: #0000ff'>Image 1</span></b>) in as few turns as possible.<br>\
|
168 |
Your goal is to help the model guess the image by answering as clearly and accurately as possible.</span><br><br>\
|
169 |
+
(Note: We are testing multiple game settings. In some instances, the game will be open-ended, while in other instances, the answer choices will be limited to yes/no.)<br><br>\
|
170 |
+
<b>Selecting N/A:</b><br>\
|
171 |
+
<ul style='font-size:120%;'>\
|
172 |
+
<li>In some games, there will be an N/A option. Please select N/A only if the question is unanswerable BECAUSE IT DOES NOT APPLY TO THE IMAGE.</li>\
|
173 |
+
<li>Otherwise, please select the closest possible option.</li>\
|
174 |
+
<li>e.g. Q:\"What is the dog doing?\" Please select N/A if there is no dog in the image.\
|
175 |
+
</ul> \
|
176 |
+
<br>")
|
|
|
177 |
|
178 |
with gr.Column():
|
179 |
with gr.Row():
|
180 |
+
taskid = gr.Number(label="Task ID (Enter a number from 0 to 160)", visible=False)
|
181 |
+
start_button = gr.Button("Start")
|
|
|
|
|
182 |
|
183 |
with gr.Column() as img_block:
|
184 |
with gr.Row():
|
|
|
196 |
conversation = gr.HTML()
|
197 |
game_session_state = gr.State()
|
198 |
|
199 |
+
# answer = gr.Textbox(placeholder="Insert answer here.", label="Answer the given question.", visible=False)
|
200 |
+
full_vocab_dict = response_model_simul_yn.model.config.label2id
|
201 |
+
vocab_list_numbers, vocab_list_letters = [], []
|
202 |
+
for i in full_vocab_dict:
|
203 |
+
if i=="None" or i is None: continue
|
204 |
+
if i[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
|
205 |
+
vocab_list_numbers.append(i)
|
206 |
+
else:
|
207 |
+
vocab_list_letters.append(i)
|
208 |
+
with gr.Row():
|
209 |
+
answer = gr.Dropdown(vocab_list_letters+vocab_list_numbers, label="Answer the given question.", \
|
210 |
+
info="If you cannot find your exact answer, pick the word you feel would be most appropriate. ONLY SELECT N/A IF THE QUESTION DOES NOT APPLY TO THE IMAGE.", visible=False)
|
211 |
+
clear_box = gr.Button("Reset Selection \n(Use this to clear the dropdown selection.)", visible=False)
|
212 |
+
with gr.Row():
|
213 |
+
vocab_warning = gr.HTML("<h3>The word you typed in is not a valid word in the model vocabulary. Please clear it and select a valid word from the dropdown menu.</h3>", visible=False)
|
214 |
+
null_answer = gr.Textbox("n/a", visible=False)
|
215 |
yes_answer = gr.Textbox("yes", visible=False)
|
216 |
no_answer = gr.Textbox("no", visible=False)
|
217 |
|
|
|
223 |
with gr.Row():
|
224 |
na_box = gr.Button("N/A", visible=False, elem_classes="na_button")
|
225 |
submit = gr.Button("Submit", visible=False)
|
226 |
+
with gr.Row():
|
227 |
+
reward_code = gr.Textbox("", label="Task Finish Code", visible=False)
|
228 |
+
|
229 |
+
with gr.Column() as img_block0:
|
230 |
+
with gr.Row():
|
231 |
+
img0 = gr.Image(label="Image 1", show_label=True).style(height=700, width=700)
|
232 |
+
|
233 |
### Button click events
|
234 |
+
start_button.click(fn=set_images, inputs=taskid, outputs=[img0, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, game_session_state, conversation, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box])
|
235 |
+
submit.click(fn=ask_a_question, inputs=[answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
|
236 |
+
na_box.click(fn=ask_a_question, inputs=[null_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
|
237 |
+
yes_box.click(fn=ask_a_question, inputs=[yes_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
|
238 |
+
no_box.click(fn=ask_a_question, inputs=[no_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
|
239 |
+
clear_box.click(fn=reset_dropdown, inputs=[], outputs=[answer])
|
240 |
+
|
241 |
+
url_params = gr.JSON({}, visible=False, label="URL Params")
|
242 |
+
demo.load(fn = lambda url_params : gr.Number.update(value=int(url_params['p'])), inputs=[url_params], outputs=taskid, _js=get_window_url_params)
|
243 |
|
244 |
demo.launch()
|
create_cache.py
CHANGED
@@ -37,11 +37,11 @@ def create_cache(taskid):
|
|
37 |
global question_model_yn, response_model_simul_yn, caption_model_yn
|
38 |
if taskid in yn_indices:
|
39 |
yn = True
|
40 |
-
curr_question_model, curr_response_model_simul, curr_caption_model =
|
41 |
taskid-=40
|
42 |
else:
|
43 |
yn = False
|
44 |
-
curr_question_model, curr_response_model_simul, curr_caption_model =
|
45 |
if taskid in hard_setting_indices:
|
46 |
hard_setting = True
|
47 |
image_list_curr = image_list_hard
|
@@ -65,6 +65,9 @@ def create_cache(taskid):
|
|
65 |
image_names.append(image_list_curr[int(taskid)*10+i])
|
66 |
image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
|
67 |
image_files = [x[15:] for x in image_files]
|
|
|
|
|
|
|
68 |
images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in image_files]
|
69 |
images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images_np]
|
70 |
p_y_x = (torch.ones(10)/10).to(curr_question_model.device)
|
@@ -74,7 +77,7 @@ def create_cache(taskid):
|
|
74 |
first_question = curr_question_model.select_best_question(p_y_x, questions, images_np, captions, curr_response_model_simul)
|
75 |
|
76 |
gc = Game_Cache(curr_question_model.question_bank, image_names, yn, hard_setting)
|
77 |
-
with open(f'./cache{int(
|
78 |
pickle.dump(gc, fp, protocol=pickle.HIGHEST_PROTOCOL)
|
79 |
|
80 |
if __name__=="__main__":
|
|
|
37 |
global question_model_yn, response_model_simul_yn, caption_model_yn
|
38 |
if taskid in yn_indices:
|
39 |
yn = True
|
40 |
+
curr_question_model, curr_response_model_simul, curr_caption_model = question_model_yn, response_model_simul_yn, caption_model_yn
|
41 |
taskid-=40
|
42 |
else:
|
43 |
yn = False
|
44 |
+
curr_question_model, curr_response_model_simul, curr_caption_model = question_model, response_model_simul, caption_model
|
45 |
if taskid in hard_setting_indices:
|
46 |
hard_setting = True
|
47 |
image_list_curr = image_list_hard
|
|
|
65 |
image_names.append(image_list_curr[int(taskid)*10+i])
|
66 |
image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
|
67 |
image_files = [x[15:] for x in image_files]
|
68 |
+
import os
|
69 |
+
for i in image_files:
|
70 |
+
os.system(f"cp ./../../../data/ms-coco/images/{i} ./mscoco-images/val2014/")
|
71 |
images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in image_files]
|
72 |
images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images_np]
|
73 |
p_y_x = (torch.ones(10)/10).to(curr_question_model.device)
|
|
|
77 |
first_question = curr_question_model.select_best_question(p_y_x, questions, images_np, captions, curr_response_model_simul)
|
78 |
|
79 |
gc = Game_Cache(curr_question_model.question_bank, image_names, yn, hard_setting)
|
80 |
+
with open(f'./cache-soft/{int(original_taskid)}.p', 'wb') as fp:
|
81 |
pickle.dump(gc, fp, protocol=pickle.HIGHEST_PROTOCOL)
|
82 |
|
83 |
if __name__=="__main__":
|
data/questions_gpt4.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/model/question_asking_model.py
CHANGED
@@ -66,13 +66,59 @@ class QuestionAskingModelGPT3(QuestionAskingModel):
|
|
66 |
questions = [i+"?" for i in questions]
|
67 |
return questions
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
def get_questions(self, images, captions, target_idx=0):
|
70 |
questions = []
|
71 |
for i, (image, caption) in enumerate(zip(images, captions)):
|
72 |
if caption in self.gpt3_captions.caption.tolist():
|
73 |
image_questions = self.gpt3_captions[self.gpt3_captions.caption==caption].question.tolist()
|
74 |
else:
|
75 |
-
image_questions = self.
|
76 |
image_df = pd.DataFrame({'caption':[caption for _ in image_questions], 'question':image_questions})
|
77 |
self.gpt3_captions = pd.concat([self.gpt3_captions, image_df])
|
78 |
self.gpt3_captions.to_csv(self.gpt3_path, index=False)
|
|
|
66 |
questions = [i+"?" for i in questions]
|
67 |
return questions
|
68 |
|
69 |
+
def generate_gpt4_questions(self, caption):
|
70 |
+
print("generating gpt4 captions")
|
71 |
+
instructions_prompt="You are tasked to produce reasonable questions from a given caption.\
|
72 |
+
The questions you ask must be answerable only using visual information.\
|
73 |
+
As such, never ask questions that involve exact measurement such as \"How tall\", \"How big\", or \"How far\", since these cannot be easily inferred from just looking at an object.\
|
74 |
+
Likewise, never ask questions that involve age (\"How old\"), composition (\"What is it made of\"), material, emotion, or personal relationship.\
|
75 |
+
When asking \"Where\" questions, the subject of your question must be a person or a small object.\
|
76 |
+
Never ask questions that can be answered with yes or no.\
|
77 |
+
When refering to objects, try to be general. For example, instead of saying \"cat\", you should say \"animal\". Instead of saying \"cake\", you should say \"food\".\
|
78 |
+
I repeat, when refering to object, try to be general!\
|
79 |
+
Good questions to ask include general \"What color\", as well as general probing questions such as \"What is the man doing?\" or \"What is the main subject of the image?\"\
|
80 |
+
For each caption, please generate 3-5 reasonable questions."
|
81 |
+
c1="A living room with a couch, coffee table and two large windows with white curtains."
|
82 |
+
q1="What color is the couch? How many windows are there? How many tables are there? What color is the table? What color are the curtains? What is next to the table? What is next to the couch?"
|
83 |
+
c2="A cat is wearing a pink wool hat."
|
84 |
+
q2="What color is the animal? What color is the hat? What is the cat wearing? "
|
85 |
+
c3="A stop sign with a skeleton painted on it, next to a car."
|
86 |
+
q3="What color is the sign? What color is the car? What is next to the sign? What is next to the car? What is on the sign? Where is the car?"
|
87 |
+
c4="A man brushing his teeth with a toothbrush"
|
88 |
+
q4="What is the man doing? Where is the man? What color is the toothbrush?"
|
89 |
+
prompt = f"{instructions_prompt}\n"
|
90 |
+
prompt+=f"\nCaption: {c1}\nQuestions: {q1}\n"
|
91 |
+
prompt+=f"\nCaption: {c2}\nQuestions: {q2}\n"
|
92 |
+
prompt+=f"\nCaption: {c3}\nQuestions: {q3}\n"
|
93 |
+
prompt+=f"\nCaption: {c4}\nQuestions: {q4}\n"
|
94 |
+
prompt+=f"\nCaption: {caption}\nQuestions:"
|
95 |
+
response = openai.ChatCompletion.create(
|
96 |
+
model="gpt-4-0314",
|
97 |
+
messages=[
|
98 |
+
{"role": "system", "content": instructions_prompt},
|
99 |
+
{"role": "user", "content": c1},
|
100 |
+
{"role": "assistant", "content": q1},
|
101 |
+
{"role": "user", "content": c2},
|
102 |
+
{"role": "assistant", "content": q2},
|
103 |
+
{"role": "user", "content": c3},
|
104 |
+
{"role": "assistant", "content": q3},
|
105 |
+
{"role": "user", "content": c4},
|
106 |
+
{"role": "assistant", "content": q4},
|
107 |
+
{"role": "user", "content": caption}
|
108 |
+
]
|
109 |
+
)
|
110 |
+
questions = response["choices"][0]["message"]["content"]
|
111 |
+
questions = [i.strip() for i in questions.split('?') if len(i.strip())>1]
|
112 |
+
questions = [i+"?" for i in questions]
|
113 |
+
return questions
|
114 |
+
|
115 |
def get_questions(self, images, captions, target_idx=0):
|
116 |
questions = []
|
117 |
for i, (image, caption) in enumerate(zip(images, captions)):
|
118 |
if caption in self.gpt3_captions.caption.tolist():
|
119 |
image_questions = self.gpt3_captions[self.gpt3_captions.caption==caption].question.tolist()
|
120 |
else:
|
121 |
+
image_questions = self.generate_gpt4_questions(caption)
|
122 |
image_df = pd.DataFrame({'caption':[caption for _ in image_questions], 'question':image_questions})
|
123 |
self.gpt3_captions = pd.concat([self.gpt3_captions, image_df])
|
124 |
self.gpt3_captions.to_csv(self.gpt3_path, index=False)
|
model/model/question_model_base.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
import torch
|
2 |
-
import scipy
|
3 |
import numpy as np
|
4 |
import operator
|
5 |
-
import random
|
6 |
from model.model.question_generator import QuestionGenerator
|
7 |
|
8 |
|
@@ -39,7 +37,7 @@ class QuestionAskingModel():
|
|
39 |
p_r_qy = p_r_qy.detach().cpu().numpy()
|
40 |
is_a_multiplier.append(p_r_qy)
|
41 |
if len(is_a_multiplier)==0: is_a_multiplier.append([0 for _ in range(self.num_images)])
|
42 |
-
is_a_multiplier = torch.tensor(
|
43 |
if self.multiplier_mode=="hard":
|
44 |
for i in range(is_a_multiplier.shape[0]):
|
45 |
if is_a_multiplier[i]<0.5: is_a_multiplier[i]=1e-6
|
|
|
1 |
import torch
|
|
|
2 |
import numpy as np
|
3 |
import operator
|
|
|
4 |
from model.model.question_generator import QuestionGenerator
|
5 |
|
6 |
|
|
|
37 |
p_r_qy = p_r_qy.detach().cpu().numpy()
|
38 |
is_a_multiplier.append(p_r_qy)
|
39 |
if len(is_a_multiplier)==0: is_a_multiplier.append([0 for _ in range(self.num_images)])
|
40 |
+
is_a_multiplier = torch.tensor(np.mean(is_a_multiplier, axis=0)).to("cuda")
|
41 |
if self.multiplier_mode=="hard":
|
42 |
for i in range(is_a_multiplier.shape[0]):
|
43 |
if is_a_multiplier[i]<0.5: is_a_multiplier[i]=1e-6
|
model/model/response_model.py
CHANGED
@@ -121,7 +121,7 @@ class ResponseModelVQA(ResponseModel):
|
|
121 |
no_cnt = sum([i.lower()=="no" for i in is_a_responses])
|
122 |
if len(is_a_responses)>0 and no_cnt/len(is_a_responses)>=0.5:
|
123 |
if question[:8]=="How many": return "0"
|
124 |
-
else: return "
|
125 |
if self.vqa_type in ["vilt1", "vilt2"]:
|
126 |
outputs = self.model(**encoding)
|
127 |
logits = torch.nn.functional.softmax(outputs.logits, dim=1)
|
@@ -165,9 +165,7 @@ class ResponseModelVQA(ResponseModel):
|
|
165 |
response_idx = self.get_response_idx(response)
|
166 |
p_r_qy = logits[:,response_idx]
|
167 |
|
168 |
-
|
169 |
-
# consider rerunning also without the geometric mean
|
170 |
-
if response=="nothing":
|
171 |
is_a_questions = self.question_generator.generate_is_there_question_v2(question)
|
172 |
for idx, (caption, image) in enumerate(zip(captions, images)):
|
173 |
current_responses = []
|
|
|
121 |
no_cnt = sum([i.lower()=="no" for i in is_a_responses])
|
122 |
if len(is_a_responses)>0 and no_cnt/len(is_a_responses)>=0.5:
|
123 |
if question[:8]=="How many": return "0"
|
124 |
+
else: return "n/a"
|
125 |
if self.vqa_type in ["vilt1", "vilt2"]:
|
126 |
outputs = self.model(**encoding)
|
127 |
logits = torch.nn.functional.softmax(outputs.logits, dim=1)
|
|
|
165 |
response_idx = self.get_response_idx(response)
|
166 |
p_r_qy = logits[:,response_idx]
|
167 |
|
168 |
+
if response=="n/a":
|
|
|
|
|
169 |
is_a_questions = self.question_generator.generate_is_there_question_v2(question)
|
170 |
for idx, (caption, image) in enumerate(zip(captions, images)):
|
171 |
current_responses = []
|
model/run_question_asking_model.py
CHANGED
@@ -34,7 +34,7 @@ parser.add_argument('--response_type_gtruth', type=str, default='VQA2', choices=
|
|
34 |
parser.add_argument('--question_strategy', type=str, default='gpt3', choices=['rule', 'gpt3'])
|
35 |
parser.add_argument('--multiplier_mode', type=str, default='soft', choices=['soft', 'hard', 'none'])
|
36 |
|
37 |
-
parser.add_argument('--gpt3_save_name', type=str, default='
|
38 |
parser.add_argument('--save_name', type=str, default=None)
|
39 |
parser.add_argument('--verbose', action='store_true')
|
40 |
args = parser.parse_args()
|
@@ -43,8 +43,6 @@ args.include_what=True
|
|
43 |
args.response_type_simul='VQA1'
|
44 |
args.response_type_gtruth='VQA3'
|
45 |
args.multiplier_mode='soft'
|
46 |
-
args.sample_strategy='attribute'
|
47 |
-
args.attribute_n=1
|
48 |
args.caption_strategy='gtruth'
|
49 |
assert_checks(args)
|
50 |
if args.save_name is None: logger = logging_handler(args.verbose, args.save_name)
|
@@ -74,8 +72,6 @@ args.include_what=False
|
|
74 |
args.response_type_simul='VQA1'
|
75 |
args.response_type_gtruth='VQA3'
|
76 |
args.multiplier_mode='none'
|
77 |
-
args.sample_strategy='attribute'
|
78 |
-
args.attribute_n=1
|
79 |
args.caption_strategy='gtruth'
|
80 |
|
81 |
print("1. Loading question model ...")
|
|
|
34 |
parser.add_argument('--question_strategy', type=str, default='gpt3', choices=['rule', 'gpt3'])
|
35 |
parser.add_argument('--multiplier_mode', type=str, default='soft', choices=['soft', 'hard', 'none'])
|
36 |
|
37 |
+
parser.add_argument('--gpt3_save_name', type=str, default='questions_gpt4')
|
38 |
parser.add_argument('--save_name', type=str, default=None)
|
39 |
parser.add_argument('--verbose', action='store_true')
|
40 |
args = parser.parse_args()
|
|
|
43 |
args.response_type_simul='VQA1'
|
44 |
args.response_type_gtruth='VQA3'
|
45 |
args.multiplier_mode='soft'
|
|
|
|
|
46 |
args.caption_strategy='gtruth'
|
47 |
assert_checks(args)
|
48 |
if args.save_name is None: logger = logging_handler(args.verbose, args.save_name)
|
|
|
72 |
args.response_type_simul='VQA1'
|
73 |
args.response_type_gtruth='VQA3'
|
74 |
args.multiplier_mode='none'
|
|
|
|
|
75 |
args.caption_strategy='gtruth'
|
76 |
|
77 |
print("1. Loading question model ...")
|
mscoco-images/val2014/COCO_val2014_000000564366.jpg
CHANGED
open_db.py
CHANGED
@@ -4,4 +4,12 @@ import pandas as pd
|
|
4 |
db = ResponseDb()
|
5 |
df = pd.DataFrame(list(db.get()))
|
6 |
print(df)
|
7 |
-
df.to_csv("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
db = ResponseDb()
|
5 |
df = pd.DataFrame(list(db.get()))
|
6 |
print(df)
|
7 |
+
df.to_csv("responses_new.csv", index=False)
|
8 |
+
|
9 |
+
import sqlite3
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
db = sqlite3.connect("response.db")
|
13 |
+
df = pd.read_sql('SELECT * from responses', db)
|
14 |
+
print(df)
|
15 |
+
df.to_csv("responses_old.csv", index=False)
|
response.db
DELETED
Binary file (8.19 kB)
|
|
response_db.py
CHANGED
@@ -24,3 +24,23 @@ class ResponseDb:
|
|
24 |
|
25 |
def get(self):
|
26 |
return self.collection.find()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def get(self):
|
26 |
return self.collection.find()
|
27 |
+
|
28 |
+
|
29 |
+
def get_code(taskid, history, top_pred):
|
30 |
+
taskid = int(taskid)
|
31 |
+
mongodb_username=os.environ['mongodb_username_2']
|
32 |
+
mongodb_pw=os.environ['mongodb_pw_2']
|
33 |
+
mongodb_cluster_url=os.environ['mongodb_cluster_url_2']
|
34 |
+
client = MongoClient(f"mongodb+srv://{mongodb_username}:{mongodb_pw}@{mongodb_cluster_url}/?retryWrites=true&w=majority")
|
35 |
+
db = client['vqa-codes']
|
36 |
+
collection = db['vqa-codes']
|
37 |
+
|
38 |
+
threshold_dict = {1001: 6, 1002: 2, 1003: 4, 1004: 2}
|
39 |
+
if int(taskid) in threshold_dict:
|
40 |
+
threshold = threshold_dict[int(taskid)]
|
41 |
+
if len(history)<=threshold and top_pred == 0:
|
42 |
+
return list(collection.find({"taskid":int(taskid)}))[0]['code']
|
43 |
+
else:
|
44 |
+
return list(collection.find({"taskid":3000-int(taskid)}))[0]['code']
|
45 |
+
|
46 |
+
return list(collection.find({"taskid":taskid}))[0]['code']
|