Spaces:
Runtime error
Runtime error
sedrickkeh
commited on
Commit
•
016285f
1
Parent(s):
4be744b
Upload 13 files
Browse files- app.py +169 -57
- create_cache.py +83 -0
- model/__init__.py +6 -0
- model/model/caption_model.py +89 -0
- model/model/question_asking_model.py +83 -0
- model/model/question_generator.py +194 -0
- model/model/question_model_base.py +85 -0
- model/model/response_model.py +190 -0
- model/run_question_asking_model.py +186 -0
- model/utils.py +54 -0
- open_db.py +1 -6
- pilot-study.csv +161 -0
- response_db.py +1 -2
app.py
CHANGED
@@ -1,80 +1,192 @@
|
|
1 |
import gradio as gr
|
2 |
from response_db import StResponseDb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
db = StResponseDb()
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
def get_next_question(history):
|
7 |
-
if len(history)==2:
|
8 |
-
question = "What is the man doing?"
|
9 |
-
elif len(history)==4:
|
10 |
-
question = "How many apples are there?"
|
11 |
-
else:
|
12 |
-
question = "What color is the cat?"
|
13 |
-
return question
|
14 |
-
|
15 |
-
def ask_a_question(input, taskid, history=[]):
|
16 |
-
history.append(input)
|
17 |
-
db.add(int(a.value), taskid, len(history)//2-1, history[-2], history[-1])
|
18 |
-
history.append(get_next_question(history))
|
19 |
-
|
20 |
# write some HTML
|
21 |
html = "<div class='chatbot'>"
|
22 |
-
for m, msg in enumerate(history):
|
|
|
23 |
cls = "bot" if m%2 == 0 else "user"
|
24 |
html += "<div class='msg {}'> {}</div>".format(cls, msg)
|
25 |
html += "</div>"
|
26 |
-
return html, history
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
css = """
|
30 |
-
.chatbox {display:flex;flex-direction:column}
|
31 |
-
.msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
|
32 |
-
.msg.user {background-color:cornflowerblue;color:white}
|
33 |
-
.msg.bot {background-color:lightgray;align-self:self-end}
|
34 |
-
.footer {display:none !important}
|
35 |
-
"""
|
36 |
|
37 |
def set_images(taskid):
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
first_question_html = f"<div class='chatbot'><div class='msg bot'>{first_question}</div></div>"
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
with gr.Blocks(css=css) as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
with gr.Column() as img_block:
|
56 |
with gr.Row():
|
57 |
-
img1 = gr.Image()
|
58 |
-
img2 = gr.Image()
|
59 |
-
img3 = gr.Image()
|
60 |
-
img4 = gr.Image()
|
61 |
-
img5 = gr.Image()
|
62 |
with gr.Row():
|
63 |
-
img6 = gr.Image()
|
64 |
-
img7 = gr.Image()
|
65 |
-
img8 = gr.Image()
|
66 |
-
img9 = gr.Image()
|
67 |
-
img10 = gr.Image()
|
68 |
conversation = gr.HTML()
|
69 |
-
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
with gr.Column():
|
72 |
with gr.Row():
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
submit =
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from response_db import StResponseDb
|
3 |
+
from create_cache import Game_Cache
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import pickle
|
9 |
+
import uuid
|
10 |
+
|
11 |
db = StResponseDb()
|
12 |
+
css = """
|
13 |
+
.chatbot {display:flex;flex-direction:column}
|
14 |
+
.msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
|
15 |
+
.msg.user {background-color:cornflowerblue;color:white;align-self:self-end}
|
16 |
+
.msg.bot {background-color:lightgray}
|
17 |
+
.na_button {background-color:red;color:red}
|
18 |
+
"""
|
19 |
+
|
20 |
+
from model.run_question_asking_model import return_modules, return_modules_yn
|
21 |
+
question_model, response_model_simul, _, caption_model = return_modules()
|
22 |
+
question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn()
|
23 |
+
|
24 |
+
class Game_Session:
|
25 |
+
def __init__(self, taskid, yn, hard_setting):
|
26 |
+
self.yn = yn
|
27 |
+
self.hard_setting = hard_setting
|
28 |
+
|
29 |
+
global question_model, response_model_simul, caption_model
|
30 |
+
global question_model_yn, response_model_simul_yn, caption_model_yn
|
31 |
+
self.question_model = question_model
|
32 |
+
self.response_model_simul = response_model_simul
|
33 |
+
self.caption_model = caption_model
|
34 |
+
self.question_model_yn = question_model_yn
|
35 |
+
self.response_model_simul_yn = response_model_simul_yn
|
36 |
+
self.caption_model_yn = caption_model_yn
|
37 |
+
|
38 |
+
global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions
|
39 |
+
self.image_files, self.image_np, self.p_y_x, self.p_r_qy, self.p_y_xqr = None, None, None, None, None
|
40 |
+
self.captions, self.questions, self.target_questions = None, None, None
|
41 |
+
|
42 |
+
self.history = []
|
43 |
+
self.game_id = str(uuid.uuid4())
|
44 |
+
self.set_curr_models()
|
45 |
+
|
46 |
+
def set_curr_models(self):
|
47 |
+
if self.yn:
|
48 |
+
self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model_yn, self.caption_model_yn, self.response_model_simul_yn
|
49 |
+
else:
|
50 |
+
self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model, self.caption_model, self.response_model_simul
|
51 |
+
|
52 |
+
def get_next_question(self):
|
53 |
+
return self.curr_question_model.select_best_question(self.p_y_x, self.questions, self.images_np, self.captions, self.curr_response_model_simul)
|
54 |
+
|
55 |
+
|
56 |
+
def ask_a_question(input, taskid, gs):
|
57 |
+
gs.history.append(input)
|
58 |
+
gs.p_r_qy = gs.curr_response_model_simul.get_p_r_qy(input, gs.history[-2], gs.images_np, gs.captions)
|
59 |
+
gs.p_y_xqr = gs.p_y_x*gs.p_r_qy
|
60 |
+
gs.p_y_xqr = gs.p_y_xqr/torch.sum(gs.p_y_xqr)if torch.sum(gs.p_y_xqr) != 0 else torch.zeros_like(gs.p_y_xqr)
|
61 |
+
gs.p_y_x = gs.p_y_xqr
|
62 |
+
gs.questions.remove(gs.history[-2])
|
63 |
+
db.add(gs.game_id, taskid, len(gs.history)//2-1, gs.history[-2], gs.history[-1])
|
64 |
+
gs.history.append(gs.get_next_question())
|
65 |
+
|
66 |
+
top_prob = torch.max(gs.p_y_x).item()
|
67 |
+
top_pred = torch.argmax(gs.p_y_x).item()
|
68 |
+
if top_prob > 0.8:
|
69 |
+
gs.history = gs.history[:-1]
|
70 |
+
db.add(gs.game_id, taskid, len(gs.history)//2, f"Guess: Image {top_pred}", "")
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
# write some HTML
|
73 |
html = "<div class='chatbot'>"
|
74 |
+
for m, msg in enumerate(gs.history):
|
75 |
+
if msg=="nothing": msg="n/a"
|
76 |
cls = "bot" if m%2 == 0 else "user"
|
77 |
html += "<div class='msg {}'> {}</div>".format(cls, msg)
|
78 |
html += "</div>"
|
|
|
79 |
|
80 |
+
### Game finished:
|
81 |
+
if top_prob > 0.8:
|
82 |
+
html += f"<p>The model identified <b>Image {top_pred+1}</b> as the image. Please select a new task ID to continue.</p>"
|
83 |
+
return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False)
|
84 |
+
else:
|
85 |
+
if not gs.yn:
|
86 |
+
return html, gs, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False)
|
87 |
+
else:
|
88 |
+
return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True)
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
def set_images(taskid):
|
92 |
+
pilot_study = pd.read_csv("pilot-study.csv")
|
93 |
+
taskid_original = taskid
|
94 |
+
taskid = pilot_study['mscoco-id'].tolist()[int(taskid)]
|
95 |
+
|
96 |
+
with open(f'cache/{int(taskid)}.p', 'rb') as fp:
|
97 |
+
game_cache = pickle.load(fp)
|
98 |
+
gs = Game_Session(int(taskid), game_cache.yn, game_cache.hard_setting)
|
99 |
+
id1 = f"./mscoco-images/val2014/{game_cache.image_files[0]}"
|
100 |
+
id2 = f"./mscoco-images/val2014/{game_cache.image_files[1]}"
|
101 |
+
id3 = f"./mscoco-images/val2014/{game_cache.image_files[2]}"
|
102 |
+
id4 = f"./mscoco-images/val2014/{game_cache.image_files[3]}"
|
103 |
+
id5 = f"./mscoco-images/val2014/{game_cache.image_files[4]}"
|
104 |
+
id6 = f"./mscoco-images/val2014/{game_cache.image_files[5]}"
|
105 |
+
id7 = f"./mscoco-images/val2014/{game_cache.image_files[6]}"
|
106 |
+
id8 = f"./mscoco-images/val2014/{game_cache.image_files[7]}"
|
107 |
+
id9 = f"./mscoco-images/val2014/{game_cache.image_files[8]}"
|
108 |
+
id10 = f"./mscoco-images/val2014/{game_cache.image_files[9]}"
|
109 |
+
|
110 |
+
gs.image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
|
111 |
+
gs.image_files = [x[15:] for x in gs.image_files]
|
112 |
+
gs.images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in gs.image_files]
|
113 |
+
gs.images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in gs.images_np]
|
114 |
+
|
115 |
+
gs.p_y_x = (torch.ones(10)/10).to(gs.curr_question_model.device)
|
116 |
+
gs.captions = gs.curr_caption_model.get_captions(gs.image_files)
|
117 |
+
gs.questions, gs.target_questions = gs.curr_question_model.get_questions(gs.image_files, gs.captions, 0)
|
118 |
+
gs.curr_question_model.reset_question_bank()
|
119 |
+
gs.curr_question_model.question_bank = game_cache.question_dict
|
120 |
+
first_question = gs.curr_question_model.select_best_question(gs.p_y_x, gs.questions, gs.images_np, gs.captions, gs.curr_response_model_simul)
|
121 |
first_question_html = f"<div class='chatbot'><div class='msg bot'>{first_question}</div></div>"
|
122 |
+
gs.history.append(first_question)
|
123 |
+
html = f"<p>Current Task ID: <b>{int(taskid_original)}</b></p>"
|
124 |
+
if not gs.yn:
|
125 |
+
return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False)
|
126 |
+
else:
|
127 |
+
return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=True), gr.Button.update(visible=True)
|
128 |
+
|
129 |
|
130 |
+
with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
|
131 |
+
gr.HTML("<h1>Image Q&A Guessing Game</h1>\
|
132 |
+
<p style='font-size:120%;'>\
|
133 |
+
Imagine you are playing 20-questions with an AI model.<br>\
|
134 |
+
The AI model plays the role of the question asker. You play the role of the responder. <br>\
|
135 |
+
There are 10 images. <b>Your image is Image 1</b>. The other images are distraction images.\
|
136 |
+
The model can see all 10 images and all the questions and answers for the current set of images. It will ask a question based on the available information.<br>\
|
137 |
+
<span style='color: #0000ff'>The goal of the model is to accurately guess the correct image (i.e. <b><span style='color: #0000ff'>Image 1</span></b>) in as few turns as possible.<br>\
|
138 |
+
Your goal is to help the model guess the image by answering as clearly and accurately as possible.</span><br><br>\
|
139 |
+
<b>Guidelines:</b><br>\
|
140 |
+
<ol style='font-size:120%;'>\
|
141 |
+
<li>It is best to keep your answers short (a single word or a short phrase). No need to answer in full sentences.</li>\
|
142 |
+
<li>If you feel that the question cannot be answered or does not apply to Image 1, please select N/A.</li>\
|
143 |
+
</ol> \
|
144 |
+
<br>\
|
145 |
+
(Note: We are testing multiple game settings. In some instances, the game will be open-ended, while in other instances, the answer choices will be limited to yes/no.)<br></p>\
|
146 |
+
<br>\
|
147 |
+
<h2>Please enter a TaskID to start</h2>")
|
148 |
+
|
149 |
+
with gr.Column():
|
150 |
+
with gr.Row():
|
151 |
+
taskid = gr.Number(label="Task ID (Enter a number from 0 to 160)", value=0)
|
152 |
+
start_button = gr.Button("Enter")
|
153 |
+
with gr.Row():
|
154 |
+
task_text = gr.HTML()
|
155 |
|
156 |
with gr.Column() as img_block:
|
157 |
with gr.Row():
|
158 |
+
img1 = gr.Image(label="Image 1", show_label=True)
|
159 |
+
img2 = gr.Image(label="Image 2", show_label=True)
|
160 |
+
img3 = gr.Image(label="Image 3", show_label=True)
|
161 |
+
img4 = gr.Image(label="Image 4", show_label=True)
|
162 |
+
img5 = gr.Image(label="Image 5", show_label=True)
|
163 |
with gr.Row():
|
164 |
+
img6 = gr.Image(label="Image 6", show_label=True)
|
165 |
+
img7 = gr.Image(label="Image 7", show_label=True)
|
166 |
+
img8 = gr.Image(label="Image 8", show_label=True)
|
167 |
+
img9 = gr.Image(label="Image 9", show_label=True)
|
168 |
+
img10 = gr.Image(label="Image 10", show_label=True)
|
169 |
conversation = gr.HTML()
|
170 |
+
game_session_state = gr.State()
|
171 |
|
172 |
+
answer = gr.Textbox(placeholder="Insert answer here.", label="Answer the given question.", visible=False)
|
173 |
+
null_answer = gr.Textbox("nothing", visible=False)
|
174 |
+
yes_answer = gr.Textbox("yes", visible=False)
|
175 |
+
no_answer = gr.Textbox("no", visible=False)
|
176 |
+
|
177 |
+
with gr.Column():
|
178 |
+
with gr.Row():
|
179 |
+
yes_box = gr.Button("Yes", visible=False)
|
180 |
+
no_box = gr.Button("No", visible=False)
|
181 |
with gr.Column():
|
182 |
with gr.Row():
|
183 |
+
na_box = gr.Button("N/A", visible=False, elem_classes="na_button")
|
184 |
+
submit = gr.Button("Submit", visible=False)
|
185 |
+
### Button click events
|
186 |
+
start_button.click(fn=set_images, inputs=taskid, outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, game_session_state, conversation, task_text, answer, na_box, submit, taskid, start_button, yes_box, no_box])
|
187 |
+
submit.click(fn=ask_a_question, inputs=[answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
|
188 |
+
na_box.click(fn=ask_a_question, inputs=[null_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
|
189 |
+
yes_box.click(fn=ask_a_question, inputs=[yes_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
|
190 |
+
no_box.click(fn=ask_a_question, inputs=[no_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
|
191 |
+
|
192 |
+
demo.launch()
|
create_cache.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
import pickle
|
6 |
+
|
7 |
+
class Game_Cache:
|
8 |
+
def __init__(self, question_dict, image_files, yn, hard_setting):
|
9 |
+
self.question_dict = question_dict
|
10 |
+
self.image_files = image_files
|
11 |
+
self.yn = yn
|
12 |
+
self.hard_setting = hard_setting
|
13 |
+
|
14 |
+
image_list = []
|
15 |
+
with open('./mscoco/mscoco_images.txt', 'r') as f:
|
16 |
+
for line in f.readlines():
|
17 |
+
image_list.append(line.strip())
|
18 |
+
image_list_hard = []
|
19 |
+
with open('./mscoco/mscoco_images_attribute_n=1.txt', 'r') as f:
|
20 |
+
for line in f.readlines():
|
21 |
+
image_list_hard.append(line.strip())
|
22 |
+
|
23 |
+
yn_indices = list(range(40,80))+list(range(120,160))
|
24 |
+
hard_setting_indices = list(range(80,160))
|
25 |
+
|
26 |
+
|
27 |
+
from model.run_question_asking_model import return_modules, return_modules_yn
|
28 |
+
global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions
|
29 |
+
global question_model, response_model_simul, caption_model
|
30 |
+
question_model, response_model_simul, _, caption_model = return_modules()
|
31 |
+
global question_model_yn, response_model_simul_yn, caption_model_yn
|
32 |
+
question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn()
|
33 |
+
|
34 |
+
def create_cache(taskid):
|
35 |
+
original_taskid = taskid
|
36 |
+
global question_model, response_model_simul, caption_model
|
37 |
+
global question_model_yn, response_model_simul_yn, caption_model_yn
|
38 |
+
if taskid in yn_indices:
|
39 |
+
yn = True
|
40 |
+
curr_question_model, curr_response_model_simul, curr_caption_model = question_model, response_model_simul, caption_model
|
41 |
+
taskid-=40
|
42 |
+
else:
|
43 |
+
yn = False
|
44 |
+
curr_question_model, curr_response_model_simul, curr_caption_model = question_model_yn, response_model_simul_yn, caption_model_yn
|
45 |
+
if taskid in hard_setting_indices:
|
46 |
+
hard_setting = True
|
47 |
+
image_list_curr = image_list_hard
|
48 |
+
taskid -= 80
|
49 |
+
else:
|
50 |
+
hard_setting = False
|
51 |
+
image_list_curr = image_list
|
52 |
+
|
53 |
+
id1 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+0]}"
|
54 |
+
id2 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+1]}"
|
55 |
+
id3 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+2]}"
|
56 |
+
id4 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+3]}"
|
57 |
+
id5 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+4]}"
|
58 |
+
id6 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+5]}"
|
59 |
+
id7 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+6]}"
|
60 |
+
id8 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+7]}"
|
61 |
+
id9 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+8]}"
|
62 |
+
id10 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+9]}"
|
63 |
+
image_names = []
|
64 |
+
for i in range(10):
|
65 |
+
image_names.append(image_list_curr[int(taskid)*10+i])
|
66 |
+
image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
|
67 |
+
image_files = [x[15:] for x in image_files]
|
68 |
+
images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in image_files]
|
69 |
+
images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images_np]
|
70 |
+
p_y_x = (torch.ones(10)/10).to(curr_question_model.device)
|
71 |
+
captions = curr_caption_model.get_captions(image_files)
|
72 |
+
questions, target_questions = curr_question_model.get_questions(image_files, captions, 0)
|
73 |
+
curr_question_model.reset_question_bank()
|
74 |
+
first_question = curr_question_model.select_best_question(p_y_x, questions, images_np, captions, curr_response_model_simul)
|
75 |
+
|
76 |
+
gc = Game_Cache(curr_question_model.question_bank, image_names, yn, hard_setting)
|
77 |
+
with open(f'./cache{int(taskid)}.p', 'wb') as fp:
|
78 |
+
pickle.dump(gc, fp, protocol=pickle.HIGHEST_PROTOCOL)
|
79 |
+
|
80 |
+
if __name__=="__main__":
|
81 |
+
for i in range(160):
|
82 |
+
create_cache(i)
|
83 |
+
|
model/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import model.run_question_asking_model
|
2 |
+
import model.model.caption_model
|
3 |
+
import model.model.question_asking_model
|
4 |
+
import model.model.question_generator
|
5 |
+
import model.model.question_model_base
|
6 |
+
import model.model.response_model
|
model/model/caption_model.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from pycocotools.coco import COCO
|
3 |
+
|
4 |
+
def get_caption_model(args, question_model):
|
5 |
+
if args.caption_strategy=="simple":
|
6 |
+
return CaptionModelSimple(question_model)
|
7 |
+
elif args.caption_strategy=="granular":
|
8 |
+
return CaptionModelGranular()
|
9 |
+
elif args.caption_strategy=="gtruth":
|
10 |
+
return CaptionModelCOCO()
|
11 |
+
else:
|
12 |
+
raise ValueError(f"{args.caption_strategy} is not a valid caption strategy.")
|
13 |
+
|
14 |
+
|
15 |
+
class CaptionModel():
|
16 |
+
# Class for the other CaptionModels to inherit from
|
17 |
+
def __init__(self):
|
18 |
+
pass
|
19 |
+
|
20 |
+
def get_captions(self, images, **kwargs):
|
21 |
+
raise NotImplemented
|
22 |
+
|
23 |
+
class CaptionModelCOCO():
|
24 |
+
# Ground truth annotations from COCO dataset
|
25 |
+
def __init__(self):
|
26 |
+
dataDir='./mscoco'
|
27 |
+
val_file = '{}/annotations/captions_val2014.json'.format(dataDir)
|
28 |
+
self.coco_caps_val = COCO(val_file)
|
29 |
+
val_file = '{}/annotations/instances_val2014.json'.format(dataDir)
|
30 |
+
self.coco_anns_val = COCO(val_file)
|
31 |
+
|
32 |
+
def get_captions(self, images, return_all=False):
|
33 |
+
captions = []
|
34 |
+
for i, image in enumerate(images):
|
35 |
+
image_id = int(image.split('_')[-1].split('.')[0].lstrip("0"))
|
36 |
+
annIds = self.coco_caps_val.getAnnIds(imgIds=image_id)
|
37 |
+
anns_val = self.coco_caps_val.loadAnns(annIds)
|
38 |
+
# annIds = self.coco_caps_train.getAnnIds(imgIds=image_id)
|
39 |
+
# anns_train = self.coco_caps_train.loadAnns(annIds)
|
40 |
+
# anns = anns_val + anns_train
|
41 |
+
anns = anns_val
|
42 |
+
anns = [d['caption'] for d in anns]
|
43 |
+
if return_all: captions.append(anns)
|
44 |
+
else: captions.append(anns[0])
|
45 |
+
return captions
|
46 |
+
|
47 |
+
def get_subjects(self, images):
|
48 |
+
subjects = []
|
49 |
+
for i, image in enumerate(images):
|
50 |
+
image_id = int(image.split('_')[-1].split('.')[0].lstrip("0"))
|
51 |
+
annIds = self.coco_anns_val.getAnnIds(imgIds=image_id)
|
52 |
+
anns_val = self.coco_anns_val.loadAnns(annIds)
|
53 |
+
cats_val = list(set([d['category_id'] for d in anns_val]))
|
54 |
+
annIds = self.coco_caps_train.getAnnIds(imgIds=image_id)
|
55 |
+
anns_train = self.coco_caps_train.loadAnns(annIds)
|
56 |
+
cats_train = list(set([d['category_id'] for d in anns_train]))
|
57 |
+
cats = self.coco_anns_val.loadCats(ids=cats_val+cats_train)
|
58 |
+
cats1, cats2 = [d['supercategory'] for d in cats], [d['name'] for d in cats]
|
59 |
+
cats = list(set(cats1+cats2))
|
60 |
+
subjects.append(cats)
|
61 |
+
return subjects
|
62 |
+
|
63 |
+
|
64 |
+
class CaptionModelSimple():
|
65 |
+
def __init__(self, qa_model):
|
66 |
+
self.qa_model = qa_model
|
67 |
+
|
68 |
+
def get_captions(self, images):
|
69 |
+
captions = []
|
70 |
+
for i, image in enumerate(images):
|
71 |
+
caption = self.qa_model.generate_description(image, images)
|
72 |
+
captions.append(caption)
|
73 |
+
return captions
|
74 |
+
|
75 |
+
class CaptionModelGranular():
|
76 |
+
def __init__(self):
|
77 |
+
df_train = pd.read_json("captions/coco_train_captions.jsonl", lines=True)
|
78 |
+
df_val = pd.read_json("captions/coco_val_captions.jsonl", lines=True)
|
79 |
+
self.caption_dict = {}
|
80 |
+
for i in range(len(df_train)):
|
81 |
+
self.caption_dict[str(df_train.image_id[i])] = df_train.caption[i]
|
82 |
+
for i in range(len(df_val)):
|
83 |
+
self.caption_dict[str(df_val.image_id[i])] = df_val.caption[i]
|
84 |
+
|
85 |
+
def get_captions(self, images):
|
86 |
+
captions = []
|
87 |
+
for i, image in enumerate(images):
|
88 |
+
captions.append(self.caption_dict[image.split('.')[0].split('_')[-1].lstrip('0')])
|
89 |
+
return captions
|
model/model/question_asking_model.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import pandas as pd
|
3 |
+
from model.model.question_model_base import QuestionAskingModel
|
4 |
+
import openai
|
5 |
+
|
6 |
+
def get_question_model(args):
|
7 |
+
if args.question_strategy=="rule":
|
8 |
+
return QuestionAskingModelSimple(args)
|
9 |
+
elif args.question_strategy=="gpt3":
|
10 |
+
return QuestionAskingModelGPT3(args)
|
11 |
+
else:
|
12 |
+
raise ValueError(f"{args.question_strategy} is not a valid question strategy.")
|
13 |
+
|
14 |
+
|
15 |
+
class QuestionAskingModelSimple(QuestionAskingModel):
|
16 |
+
def __init__(self, args):
|
17 |
+
super(QuestionAskingModelSimple, self).__init__(args)
|
18 |
+
|
19 |
+
def get_questions(self, images, captions, target_idx=0):
|
20 |
+
questions = []
|
21 |
+
for i, (image, caption) in enumerate(zip(images, captions)):
|
22 |
+
image_questions = self.question_generator.generate_is_there_question(caption)
|
23 |
+
if i == target_idx: target_questions = image_questions
|
24 |
+
questions += image_questions
|
25 |
+
questions = list(set(questions))
|
26 |
+
# random.shuffle(questions)
|
27 |
+
return questions, target_questions
|
28 |
+
|
29 |
+
|
30 |
+
class QuestionAskingModelGPT3(QuestionAskingModel):
|
31 |
+
def __init__(self, args):
|
32 |
+
super(QuestionAskingModelGPT3, self).__init__(args)
|
33 |
+
self.gpt3_path = f"data/{args.gpt3_save_name}.csv"
|
34 |
+
try: self.gpt3_captions = pd.read_csv(self.gpt3_path) # cache locally to save GPT3 compute
|
35 |
+
except: self.gpt3_captions = pd.DataFrame({"caption":[], "question":[]})
|
36 |
+
|
37 |
+
def generate_gpt3_questions(self, caption):
|
38 |
+
# c1="Two people sitting on grass. The man in a blue shirt is playing a guitar and is on the left. The woman on the right is eating a sandwich."
|
39 |
+
# q1="What are the two people doing? How many people are there? What color is the man's shirt? Where is the man? Where is the woman? What is the man doing? What is the woman doing? Who is playing the guitar? Who is eating a sandwich?"
|
40 |
+
# c2="There is a table in the middle of the room. On the left there is a bowl of red apples. To the right of the bowl, there is a glass of juice, as well as a bowl of salad. Beside the table there is a bookshelf with ten books of various colors."
|
41 |
+
# q2="What is beside the table? What is on the left of the table? What color are the apples? How many bowls are there? What is inside the bowl? What is inside the glass? How many books are there? What color are the books?"
|
42 |
+
c1="A living room with a couch, coffee table and two large windows with white curtains."
|
43 |
+
q1="What color is the couch? How many windows are there? How many tables are there? What color is the table? What color are the curtains? What is next to the table? What is next to the couch?"
|
44 |
+
c2="A large, shiny, stainless, side by side refrigerator in a kitchen."
|
45 |
+
q2="Where is the refrigerator? What color is the refrigerator?"
|
46 |
+
c3="A stop sign with a skeleton painted on it, next to a car."
|
47 |
+
q3="What color is the sign? What color is the car? What is next to the sign? What is next to the car? What is on the sign? Where is the car?"
|
48 |
+
c4="A man brushing his teeth with a toothbrush"
|
49 |
+
q4="What is the man doing? Where is the man? What color is the toothbrush? How many people are there?"
|
50 |
+
prompt=f"Generate questions for the following caption:\nCaption: {c1}\nQuestions: {q1}\n"
|
51 |
+
prompt+=f"Generate questions for the following caption:\nCaption: {c2}\nQuestions: {q2}\n"
|
52 |
+
prompt+=f"Generate questions for the following caption:\nCaption: {c3}\nQuestions: {q3}\n"
|
53 |
+
prompt+=f"Generate questions for the following caption:\nCaption: {c4}\nQuestions: {q4}\n"
|
54 |
+
prompt+=f"Generate questions for the following caption:\nCaption: {caption}\nQuestions:"
|
55 |
+
response = openai.Completion.create(
|
56 |
+
model="text-davinci-003",
|
57 |
+
prompt=prompt,
|
58 |
+
temperature=0,
|
59 |
+
max_tokens=1024,
|
60 |
+
top_p=1,
|
61 |
+
frequency_penalty=0,
|
62 |
+
presence_penalty=0
|
63 |
+
)
|
64 |
+
questions = response["choices"][0]["text"]
|
65 |
+
questions = [i.strip() for i in questions.split('?') if len(i.strip())>1]
|
66 |
+
questions = [i+"?" for i in questions]
|
67 |
+
return questions
|
68 |
+
|
69 |
+
def get_questions(self, images, captions, target_idx=0):
|
70 |
+
questions = []
|
71 |
+
for i, (image, caption) in enumerate(zip(images, captions)):
|
72 |
+
if caption in self.gpt3_captions.caption.tolist():
|
73 |
+
image_questions = self.gpt3_captions[self.gpt3_captions.caption==caption].question.tolist()
|
74 |
+
else:
|
75 |
+
image_questions = self.generate_gpt3_questions(caption)
|
76 |
+
image_df = pd.DataFrame({'caption':[caption for _ in image_questions], 'question':image_questions})
|
77 |
+
self.gpt3_captions = pd.concat([self.gpt3_captions, image_df])
|
78 |
+
self.gpt3_captions.to_csv(self.gpt3_path, index=False)
|
79 |
+
if i == target_idx: target_questions = image_questions
|
80 |
+
questions += image_questions
|
81 |
+
questions = list(set(questions))
|
82 |
+
# random.shuffle(questions)
|
83 |
+
return questions, target_questions
|
model/model/question_generator.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spacy
|
2 |
+
import nltk.tree
|
3 |
+
import collections
|
4 |
+
|
5 |
+
import benepar
|
6 |
+
### Run this the first time if benepar_en3 is not yet downloaded
|
7 |
+
# benepar.download('benepar_en3')
|
8 |
+
|
9 |
+
def find_word(tree, kind, parents=None):
|
10 |
+
if parents is None:
|
11 |
+
parents = []
|
12 |
+
if not isinstance(tree, nltk.tree.Tree):
|
13 |
+
return None, None
|
14 |
+
if tree.label() == kind:
|
15 |
+
return tree[0], parents
|
16 |
+
parents.append(tree)
|
17 |
+
for st in tree:
|
18 |
+
n, p = find_word(st, kind, parents)
|
19 |
+
if n is not None:
|
20 |
+
return n, p
|
21 |
+
parents.pop()
|
22 |
+
return None, None
|
23 |
+
|
24 |
+
def find_subtrees(tree, kind, blocking_kinds=()):
|
25 |
+
result = []
|
26 |
+
if not isinstance(tree, nltk.tree.Tree):
|
27 |
+
return result
|
28 |
+
if tree.label() == kind:
|
29 |
+
result.append(tree)
|
30 |
+
if tree.label() not in blocking_kinds:
|
31 |
+
for st in tree:
|
32 |
+
result.extend(find_subtrees(st, kind))
|
33 |
+
return result
|
34 |
+
|
35 |
+
def tree_to_str(tree, transform=lambda w: w):
|
36 |
+
l = []
|
37 |
+
def list_words(tree):
|
38 |
+
if isinstance(tree, str):
|
39 |
+
l.append(transform(tree))
|
40 |
+
else:
|
41 |
+
for st in tree:
|
42 |
+
list_words(st)
|
43 |
+
list_words(tree)
|
44 |
+
if l[-1] == '.':
|
45 |
+
l = l[:-1]
|
46 |
+
return ' '.join(l)
|
47 |
+
|
48 |
+
def tree_to_nouns(tree, transform=lambda w: w):
|
49 |
+
l = []
|
50 |
+
def list_words(tree, noun=False):
|
51 |
+
if isinstance(tree, str):
|
52 |
+
if noun == True:
|
53 |
+
l.append(transform(tree))
|
54 |
+
else:
|
55 |
+
for st in tree:
|
56 |
+
if not isinstance(st, str):
|
57 |
+
if st.label() == 'NN':
|
58 |
+
noun = True
|
59 |
+
else:
|
60 |
+
noun = False
|
61 |
+
list_words(st, noun)
|
62 |
+
list_words(tree)
|
63 |
+
if l[-1] == '.':
|
64 |
+
l = l[:-1]
|
65 |
+
return l
|
66 |
+
|
67 |
+
def make_determinate(w):
|
68 |
+
if w.lower() in ('a', 'an'):
|
69 |
+
return 'the'
|
70 |
+
return w
|
71 |
+
|
72 |
+
def make_indeterminate(w):
|
73 |
+
if w.lower() in ('the', 'his', 'her', 'their', 'its'):
|
74 |
+
return 'a'
|
75 |
+
return w
|
76 |
+
|
77 |
+
def pluralize(singular, plural, number):
|
78 |
+
if number <= 1:
|
79 |
+
return singular
|
80 |
+
return plural
|
81 |
+
|
82 |
+
def count_labels(tree):
|
83 |
+
counts = collections.defaultdict(int)
|
84 |
+
def update_counts(node):
|
85 |
+
counts[node.label()] += 1
|
86 |
+
for child in node:
|
87 |
+
if isinstance(child, nltk.tree.Tree):
|
88 |
+
update_counts(child)
|
89 |
+
update_counts(tree)
|
90 |
+
return counts
|
91 |
+
|
92 |
+
def get_number(tree):
|
93 |
+
if not isinstance(tree, nltk.tree.Tree):
|
94 |
+
return 0
|
95 |
+
if tree.label() == 'NN':
|
96 |
+
return 1
|
97 |
+
if tree.label() == 'NNS':
|
98 |
+
return 2
|
99 |
+
first_noun_number = None
|
100 |
+
n_np_children = 0
|
101 |
+
for subtree in tree:
|
102 |
+
label = subtree.label() if isinstance(subtree, nltk.tree.Tree) else None
|
103 |
+
if label == 'NP':
|
104 |
+
n_np_children += 1
|
105 |
+
if label in ('NP', 'NN', 'NNS') and first_noun_number is None:
|
106 |
+
first_noun_number = get_number(subtree)
|
107 |
+
if tree.label() == 'NP' and n_np_children > 1:
|
108 |
+
return 2
|
109 |
+
return first_noun_number or 0
|
110 |
+
|
111 |
+
def is_present_continuous(verb):
|
112 |
+
return verb.endswith('ing')
|
113 |
+
|
114 |
+
class QuestionGenerator:
|
115 |
+
def __init__(self):
|
116 |
+
self.parser = benepar.Parser("benepar_en3")
|
117 |
+
|
118 |
+
def generate_what_question(self, s):
|
119 |
+
tree = self.parser.parse(s)[0]
|
120 |
+
questions = []
|
121 |
+
try:
|
122 |
+
if len(tree) >= 2 and tree[0].label() == 'NP' and tree[1].label() == 'VP':
|
123 |
+
np = tree[0]
|
124 |
+
verb = None
|
125 |
+
vp = tree[1]
|
126 |
+
vnp = None
|
127 |
+
|
128 |
+
while True:
|
129 |
+
verb, verb_parents = find_word(vp, 'VBG')
|
130 |
+
if verb is None:
|
131 |
+
break
|
132 |
+
if is_present_continuous(verb):
|
133 |
+
if len(verb_parents[-1]) > 1:
|
134 |
+
vnp = verb_parents[-1][1]
|
135 |
+
break
|
136 |
+
else:
|
137 |
+
vp = verb_parents[-1][1]
|
138 |
+
to_be = pluralize('is', 'are', get_number(np))
|
139 |
+
if vnp is not None and vnp.label() == 'NP':
|
140 |
+
questions.append(('What {} {} {}?'
|
141 |
+
.format(to_be,
|
142 |
+
tree_to_str(np, make_determinate).lower(),
|
143 |
+
verb),
|
144 |
+
tree_to_nouns(vnp)[-1]))
|
145 |
+
except Exception as e:
|
146 |
+
print(e)
|
147 |
+
return questions
|
148 |
+
|
149 |
+
def generate_is_there_question(self, s):
|
150 |
+
tree = self.parser.parse(s)
|
151 |
+
questions = []
|
152 |
+
nps = find_subtrees(tree, 'NP', ('PP',))
|
153 |
+
for np in nps:
|
154 |
+
only_child_label = len(np) == 1 and next(iter(np)).label()
|
155 |
+
if only_child_label in ('PRP', 'EX'):
|
156 |
+
continue
|
157 |
+
try:
|
158 |
+
to_be = pluralize('Is', 'Are', get_number(np))
|
159 |
+
questions.append('{} there {}?'
|
160 |
+
.format(to_be,
|
161 |
+
tree_to_str(np, make_indeterminate).lower()))
|
162 |
+
except Exception as e:
|
163 |
+
print(e)
|
164 |
+
return questions
|
165 |
+
|
166 |
+
def generate_is_there_question_v2(self, s):
|
167 |
+
tree = self.parser.parse(s)
|
168 |
+
questions = []
|
169 |
+
nps = find_subtrees(tree, 'NP', ('PP',))
|
170 |
+
for np in nps:
|
171 |
+
only_child_label = len(np) == 1 and next(iter(np)).label()
|
172 |
+
if only_child_label in ('PRP', 'EX'):
|
173 |
+
continue
|
174 |
+
try:
|
175 |
+
to_be = pluralize('Is', 'Are', get_number(np))
|
176 |
+
questions.append('{} there {}?'
|
177 |
+
.format(to_be,
|
178 |
+
tree_to_str(np, make_indeterminate).lower()))
|
179 |
+
except Exception as e:
|
180 |
+
print(e)
|
181 |
+
if len(questions)==0:
|
182 |
+
nps = find_subtrees(tree, 'NNS', ('PP',))
|
183 |
+
for np in nps:
|
184 |
+
only_child_label = len(np) == 1
|
185 |
+
if only_child_label in ('PRP', 'EX'):
|
186 |
+
continue
|
187 |
+
try:
|
188 |
+
to_be = pluralize('Is', 'Are', get_number(np))
|
189 |
+
questions.append('{} there {}?'
|
190 |
+
.format(to_be,
|
191 |
+
tree_to_str(np, make_indeterminate).lower()))
|
192 |
+
except Exception as e:
|
193 |
+
print(e)
|
194 |
+
return questions
|
model/model/question_model_base.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import scipy
|
3 |
+
import numpy as np
|
4 |
+
import operator
|
5 |
+
import random
|
6 |
+
from model.model.question_generator import QuestionGenerator
|
7 |
+
|
8 |
+
|
9 |
+
class QuestionAskingModel():
|
10 |
+
def __init__(self, args):
|
11 |
+
self.device = args.device
|
12 |
+
self.include_what = args.include_what
|
13 |
+
self.max_length = 128
|
14 |
+
self.eps = 1e-25
|
15 |
+
self.multiplier_mode = args.multiplier_mode
|
16 |
+
self.num_images = args.num_images
|
17 |
+
|
18 |
+
# Initialize question generation model
|
19 |
+
class Namespace:
|
20 |
+
def __init__(self, **kwargs):
|
21 |
+
self.__dict__.update(kwargs)
|
22 |
+
|
23 |
+
self.question_generator = QuestionGenerator()
|
24 |
+
self.reset_question_bank()
|
25 |
+
|
26 |
+
def get_questions(self, images, captions, target_idx=0):
|
27 |
+
raise NotImplemented
|
28 |
+
|
29 |
+
def get_negative_information_gain(self, p_y_x, question, image_set, caption_set, response_model):
|
30 |
+
if question in self.question_bank:
|
31 |
+
p_r_qy = self.question_bank[question]
|
32 |
+
else:
|
33 |
+
is_a_questions = self.question_generator.generate_is_there_question_v2(question)
|
34 |
+
is_a_multiplier = []
|
35 |
+
if self.multiplier_mode != "none":
|
36 |
+
for is_a_q in is_a_questions:
|
37 |
+
# print(f"IsA Question: {is_a_q}")
|
38 |
+
p_r_qy = response_model.get_p_r_qy(None, is_a_q, image_set, caption_set, is_a=True)
|
39 |
+
p_r_qy = p_r_qy.detach().cpu().numpy()
|
40 |
+
is_a_multiplier.append(p_r_qy)
|
41 |
+
if len(is_a_multiplier)==0: is_a_multiplier.append([0 for _ in range(self.num_images)])
|
42 |
+
is_a_multiplier = torch.tensor(scipy.stats.mstats.gmean(is_a_multiplier, axis=0)).to("cuda")
|
43 |
+
if self.multiplier_mode=="hard":
|
44 |
+
for i in range(is_a_multiplier.shape[0]):
|
45 |
+
if is_a_multiplier[i]<0.5: is_a_multiplier[i]=1e-6
|
46 |
+
else: is_a_multiplier[i]=0.9
|
47 |
+
elif self.multiplier_mode=="soft":
|
48 |
+
pass
|
49 |
+
elif self.multiplier_mode == "none":
|
50 |
+
is_a_multiplier = torch.tensor([1 for _ in range(self.num_images)]).to("cuda")
|
51 |
+
p_r_qy = response_model.get_p_r_qy(None, question, image_set, caption_set)
|
52 |
+
p_r_qy = torch.stack([is_a_multiplier*p_r_qy[r] for r in range(len(p_r_qy))])
|
53 |
+
for i in range(self.num_images):
|
54 |
+
if is_a_multiplier[i] < 0.5:
|
55 |
+
p_r_qy[i] = 1-is_a_multiplier
|
56 |
+
else:
|
57 |
+
p_r_qy[i] *= is_a_multiplier
|
58 |
+
if self.multiplier_mode=="none":
|
59 |
+
p_r_qy = response_model.get_p_r_qy(None, question, image_set, caption_set)
|
60 |
+
if not self.include_what:
|
61 |
+
p_r_qy = [p_r_qy, 1-p_r_qy]
|
62 |
+
self.question_bank[question] = p_r_qy
|
63 |
+
|
64 |
+
p_y_xqr = torch.stack([p_y_x*p_r_qy[r] for r in range(len(p_r_qy))])
|
65 |
+
p_y_xqr = [p_y_xqr[r]/torch.sum(p_y_xqr[r]) if torch.sum(p_y_xqr[r]) != 0 \
|
66 |
+
else [0]*len(p_y_xqr[r]) for r in range(len(p_y_xqr))]
|
67 |
+
return torch.sum(torch.stack([p_r_qy[r]*p_y_x*torch.log2(1/(p_y_xqr[r]+self.eps)) for r in range(len(p_r_qy))]))
|
68 |
+
|
69 |
+
def get_question_ranking(self, p_y_x, question_set, image_set, caption_set, response_model):
|
70 |
+
H_y_rxq = [0]*len(question_set)
|
71 |
+
|
72 |
+
for i, question in enumerate(question_set):
|
73 |
+
H_y_rxq[i] = self.get_negative_information_gain(p_y_x, question, image_set, caption_set, response_model)
|
74 |
+
|
75 |
+
IG = - torch.stack(H_y_rxq).unsqueeze(1)
|
76 |
+
ranked_questions = sorted(zip(list(IG.data.cpu().numpy()), question_set),
|
77 |
+
key = operator.itemgetter(0))[::-1]
|
78 |
+
return ranked_questions
|
79 |
+
|
80 |
+
def select_best_question(self, p_y_x, question_set, image_set, caption_set, response_model):
|
81 |
+
ranked_questions = self.get_question_ranking(p_y_x, question_set, image_set, caption_set, response_model)
|
82 |
+
return ranked_questions[0][1]
|
83 |
+
|
84 |
+
def reset_question_bank(self):
|
85 |
+
self.question_bank = {}
|
model/model/response_model.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import AutoProcessor, AutoTokenizer, AutoModelForQuestionAnswering, pipeline
|
4 |
+
from transformers import ViltProcessor, ViltForQuestionAnswering
|
5 |
+
from transformers import BlipProcessor, BlipForQuestionAnswering
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
import openai
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
def get_response_model(args, response_type):
|
11 |
+
if response_type=="QA":
|
12 |
+
return ResponseModelQA(args.device, args.include_what)
|
13 |
+
elif response_type=="VQA1":
|
14 |
+
return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="vilt1")
|
15 |
+
elif response_type=="VQA2":
|
16 |
+
return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="vilt2")
|
17 |
+
elif response_type=="VQA3":
|
18 |
+
return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="blip")
|
19 |
+
elif response_type=="VQA4":
|
20 |
+
return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="git")
|
21 |
+
else:
|
22 |
+
raise ValueError(f"{response_type} is not a valid response type.")
|
23 |
+
|
24 |
+
|
25 |
+
class ResponseModel(nn.Module):
|
26 |
+
# Class for the other ResponseModels to inherit from
|
27 |
+
def __init__(self, device, include_what):
|
28 |
+
super(ResponseModel, self).__init__()
|
29 |
+
self.device = device
|
30 |
+
self.include_what = include_what
|
31 |
+
self.model = None
|
32 |
+
|
33 |
+
def get_response(self, question, image, caption, target_questions, **kwargs):
|
34 |
+
raise NotImplemented
|
35 |
+
|
36 |
+
def get_p_r_qy(self, response, question, images, captions, **kwargs):
|
37 |
+
raise NotImplemented
|
38 |
+
|
39 |
+
class ResponseModelQA(ResponseModel):
|
40 |
+
def __init__(self, device, include_what):
|
41 |
+
super(ResponseModelQA, self).__init__(device, include_what)
|
42 |
+
if not self.include_what:
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained("AmazonScience/qanlu")
|
44 |
+
model = AutoModelForQuestionAnswering.from_pretrained("AmazonScience/qanlu")
|
45 |
+
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=0) # remove device=0 for cpu
|
46 |
+
elif self.include_what:
|
47 |
+
tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
|
48 |
+
model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
|
49 |
+
self.model_wh = pipeline('question-answering', model=model, tokenizer=tokenizer, device=0) # remove device=0 for cpu
|
50 |
+
|
51 |
+
def get_response(self, question, image, caption, target_questions, **kwargs):
|
52 |
+
if self.include_what:
|
53 |
+
answer = self.model({'context':caption, 'question':question})
|
54 |
+
return answer['answer'].split(' ')[-1]
|
55 |
+
else:
|
56 |
+
answer = self.model({'context':f"Yes. No. {caption}", 'question':question})
|
57 |
+
response, score = answer['answer'], answer['score']
|
58 |
+
if score>0.5:
|
59 |
+
response = response.lower().replace('.','')
|
60 |
+
if "yes" in response.split() and "no" not in response.split():
|
61 |
+
response = 'yes'
|
62 |
+
elif "no" in response.split() and "yes" not in response.split():
|
63 |
+
response = 'no'
|
64 |
+
else:
|
65 |
+
response = 'yes' if question in target_questions else 'no'
|
66 |
+
else:
|
67 |
+
response = 'yes' if question in target_questions else 'no'
|
68 |
+
return response
|
69 |
+
|
70 |
+
def get_p_r_qy(self, response, question, images, captions, **kwargs):
|
71 |
+
if self.include_what:
|
72 |
+
raise NotImplementedError
|
73 |
+
else:
|
74 |
+
p_r_qy = torch.zeros(len(captions))
|
75 |
+
qa_input = {'context':[f"Yes. No. {c}" for c in captions], 'question':[question for _ in captions]}
|
76 |
+
answers = self.model(qa_input)
|
77 |
+
for idx, answer in enumerate(answers):
|
78 |
+
curr_ans, score = answer['answer'], answer['score']
|
79 |
+
if curr_ans.strip() in ["Yes.", "No."]:
|
80 |
+
if response==None:
|
81 |
+
if curr_ans.strip()=="No.": p_r_qy[idx] = 1-score
|
82 |
+
if curr_ans.strip()=="Yes.": p_r_qy[idx] = score
|
83 |
+
elif curr_ans.strip().lower().replace('.','')==response: p_r_qy[idx]=score
|
84 |
+
else: p_r_qy[idx]=1-score
|
85 |
+
else:
|
86 |
+
p_r_qy[idx]=0.5
|
87 |
+
return p_r_qy.to(self.device)
|
88 |
+
|
89 |
+
class ResponseModelVQA(ResponseModel):
|
90 |
+
def __init__(self, device, include_what, question_generator, vqa_type):
|
91 |
+
super(ResponseModelVQA, self).__init__(device, include_what)
|
92 |
+
self.vqa_type = vqa_type
|
93 |
+
self.question_generator = question_generator
|
94 |
+
self.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
|
95 |
+
if vqa_type=="vilt1":
|
96 |
+
self.processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
97 |
+
self.model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device)
|
98 |
+
self.vocab = list(self.model.config.label2id.keys())
|
99 |
+
elif vqa_type=="vilt2":
|
100 |
+
self.processor = AutoProcessor.from_pretrained("tufa15nik/vilt-finetuned-vqasi")
|
101 |
+
self.model = ViltForQuestionAnswering.from_pretrained("tufa15nik/vilt-finetuned-vqasi").to("cuda")
|
102 |
+
self.vocab = list(self.model.config.label2id.keys())
|
103 |
+
elif vqa_type=="blip":
|
104 |
+
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
105 |
+
self.model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
|
106 |
+
elif vqa_type=="git":
|
107 |
+
pass
|
108 |
+
else:
|
109 |
+
raise ValueError(f"{vqa_type} is not a valid vqa_type.")
|
110 |
+
|
111 |
+
|
112 |
+
def get_response(self, question, image, caption, target_questions, is_a=False):
|
113 |
+
encoding = self.processor(image, question, return_tensors="pt").to(self.device)
|
114 |
+
if is_a==False:
|
115 |
+
is_a_questions = self.question_generator.generate_is_there_question_v2(question)
|
116 |
+
is_a_responses = []
|
117 |
+
if question in ["What is in the photo?", "What is in the picture?", "What is in the background?"]:
|
118 |
+
is_a_questions = []
|
119 |
+
for q in is_a_questions:
|
120 |
+
is_a_responses.append(self.get_response(q, image, caption, target_questions, is_a=True))
|
121 |
+
no_cnt = sum([i.lower()=="no" for i in is_a_responses])
|
122 |
+
if len(is_a_responses)>0 and no_cnt/len(is_a_responses)>=0.5:
|
123 |
+
if question[:8]=="How many": return "0"
|
124 |
+
else: return "nothing"
|
125 |
+
if self.vqa_type in ["vilt1", "vilt2"]:
|
126 |
+
outputs = self.model(**encoding)
|
127 |
+
logits = torch.nn.functional.softmax(outputs.logits, dim=1)
|
128 |
+
idx = logits.argmax(-1).item()
|
129 |
+
response = self.model.config.id2label[idx]
|
130 |
+
response = response.lower().replace('.','').strip()
|
131 |
+
elif self.vqa_type == "blip":
|
132 |
+
outputs = self.model.generate(**encoding)
|
133 |
+
response = self.processor.decode(outputs[0], skip_special_tokens=True)
|
134 |
+
return response
|
135 |
+
|
136 |
+
def get_p_r_qy(self, response, question, images, captions, is_a=False):
|
137 |
+
p_r_qy = torch.zeros(len(captions))
|
138 |
+
logits_arr = []
|
139 |
+
for i, image in enumerate(images):
|
140 |
+
with torch.no_grad():
|
141 |
+
if len(question) > 150: question="" # ignore question if too long
|
142 |
+
encoding = self.processor(image, question, return_tensors="pt").to(self.device)
|
143 |
+
outputs = self.model(**encoding)
|
144 |
+
logits = torch.nn.functional.softmax(outputs.logits, dim=1)
|
145 |
+
idx = logits.argmax(-1).item()
|
146 |
+
curr_response = self.model.config.id2label[idx]
|
147 |
+
curr_response = curr_response.lower().replace('.','').strip()
|
148 |
+
if self.include_what==False or is_a==True:
|
149 |
+
if response==None:
|
150 |
+
if curr_response=="yes": p_r_qy[i] = logits[0][3].item()
|
151 |
+
elif curr_response=="no": p_r_qy[i] = 1-logits[0][9].item()
|
152 |
+
else: p_r_qy[i] = 0.5
|
153 |
+
elif curr_response==response: p_r_qy[i] = logits[0][idx].item()
|
154 |
+
else: p_r_qy[i] = 1-logits[0][idx].item()
|
155 |
+
else:
|
156 |
+
logits_arr.append(logits)
|
157 |
+
if self.include_what==False or is_a==True:
|
158 |
+
return p_r_qy.to(self.device)
|
159 |
+
else:
|
160 |
+
logits = torch.concat(logits_arr)
|
161 |
+
if response==None:
|
162 |
+
top_answers = logits.argmax(1)
|
163 |
+
p_r_qy = logits[:,top_answers]
|
164 |
+
else:
|
165 |
+
response_idx = self.get_response_idx(response)
|
166 |
+
p_r_qy = logits[:,response_idx]
|
167 |
+
|
168 |
+
# check if this
|
169 |
+
# consider rerunning also without the geometric mean
|
170 |
+
if response=="nothing":
|
171 |
+
is_a_questions = self.question_generator.generate_is_there_question_v2(question)
|
172 |
+
for idx, (caption, image) in enumerate(zip(captions, images)):
|
173 |
+
current_responses = []
|
174 |
+
for is_a_q in is_a_questions:
|
175 |
+
current_responses.append(self.get_response(is_a_q, image, caption, None, is_a=True))
|
176 |
+
no_cnt = sum([i.lower()=="no" for i in current_responses])
|
177 |
+
if len(current_responses)>0 and no_cnt/len(current_responses)>=0.5:
|
178 |
+
p_r_qy[idx] = 1.0
|
179 |
+
return p_r_qy.to(self.device)
|
180 |
+
|
181 |
+
def get_response_idx(self, response):
|
182 |
+
if response in self.model.config.label2id:
|
183 |
+
return self.model.config.label2id[response]
|
184 |
+
else:
|
185 |
+
embs = self.sentence_transformer.encode(self.vocab, convert_to_tensor=True)
|
186 |
+
emb_response = self.sentence_transformer.encode([response], convert_to_tensor=True)
|
187 |
+
dists = torch.nn.CosineSimilarity(-1)(emb_response, embs)
|
188 |
+
best_response_idx = torch.argmax(dists)
|
189 |
+
best_response = self.vocab[best_response_idx]
|
190 |
+
return self.model.config.label2id[best_response]
|
model/run_question_asking_model.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.model.question_asking_model import get_question_model
|
2 |
+
from model.model.caption_model import get_caption_model
|
3 |
+
from model.model.response_model import get_response_model
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset, DataLoader
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import random
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
import logging
|
15 |
+
from model.utils import logging_handler, image_saver, assert_checks
|
16 |
+
|
17 |
+
random.seed(123)
|
18 |
+
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument('--device', type=str, default='cuda')
|
21 |
+
parser.add_argument('--include_what', action='store_true')
|
22 |
+
parser.add_argument('--target_idx', type=int, default=0)
|
23 |
+
parser.add_argument('--max_num_questions', type=int, default=25)
|
24 |
+
parser.add_argument('--num_images', type=int, default=10)
|
25 |
+
parser.add_argument('--beam', type=int, default=1)
|
26 |
+
parser.add_argument('--num_samples', type=int, default=100)
|
27 |
+
parser.add_argument('--threshold', type=float, default=0.9)
|
28 |
+
|
29 |
+
parser.add_argument('--caption_strategy', type=str, default='simple', choices=['simple', 'granular', 'gtruth'])
|
30 |
+
parser.add_argument('--sample_strategy', type=str, default='random', choices=['random', 'attribute', 'clip'])
|
31 |
+
parser.add_argument('--attribute_n', type=int, default=1) # Number of attributes to split
|
32 |
+
parser.add_argument('--response_type_simul', type=str, default='VQA1', choices=['simple', 'QA', 'VQA1', 'VQA2', 'VQA3', 'VQA4'])
|
33 |
+
parser.add_argument('--response_type_gtruth', type=str, default='VQA2', choices=['simple', 'QA', 'VQA1', 'VQA2', 'VQA3', 'VQA4'])
|
34 |
+
parser.add_argument('--question_strategy', type=str, default='gpt3', choices=['rule', 'gpt3'])
|
35 |
+
parser.add_argument('--multiplier_mode', type=str, default='soft', choices=['soft', 'hard', 'none'])
|
36 |
+
|
37 |
+
parser.add_argument('--gpt3_save_name', type=str, default='questions_gpt3')
|
38 |
+
parser.add_argument('--save_name', type=str, default=None)
|
39 |
+
parser.add_argument('--verbose', action='store_true')
|
40 |
+
args = parser.parse_args()
|
41 |
+
args.question_strategy='gpt3'
|
42 |
+
args.include_what=True
|
43 |
+
args.response_type_simul='VQA1'
|
44 |
+
args.response_type_gtruth='VQA3'
|
45 |
+
args.multiplier_mode='soft'
|
46 |
+
args.sample_strategy='attribute'
|
47 |
+
args.attribute_n=1
|
48 |
+
args.caption_strategy='gtruth'
|
49 |
+
assert_checks(args)
|
50 |
+
if args.save_name is None: logger = logging_handler(args.verbose, args.save_name)
|
51 |
+
args.load_response_model = True
|
52 |
+
|
53 |
+
print("1. Loading question model ...")
|
54 |
+
question_model = get_question_model(args)
|
55 |
+
args.question_generator = question_model.question_generator
|
56 |
+
print("2. Loading response model simul ...")
|
57 |
+
response_model_simul = get_response_model(args, args.response_type_simul)
|
58 |
+
response_model_simul.to(args.device)
|
59 |
+
print("3. Loading response model gtruth ...")
|
60 |
+
response_model_gtruth = get_response_model(args, args.response_type_gtruth)
|
61 |
+
response_model_gtruth.to(args.device)
|
62 |
+
print("4. Loading caption model ...")
|
63 |
+
caption_model = get_caption_model(args, question_model)
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
def return_modules():
|
68 |
+
return question_model, response_model_simul, response_model_gtruth, caption_model
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
args.question_strategy='rule'
|
73 |
+
args.include_what=False
|
74 |
+
args.response_type_simul='VQA1'
|
75 |
+
args.response_type_gtruth='VQA3'
|
76 |
+
args.multiplier_mode='none'
|
77 |
+
args.sample_strategy='attribute'
|
78 |
+
args.attribute_n=1
|
79 |
+
args.caption_strategy='gtruth'
|
80 |
+
|
81 |
+
print("1. Loading question model ...")
|
82 |
+
question_model_yn = get_question_model(args)
|
83 |
+
args.question_generator_yn = question_model_yn.question_generator
|
84 |
+
print("2. Loading response model simul ...")
|
85 |
+
response_model_simul_yn = get_response_model(args, args.response_type_simul)
|
86 |
+
response_model_simul_yn.to(args.device)
|
87 |
+
print("3. Loading response model gtruth ...")
|
88 |
+
response_model_gtruth_yn = get_response_model(args, args.response_type_gtruth)
|
89 |
+
response_model_gtruth_yn.to(args.device)
|
90 |
+
print("4. Loading caption model ...")
|
91 |
+
caption_model_yn = get_caption_model(args, question_model_yn)
|
92 |
+
|
93 |
+
|
94 |
+
def return_modules_yn():
|
95 |
+
return question_model_yn, response_model_simul_yn, response_model_gtruth_yn, caption_model_yn
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
# args.question_strategy='gpt3'
|
100 |
+
# args.include_what=True
|
101 |
+
# args.response_type_simul='VQA1'
|
102 |
+
# args.response_type_gtruth='VQA3'
|
103 |
+
# args.multiplier_mode='none'
|
104 |
+
# args.sample_strategy='attribute'
|
105 |
+
# args.attribute_n=1
|
106 |
+
# args.caption_strategy='gtruth'
|
107 |
+
# assert_checks(args)
|
108 |
+
# if args.save_name is None: logger = logging_handler(args.verbose, args.save_name)
|
109 |
+
# args.load_response_model = True
|
110 |
+
|
111 |
+
# print("1. Loading question model ...")
|
112 |
+
# question_model = get_question_model(args)
|
113 |
+
# args.question_generator = question_model.question_generator
|
114 |
+
# print("2. Loading response model simul ...")
|
115 |
+
# response_model_simul = get_response_model(args, args.response_type_simul)
|
116 |
+
# response_model_simul.to(args.device)
|
117 |
+
# print("3. Loading response model gtruth ...")
|
118 |
+
# response_model_gtruth = get_response_model(args, args.response_type_gtruth)
|
119 |
+
# response_model_gtruth.to(args.device)
|
120 |
+
# print("4. Loading caption model ...")
|
121 |
+
# caption_model = get_caption_model(args, question_model)
|
122 |
+
|
123 |
+
# # dataloader = DataLoader(dataset=ReferenceGameData(split='test',
|
124 |
+
# # num_images=args.num_images,
|
125 |
+
# # num_samples=args.num_samples,
|
126 |
+
# # sample_strategy=args.sample_strategy,
|
127 |
+
# # attribute_n=args.attribute_n))
|
128 |
+
|
129 |
+
# def return_modules():
|
130 |
+
# return question_model, response_model_simul, response_model_gtruth, caption_model
|
131 |
+
# # game_lens, game_preds = [], []
|
132 |
+
# for t, batch in enumerate(tqdm(dataloader)):
|
133 |
+
# image_files = [image[0] for image in batch['images'][:args.num_images]]
|
134 |
+
# image_files = [str(i).split('/')[1] for i in image_files]
|
135 |
+
# with open("mscoco_images_attribute_n=1.txt", 'a') as f:
|
136 |
+
# for i in image_files:
|
137 |
+
# f.write(str(i)+"\n")
|
138 |
+
# images = [np.asarray(Image.open(f"./../../../data/ms-coco/images/{i}")) for i in image_files]
|
139 |
+
# images = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images]
|
140 |
+
# p_y_x = (torch.ones(args.num_images)/args.num_images).to(question_model.device)
|
141 |
+
|
142 |
+
# if args.save_name is not None:
|
143 |
+
# logger = logging_handler(args.verbose, args.save_name, t)
|
144 |
+
# image_saver(images, args.save_name, t)
|
145 |
+
|
146 |
+
# captions = caption_model.get_captions(image_files)
|
147 |
+
# questions, target_questions = question_model.get_questions(image_files, captions, args.target_idx)
|
148 |
+
|
149 |
+
# question_model.reset_question_bank()
|
150 |
+
# logger.info(questions)
|
151 |
+
# for idx, c in enumerate(captions): logger.info(f"Image_{idx}: {c}")
|
152 |
+
|
153 |
+
# num_questions_original = len(questions)
|
154 |
+
# for j in range(min(args.max_num_questions, num_questions_original)):
|
155 |
+
# # Select best question
|
156 |
+
# question = question_model.select_best_question(p_y_x, questions, images, captions, response_model_simul)
|
157 |
+
# logger.info(f"Question: {question}")
|
158 |
+
|
159 |
+
# # Ask the question and get the model's response
|
160 |
+
# response = response_model_gtruth.get_response(question, images[args.target_idx], captions[args.target_idx], target_questions, is_a=1-args.include_what)
|
161 |
+
# logger.info(f"Response: {response}")
|
162 |
+
|
163 |
+
# # Update the probabilities
|
164 |
+
# p_r_qy = response_model_simul.get_p_r_qy(response, question, images, captions)
|
165 |
+
# logger.info(f"P(r|q,y):\n{np.around(p_r_qy.cpu().detach().numpy(), 3)}")
|
166 |
+
# p_y_xqr = p_y_x*p_r_qy
|
167 |
+
# p_y_xqr = p_y_xqr/torch.sum(p_y_xqr)if torch.sum(p_y_xqr) != 0 else torch.zeros_like(p_y_xqr)
|
168 |
+
# p_y_x = p_y_xqr
|
169 |
+
# logger.info(f"Updated distribution:\n{np.around(p_y_x.cpu().detach().numpy(), 3)}\n")
|
170 |
+
|
171 |
+
# # Don't repeat the same question again in the future
|
172 |
+
# questions.remove(question)
|
173 |
+
|
174 |
+
# # Terminate if probability exceeds threshold or if out of questions to ask
|
175 |
+
# top_prob = torch.max(p_y_x).item()
|
176 |
+
# if top_prob >= args.threshold or j==min(args.max_num_questions, num_questions_original)-1:
|
177 |
+
# game_preds.append(torch.argmax(p_y_x).item())
|
178 |
+
# game_lens.append(j+1)
|
179 |
+
# logger.info(f"pred:{game_preds[-1]} game_len:{game_lens[-1]}")
|
180 |
+
# break
|
181 |
+
|
182 |
+
# logger = logging_handler(args.verbose, args.save_name, "final_results")
|
183 |
+
# logger.info(f"Game lenths:\n{game_lens}")
|
184 |
+
# logger.info(sum(game_lens)/len(game_lens))
|
185 |
+
# logger.info(f"Predictions:\n{game_preds}")
|
186 |
+
# logger.info(f"Accuracy:\n{sum([i==args.target_idx for i in game_preds])/len(game_preds)}")
|
model/utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from PIL import Image
|
5 |
+
import nltk
|
6 |
+
|
7 |
+
def logging_handler(verbose, save_name, idx=0):
|
8 |
+
logger = logging.getLogger(str(idx))
|
9 |
+
logger.setLevel(logging.INFO)
|
10 |
+
|
11 |
+
stream_logger = logging.StreamHandler()
|
12 |
+
stream_logger.setFormatter(logging.Formatter("%(message)s"))
|
13 |
+
logger.addHandler(stream_logger)
|
14 |
+
|
15 |
+
if save_name is not None:
|
16 |
+
savepath = f"results/{save_name}"
|
17 |
+
if not os.path.exists(savepath):
|
18 |
+
os.makedirs(savepath)
|
19 |
+
file_logger = logging.FileHandler(f"{savepath}/{idx}.log")
|
20 |
+
file_logger.setFormatter(logging.Formatter("%(message)s"))
|
21 |
+
logger.addHandler(file_logger)
|
22 |
+
|
23 |
+
return logger
|
24 |
+
|
25 |
+
|
26 |
+
def image_saver(images, save_name, idx=0, interactive=True):
|
27 |
+
fig, a = plt.subplots(2,5)
|
28 |
+
fig.set_size_inches(30, 15)
|
29 |
+
for i in range(10):
|
30 |
+
a[i//5][i%5].imshow(images[i])
|
31 |
+
a[i//5][i%5].axis('off')
|
32 |
+
a[i//5][i%5].set_aspect('equal')
|
33 |
+
plt.tight_layout()
|
34 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
35 |
+
if not interactive:
|
36 |
+
plt.savefig(f"results/{save_name}/{idx}.png")
|
37 |
+
else:
|
38 |
+
plt.savefig(f"{save_name}.png")
|
39 |
+
|
40 |
+
def assert_checks(args):
|
41 |
+
if args.question_strategy=="gpt3":
|
42 |
+
assert args.include_what
|
43 |
+
|
44 |
+
def extract_nouns(sents):
|
45 |
+
noun_list = []
|
46 |
+
for idx, s in enumerate(sents):
|
47 |
+
curr = []
|
48 |
+
sent = (nltk.pos_tag(s.split()))
|
49 |
+
for word in sent:
|
50 |
+
if word[1] not in ["NN", "NNS"]: continue
|
51 |
+
currword = word[0].replace('.','')
|
52 |
+
curr.append(currword.lower())
|
53 |
+
noun_list.append(curr)
|
54 |
+
return noun_list
|
open_db.py
CHANGED
@@ -4,9 +4,4 @@ import pandas as pd
|
|
4 |
db = sqlite3.connect("response.db")
|
5 |
df = pd.read_sql('SELECT * from responses', db)
|
6 |
print(df)
|
7 |
-
|
8 |
-
# conn = sqlite3.connect(db)
|
9 |
-
# c = conn.cursor()
|
10 |
-
# c.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
11 |
-
# for table in c.fetchall()
|
12 |
-
# yield list(c.execute('SELECT * from ?;', (table[0],)))
|
|
|
4 |
db = sqlite3.connect("response.db")
|
5 |
df = pd.read_sql('SELECT * from responses', db)
|
6 |
print(df)
|
7 |
+
df.to_csv("responses.csv", index=False)
|
|
|
|
|
|
|
|
|
|
pilot-study.csv
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
taskID,mscoco-id,task-type
|
2 |
+
0,35,wh- standard
|
3 |
+
1,102,wh- hard (n=1)
|
4 |
+
2,25,wh- standard
|
5 |
+
3,133,yes/no hard (n=1)
|
6 |
+
4,80,wh- hard (n=1)
|
7 |
+
5,15,wh- standard
|
8 |
+
6,92,wh- hard (n=1)
|
9 |
+
7,108,wh- hard (n=1)
|
10 |
+
8,60,yes/no standard
|
11 |
+
9,57,yes/no standard
|
12 |
+
10,125,yes/no hard (n=1)
|
13 |
+
11,56,yes/no standard
|
14 |
+
12,137,yes/no hard (n=1)
|
15 |
+
13,40,yes/no standard
|
16 |
+
14,134,yes/no hard (n=1)
|
17 |
+
15,130,yes/no hard (n=1)
|
18 |
+
16,89,wh- hard (n=1)
|
19 |
+
17,19,wh- standard
|
20 |
+
18,58,yes/no standard
|
21 |
+
19,81,wh- hard (n=1)
|
22 |
+
20,5,wh- standard
|
23 |
+
21,73,yes/no standard
|
24 |
+
22,54,yes/no standard
|
25 |
+
23,0,wh- standard
|
26 |
+
24,14,wh- standard
|
27 |
+
25,113,wh- hard (n=1)
|
28 |
+
26,34,wh- standard
|
29 |
+
27,159,yes/no hard (n=1)
|
30 |
+
28,135,yes/no hard (n=1)
|
31 |
+
29,2,wh- standard
|
32 |
+
30,156,yes/no hard (n=1)
|
33 |
+
31,30,wh- standard
|
34 |
+
32,104,wh- hard (n=1)
|
35 |
+
33,128,yes/no hard (n=1)
|
36 |
+
34,18,wh- standard
|
37 |
+
35,157,yes/no hard (n=1)
|
38 |
+
36,1,wh- standard
|
39 |
+
37,42,yes/no standard
|
40 |
+
38,131,yes/no hard (n=1)
|
41 |
+
39,115,wh- hard (n=1)
|
42 |
+
40,120,yes/no hard (n=1)
|
43 |
+
41,3,wh- standard
|
44 |
+
42,63,yes/no standard
|
45 |
+
43,65,yes/no standard
|
46 |
+
44,103,wh- hard (n=1)
|
47 |
+
45,124,yes/no hard (n=1)
|
48 |
+
46,21,wh- standard
|
49 |
+
47,72,yes/no standard
|
50 |
+
48,62,yes/no standard
|
51 |
+
49,47,yes/no standard
|
52 |
+
50,78,yes/no standard
|
53 |
+
51,109,wh- hard (n=1)
|
54 |
+
52,136,yes/no hard (n=1)
|
55 |
+
53,158,yes/no hard (n=1)
|
56 |
+
54,61,yes/no standard
|
57 |
+
55,27,wh- standard
|
58 |
+
56,24,wh- standard
|
59 |
+
57,123,yes/no hard (n=1)
|
60 |
+
58,70,yes/no standard
|
61 |
+
59,91,wh- hard (n=1)
|
62 |
+
60,55,yes/no standard
|
63 |
+
61,87,wh- hard (n=1)
|
64 |
+
62,46,yes/no standard
|
65 |
+
63,33,wh- standard
|
66 |
+
64,16,wh- standard
|
67 |
+
65,147,yes/no hard (n=1)
|
68 |
+
66,85,wh- hard (n=1)
|
69 |
+
67,59,yes/no standard
|
70 |
+
68,99,wh- hard (n=1)
|
71 |
+
69,117,wh- hard (n=1)
|
72 |
+
70,9,wh- standard
|
73 |
+
71,122,yes/no hard (n=1)
|
74 |
+
72,53,yes/no standard
|
75 |
+
73,22,wh- standard
|
76 |
+
74,8,wh- standard
|
77 |
+
75,29,wh- standard
|
78 |
+
76,83,wh- hard (n=1)
|
79 |
+
77,37,wh- standard
|
80 |
+
78,66,yes/no standard
|
81 |
+
79,41,yes/no standard
|
82 |
+
80,94,wh- hard (n=1)
|
83 |
+
81,98,wh- hard (n=1)
|
84 |
+
82,110,wh- hard (n=1)
|
85 |
+
83,77,yes/no standard
|
86 |
+
84,151,yes/no hard (n=1)
|
87 |
+
85,121,yes/no hard (n=1)
|
88 |
+
86,6,wh- standard
|
89 |
+
87,45,yes/no standard
|
90 |
+
88,155,yes/no hard (n=1)
|
91 |
+
89,88,wh- hard (n=1)
|
92 |
+
90,96,wh- hard (n=1)
|
93 |
+
91,75,yes/no standard
|
94 |
+
92,112,wh- hard (n=1)
|
95 |
+
93,49,yes/no standard
|
96 |
+
94,152,yes/no hard (n=1)
|
97 |
+
95,38,wh- standard
|
98 |
+
96,7,wh- standard
|
99 |
+
97,52,yes/no standard
|
100 |
+
98,101,wh- hard (n=1)
|
101 |
+
99,76,yes/no standard
|
102 |
+
100,28,wh- standard
|
103 |
+
101,114,wh- hard (n=1)
|
104 |
+
102,139,yes/no hard (n=1)
|
105 |
+
103,74,yes/no standard
|
106 |
+
104,149,yes/no hard (n=1)
|
107 |
+
105,84,wh- hard (n=1)
|
108 |
+
106,79,yes/no standard
|
109 |
+
107,127,yes/no hard (n=1)
|
110 |
+
108,126,yes/no hard (n=1)
|
111 |
+
109,116,wh- hard (n=1)
|
112 |
+
110,71,yes/no standard
|
113 |
+
111,67,yes/no standard
|
114 |
+
112,10,wh- standard
|
115 |
+
113,143,yes/no hard (n=1)
|
116 |
+
114,132,yes/no hard (n=1)
|
117 |
+
115,90,wh- hard (n=1)
|
118 |
+
116,140,yes/no hard (n=1)
|
119 |
+
117,144,yes/no hard (n=1)
|
120 |
+
118,106,wh- hard (n=1)
|
121 |
+
119,32,wh- standard
|
122 |
+
120,154,yes/no hard (n=1)
|
123 |
+
121,11,wh- standard
|
124 |
+
122,17,wh- standard
|
125 |
+
123,145,yes/no hard (n=1)
|
126 |
+
124,118,wh- hard (n=1)
|
127 |
+
125,48,yes/no standard
|
128 |
+
126,148,yes/no hard (n=1)
|
129 |
+
127,26,wh- standard
|
130 |
+
128,51,yes/no standard
|
131 |
+
129,13,wh- standard
|
132 |
+
130,39,wh- standard
|
133 |
+
131,153,yes/no hard (n=1)
|
134 |
+
132,12,wh- standard
|
135 |
+
133,93,wh- hard (n=1)
|
136 |
+
134,107,wh- hard (n=1)
|
137 |
+
135,86,wh- hard (n=1)
|
138 |
+
136,31,wh- standard
|
139 |
+
137,95,wh- hard (n=1)
|
140 |
+
138,44,yes/no standard
|
141 |
+
139,69,yes/no standard
|
142 |
+
140,150,yes/no hard (n=1)
|
143 |
+
141,4,wh- standard
|
144 |
+
142,142,yes/no hard (n=1)
|
145 |
+
143,43,yes/no standard
|
146 |
+
144,50,yes/no standard
|
147 |
+
145,100,wh- hard (n=1)
|
148 |
+
146,129,yes/no hard (n=1)
|
149 |
+
147,68,yes/no standard
|
150 |
+
148,146,yes/no hard (n=1)
|
151 |
+
149,64,yes/no standard
|
152 |
+
150,23,wh- standard
|
153 |
+
151,82,wh- hard (n=1)
|
154 |
+
152,111,wh- hard (n=1)
|
155 |
+
153,97,wh- hard (n=1)
|
156 |
+
154,119,wh- hard (n=1)
|
157 |
+
155,141,yes/no hard (n=1)
|
158 |
+
156,20,wh- standard
|
159 |
+
157,36,wh- standard
|
160 |
+
158,138,yes/no hard (n=1)
|
161 |
+
159,105,wh- hard (n=1)
|
response_db.py
CHANGED
@@ -75,5 +75,4 @@ class StResponseDb(ResponseDb):
|
|
75 |
|
76 |
if __name__ == "__main__":
|
77 |
db = ResponseDb()
|
78 |
-
print(db.get_all())
|
79 |
-
|
|
|
75 |
|
76 |
if __name__ == "__main__":
|
77 |
db = ResponseDb()
|
78 |
+
print(db.get_all())
|
|