arjunguha commited on
Commit
861c325
·
unverified ·
1 Parent(s): 41c62ec

Copy from repository

Browse files
Files changed (5) hide show
  1. .gitattributes +2 -0
  2. app.py +189 -0
  3. puzzles_cleaned.csv +3 -0
  4. requirements.txt +1 -0
  5. results.duckdb +3 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ results.duckdb filter=lfs diff=lfs merge=lfs -text
37
+ puzzles_cleaned.csv filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This program helps us explore model's responses to the benchmark. It is a web
3
+ app that displays the following:
4
+
5
+ 1. A list of benchmark items loaded from puzzles_cleaned.csv. The list shows
6
+ the columns ID, challenge, and answer.
7
+ 2. When we select a puzzle from the list, we see the transcript, Explanation,
8
+ and Editor's Note in textboxes. (Scrollable since they can be long.)
9
+ 3. The list in (1) also has a column for each model, with checkboxes indicating
10
+ whether the model's response is correct or not. We load the model responses
11
+ from results.duckdb. That file has a table called completions with
12
+ columns 'prompt_id', 'parent_dir', and 'completion'. The prompt_id can be
13
+ joined with ID from puzzles_cleaned.csv. The parent_dir is the model name.
14
+ The completion is the model response, which we compare with the answer from
15
+ puzzles_cleaned.csv using the function check_answer defined below.
16
+ 4. Finally, when an item is selected from the list, we get a dropdown that lets
17
+ us select a model to see the completion from that model.
18
+
19
+ Note that not every model has a response for every puzzle.
20
+ """
21
+ import re
22
+ import duckdb
23
+ import gradio as gr
24
+ import textwrap
25
+
26
+
27
+ def split_into_words(text: str) -> list:
28
+ return re.findall(r'\b\w+\b', text.lower())
29
+
30
+ def all_words_match(completion: str, answer: str) -> bool:
31
+ answer_words = split_into_words(answer)
32
+ completion = completion.lower()
33
+
34
+ return all(word in completion for word in answer_words)
35
+
36
+ def answer_without_thoughts(completion: str) -> str:
37
+ if "<think>" not in completion[:200]:
38
+ return completion
39
+
40
+ chunks = completion.split("</think>")
41
+ if len(chunks) <= 1:
42
+ return ""
43
+
44
+ return chunks[-1].strip()
45
+
46
+ def check_answer(completion: str, answer: str) -> bool:
47
+ """
48
+ Check if all words in the answer are in the completion, in the same order.
49
+ """
50
+ completion_words = split_into_words(answer_without_thoughts(completion))
51
+ answer_words = split_into_words(answer)
52
+ indices = []
53
+ for word in answer_words:
54
+ if word in completion_words:
55
+ indices.append(completion_words.index(word))
56
+ else:
57
+ return False
58
+ return indices == sorted(indices) or indices == sorted(indices, reverse=True)
59
+
60
+
61
+ def clip_text(text: str, width: int) -> str:
62
+ return text if len(text) <= width else text[:width] + "..."
63
+
64
+ def wrap_text(text: str, width: int) -> str:
65
+ return textwrap.fill(text, width=width)
66
+
67
+ def get_model_response(prompt_id, model_name):
68
+ query = f"""
69
+ SELECT completion FROM results.completions
70
+ WHERE prompt_id = {prompt_id} AND parent_dir = '{model_name}'
71
+ """
72
+ response = conn.sql(query).fetchone()
73
+ return response[0] if response else None
74
+
75
+ def display_puzzle(puzzle_id):
76
+ query = f"""
77
+ SELECT challenge, answer, transcript, Explanation, "Editor's Notes"
78
+ FROM challenges
79
+ WHERE ID = {puzzle_id}
80
+ """
81
+ puzzle = conn.sql(query).fetchone()
82
+ return puzzle if puzzle else (None, None,None, None, None)
83
+
84
+ def display_model_response(puzzle_id, model_name):
85
+ response = get_model_response(puzzle_id, model_name)
86
+ split_thoughts = response.split("</think>")
87
+ if len(split_thoughts) > 1:
88
+ response = split_thoughts[-1].strip()
89
+ return "From " + model_name + ":\n" + response if response else "No response from this model."
90
+
91
+
92
+ conn = duckdb.connect(":memory:")
93
+ conn.execute("ATTACH DATABASE 'results.duckdb' AS results")
94
+ conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'")
95
+ conn.create_function("check_answer", check_answer)
96
+ conn.create_function("clip_text", clip_text)
97
+ conn.create_function("wrap_text", wrap_text)
98
+
99
+ # Get all unique model names
100
+ model_names = [item[0] for item in conn.sql("SELECT DISTINCT parent_dir FROM results.completions").fetchall()]
101
+ # Just for display.
102
+ cleaned_model_names = [name.replace("completions-", "") for name in model_names]
103
+ print(cleaned_model_names)
104
+
105
+ def build_table():
106
+ # Construct the query to create two columns for each model: MODEL_answer and MODEL_ok
107
+ query = """
108
+ SELECT c.ID, c.challenge, wrap_text(c.answer, 40) AS answer,
109
+ """
110
+
111
+ model_correct_columns = []
112
+ for model in model_names:
113
+ normalized_model_name = model.replace("-", "_")
114
+ model_correct_columns.append(normalized_model_name + "_ok")
115
+ query += f"""
116
+ MAX(CASE WHEN r.parent_dir = '{model}' THEN r.completion ELSE NULL END) AS {normalized_model_name}_answer,
117
+ MAX(CASE WHEN r.parent_dir = '{model}' THEN check_answer(r.completion, c.answer) ELSE NULL END) AS {normalized_model_name}_ok,
118
+ """
119
+
120
+ query = query.rstrip(',') # Remove the trailing comma
121
+ query += """
122
+ clip_text(c.challenge, 40) as challenge_clipped,
123
+ FROM challenges c
124
+ LEFT JOIN results.completions r
125
+ ON c.ID = r.prompt_id
126
+ GROUP BY c.ID, c.challenge, c.answer
127
+ """
128
+
129
+ joined_df = conn.sql(query).fetchdf()
130
+
131
+ # Transform the model_correct columns to use emojis
132
+ for model in model_names:
133
+ normalized_model_name = model.replace("-", "_")
134
+ joined_df[normalized_model_name + '_ok'] = joined_df[normalized_model_name + '_ok'].apply(
135
+ lambda x: "✅" if x == 1 else ("❌" if x == 0 else "❓")
136
+ )
137
+
138
+ return joined_df, model_correct_columns
139
+
140
+
141
+ joined_df, model_correct_columns = build_table()
142
+
143
+ relabelled_df = joined_df[['ID', 'challenge_clipped', 'answer', *model_correct_columns]].rename(columns={
144
+ 'ID': 'Puzzle ID',
145
+ 'challenge_clipped': 'Challenge',
146
+ 'answer': 'Answer',
147
+ **{model.replace("-", "_") + '_ok': model.replace("completions-", "") for model in model_names}
148
+ })
149
+
150
+ model_columns = {
151
+ index + 3: name for index, name in enumerate(model_names)
152
+ }
153
+
154
+ valid_model_indices = list(model_columns.keys())
155
+ default_model = model_columns[valid_model_indices[0]]
156
+
157
+ def create_interface():
158
+ with gr.Blocks() as demo:
159
+ # Using "markdown" as the datatype makes Gradio interpret newlines.
160
+ puzzle_list = gr.DataFrame(
161
+ value=relabelled_df,
162
+ datatype=["number", "str", "markdown", *["str"] * len(model_correct_columns)],
163
+ # headers=["ID", "Challenge", "Answer", *cleaned_model_names],
164
+ )
165
+ model_response = gr.Textbox(label="Model Response", interactive=False)
166
+ challenge = gr.Textbox(label="Challenge", interactive=False)
167
+ answer = gr.Textbox(label="Answer", interactive=False)
168
+ explanation = gr.Textbox(label="Explanation", interactive=False)
169
+ editors_note = gr.Textbox(label="Editor's Note", interactive=False)
170
+ transcript = gr.Textbox(label="Transcript", interactive=False)
171
+
172
+ def update_puzzle(evt: gr.SelectData):
173
+ row = evt.index[0]
174
+ model_index = evt.index[1]
175
+ model_name = model_columns[model_index] if model_index in valid_model_indices else default_model
176
+ return (*display_puzzle(row), display_model_response(row, model_name))
177
+
178
+ puzzle_list.select(
179
+ fn=update_puzzle,
180
+ inputs=[],
181
+ outputs=[challenge, answer, transcript, explanation, editors_note, model_response]
182
+ )
183
+
184
+ demo.launch()
185
+
186
+
187
+ if __name__ == "__main__":
188
+ create_interface()
189
+
puzzles_cleaned.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a6bd98c71e31ec98439b56cd22bd23af52763d24b66da7eda42d30c610693ce
3
+ size 1134920
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ duckdb==1.1.3
results.duckdb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d83d136691e04a7570e3c3eb1b11fca96078d5041c1dc87f3aed86f5c9effa93
3
+ size 29634560