Amirmarshal commited on
Commit
9e292f0
β€’
1 Parent(s): cb84380

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +326 -0
app.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Basic example for doing model-in-the-loop dynamic adversarial data collection
2
+ # using Gradio Blocks.
3
+ import json
4
+ import os
5
+ import threading
6
+ import time
7
+ import uuid
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from pathlib import Path
10
+ from typing import List
11
+ from urllib.parse import parse_qs
12
+
13
+ import gradio as gr
14
+ from dotenv import load_dotenv
15
+ from huggingface_hub import Repository
16
+ from langchain import ConversationChain
17
+ from langchain.chains.conversation.memory import ConversationBufferMemory
18
+ from langchain.llms import HuggingFaceHub
19
+ from langchain.prompts import load_prompt
20
+
21
+ from utils import force_git_push
22
+
23
+
24
+ def generate_respone(chatbot: ConversationChain, input: str) -> str:
25
+ """Generates a response for a `langchain` chatbot."""
26
+ return chatbot.predict(input=input)
27
+
28
+
29
+ def generate_responses(chatbots: List[ConversationChain], inputs: List[str]) -> List[str]:
30
+ """Generates parallel responses for a list of `langchain` chatbots."""
31
+ results = []
32
+ with ThreadPoolExecutor(max_workers=100) as executor:
33
+ for result in executor.map(generate_respone, chatbots, inputs):
34
+ results.append(result)
35
+ return results
36
+
37
+
38
+ # These variables are for storing the MTurk HITs in a Hugging Face dataset.
39
+ if Path(".env").is_file():
40
+ load_dotenv(".env")
41
+ DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
42
+ FORCE_PUSH = os.getenv("FORCE_PUSH")
43
+ HF_TOKEN = os.getenv("HF_TOKEN")
44
+ PROMPT_TEMPLATES = Path("prompt_templates")
45
+
46
+ DATA_FILENAME = "data.jsonl"
47
+ DATA_FILE = os.path.join("data", DATA_FILENAME)
48
+ repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN)
49
+
50
+ TOTAL_CNT = 3 # How many user inputs per HIT
51
+
52
+ # This function pushes the HIT data written in data.jsonl to our Hugging Face
53
+ # dataset every minute. Adjust the frequency to suit your needs.
54
+ PUSH_FREQUENCY = 60
55
+
56
+
57
+ def asynchronous_push(f_stop):
58
+ if repo.is_repo_clean():
59
+ print("Repo currently clean. Ignoring push_to_hub")
60
+ else:
61
+ repo.git_add(auto_lfs_track=True)
62
+ repo.git_commit("Auto commit by space")
63
+ if FORCE_PUSH == "yes":
64
+ force_git_push(repo)
65
+ else:
66
+ repo.git_push()
67
+ if not f_stop.is_set():
68
+ # call again in 60 seconds
69
+ threading.Timer(PUSH_FREQUENCY, asynchronous_push, [f_stop]).start()
70
+
71
+
72
+ f_stop = threading.Event()
73
+ asynchronous_push(f_stop)
74
+
75
+ # Now let's run the app!
76
+ prompt = load_prompt(PROMPT_TEMPLATES / "openai_chatgpt.json")
77
+
78
+ # TODO: update this list with better, instruction-trained models
79
+ MODEL_IDS = ["google/flan-t5-xl", "bigscience/T0_3B", "EleutherAI/gpt-j-6B"]
80
+ chatbots = []
81
+
82
+ for model_id in MODEL_IDS:
83
+ chatbots.append(
84
+ ConversationChain(
85
+ llm=HuggingFaceHub(
86
+ repo_id=model_id,
87
+ model_kwargs={"temperature": 1},
88
+ huggingfacehub_api_token=HF_TOKEN,
89
+ ),
90
+ prompt=prompt,
91
+ verbose=False,
92
+ memory=ConversationBufferMemory(ai_prefix="Assistant"),
93
+ )
94
+ )
95
+
96
+
97
+ model_id2model = {chatbot.llm.repo_id: chatbot for chatbot in chatbots}
98
+
99
+ demo = gr.Blocks()
100
+
101
+ with demo:
102
+ dummy = gr.Textbox(visible=False) # dummy for passing assignmentId
103
+
104
+ # We keep track of state as a JSON
105
+ state_dict = {
106
+ "conversation_id": str(uuid.uuid4()),
107
+ "assignmentId": "",
108
+ "cnt": 0,
109
+ "data": [],
110
+ "past_user_inputs": [],
111
+ "generated_responses": [],
112
+ }
113
+ for idx in range(len(chatbots)):
114
+ state_dict[f"response_{idx+1}"] = ""
115
+ state = gr.JSON(state_dict, visible=False)
116
+
117
+ gr.Markdown("# Talk to the assistant")
118
+
119
+ state_display = gr.Markdown(f"Your messages: 0/{TOTAL_CNT}")
120
+
121
+ # Generate model prediction
122
+ def _predict(txt, state):
123
+ start = time.time()
124
+ responses = generate_responses(chatbots, [txt] * len(chatbots))
125
+ print(f"Time taken to generate {len(chatbots)} responses : {time.time() - start:.2f} seconds")
126
+
127
+ response2model_id = {}
128
+ for chatbot, response in zip(chatbots, responses):
129
+ response2model_id[response] = chatbot.llm.repo_id
130
+
131
+ state["cnt"] += 1
132
+
133
+ new_state_md = f"Inputs remaining in HIT: {state['cnt']}/{TOTAL_CNT}"
134
+
135
+ metadata = {"cnt": state["cnt"], "text": txt}
136
+ for idx, response in enumerate(responses):
137
+ metadata[f"response_{idx + 1}"] = response
138
+
139
+ metadata["response2model_id"] = response2model_id
140
+
141
+ state["data"].append(metadata)
142
+ state["past_user_inputs"].append(txt)
143
+
144
+ past_conversation_string = "<br />".join(
145
+ [
146
+ "<br />".join(["Human πŸ˜ƒ: " + user_input, "Assistant πŸ€–: " + model_response])
147
+ for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"] + [""])
148
+ ]
149
+ )
150
+ return (
151
+ gr.update(visible=False),
152
+ gr.update(visible=True),
153
+ gr.update(visible=True, choices=responses, interactive=True, value=responses[0]),
154
+ gr.update(value=past_conversation_string),
155
+ state,
156
+ gr.update(visible=False),
157
+ gr.update(visible=False),
158
+ gr.update(visible=False),
159
+ new_state_md,
160
+ dummy,
161
+ )
162
+
163
+ def _select_response(selected_response, state, dummy):
164
+ done = state["cnt"] == TOTAL_CNT
165
+ state["generated_responses"].append(selected_response)
166
+ state["data"][-1]["selected_response"] = selected_response
167
+ state["data"][-1]["selected_model"] = state["data"][-1]["response2model_id"][selected_response]
168
+ if state["cnt"] == TOTAL_CNT:
169
+ # Write the HIT data to our local dataset because the worker has
170
+ # submitted everything now.
171
+ with open(DATA_FILE, "a") as jsonlfile:
172
+ json_data_with_assignment_id = [
173
+ json.dumps(
174
+ dict(
175
+ {"assignmentId": state["assignmentId"], "conversation_id": state["conversation_id"]},
176
+ **datum,
177
+ )
178
+ )
179
+ for datum in state["data"]
180
+ ]
181
+ jsonlfile.write("\n".join(json_data_with_assignment_id) + "\n")
182
+ toggle_example_submit = gr.update(visible=not done)
183
+ past_conversation_string = "<br />".join(
184
+ [
185
+ "<br />".join(["πŸ˜ƒ: " + user_input, "πŸ€–: " + model_response])
186
+ for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"])
187
+ ]
188
+ )
189
+ query = parse_qs(dummy[1:])
190
+ if "assignmentId" in query and query["assignmentId"][0] != "ASSIGNMENT_ID_NOT_AVAILABLE":
191
+ # It seems that someone is using this app on mturk. We need to
192
+ # store the assignmentId in the state before submit_hit_button
193
+ # is clicked. We can do this here in _predict. We need to save the
194
+ # assignmentId so that the turker can get credit for their HIT.
195
+ state["assignmentId"] = query["assignmentId"][0]
196
+ toggle_final_submit = gr.update(visible=done)
197
+ toggle_final_submit_preview = gr.update(visible=False)
198
+ else:
199
+ toggle_final_submit_preview = gr.update(visible=done)
200
+ toggle_final_submit = gr.update(visible=False)
201
+
202
+ if done:
203
+ # Wipe the memory completely because we will be starting a new hit soon.
204
+ for chatbot in chatbots:
205
+ chatbot.memory = ConversationBufferMemory(ai_prefix="Assistant")
206
+ else:
207
+ # Sync all of the model's memories with the conversation path that
208
+ # was actually taken.
209
+ for chatbot in chatbots:
210
+ chatbot.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
211
+
212
+ text_input = gr.update(visible=False) if done else gr.update(visible=True)
213
+ return (
214
+ gr.update(visible=False),
215
+ gr.update(visible=True),
216
+ text_input,
217
+ gr.update(visible=False),
218
+ state,
219
+ gr.update(value=past_conversation_string),
220
+ toggle_example_submit,
221
+ toggle_final_submit,
222
+ toggle_final_submit_preview,
223
+ dummy,
224
+ )
225
+
226
+ # Input fields
227
+ past_conversation = gr.Markdown()
228
+ text_input = gr.Textbox(placeholder="Enter a statement", show_label=False)
229
+ select_response = gr.Radio(
230
+ choices=[None, None], visible=False, label="Choose the most helpful and honest response"
231
+ )
232
+ select_response_button = gr.Button("Select Response", visible=False)
233
+ with gr.Column() as example_submit:
234
+ submit_ex_button = gr.Button("Submit")
235
+ with gr.Column(visible=False) as final_submit:
236
+ submit_hit_button = gr.Button("Submit HIT")
237
+ with gr.Column(visible=False) as final_submit_preview:
238
+ submit_hit_button_preview = gr.Button(
239
+ "Submit Work (preview mode; no MTurk HIT credit, but your examples will still be stored)"
240
+ )
241
+
242
+ # Button event handlers
243
+ get_window_location_search_js = """
244
+ function(select_response, state, dummy) {
245
+ return [select_response, state, window.location.search];
246
+ }
247
+ """
248
+
249
+ select_response_button.click(
250
+ _select_response,
251
+ inputs=[select_response, state, dummy],
252
+ outputs=[
253
+ select_response,
254
+ example_submit,
255
+ text_input,
256
+ select_response_button,
257
+ state,
258
+ past_conversation,
259
+ example_submit,
260
+ final_submit,
261
+ final_submit_preview,
262
+ dummy,
263
+ ],
264
+ _js=get_window_location_search_js,
265
+ )
266
+
267
+ submit_ex_button.click(
268
+ _predict,
269
+ inputs=[text_input, state],
270
+ outputs=[
271
+ text_input,
272
+ select_response_button,
273
+ select_response,
274
+ past_conversation,
275
+ state,
276
+ example_submit,
277
+ final_submit,
278
+ final_submit_preview,
279
+ state_display,
280
+ ],
281
+ )
282
+
283
+ post_hit_js = """
284
+ function(state) {
285
+ // If there is an assignmentId, then the submitter is on mturk
286
+ // and has accepted the HIT. So, we need to submit their HIT.
287
+ const form = document.createElement('form');
288
+ form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit';
289
+ form.method = 'post';
290
+ for (const key in state) {
291
+ const hiddenField = document.createElement('input');
292
+ hiddenField.type = 'hidden';
293
+ hiddenField.name = key;
294
+ hiddenField.value = state[key];
295
+ form.appendChild(hiddenField);
296
+ };
297
+ document.body.appendChild(form);
298
+ form.submit();
299
+ return state;
300
+ }
301
+ """
302
+
303
+ submit_hit_button.click(
304
+ lambda state: state,
305
+ inputs=[state],
306
+ outputs=[state],
307
+ _js=post_hit_js,
308
+ )
309
+
310
+ refresh_app_js = """
311
+ function(state) {
312
+ // The following line here loads the app again so the user can
313
+ // enter in another preview-mode "HIT".
314
+ window.location.href = window.location.href;
315
+ return state;
316
+ }
317
+ """
318
+
319
+ submit_hit_button_preview.click(
320
+ lambda state: state,
321
+ inputs=[state],
322
+ outputs=[state],
323
+ _js=refresh_app_js,
324
+ )
325
+
326
+ demo.launch()