arjunguha commited on
Commit
000c07e
·
unverified ·
1 Parent(s): 861c325
Files changed (4) hide show
  1. app.py +6 -52
  2. metrics.py +85 -0
  3. puzzles_cleaned.csv +2 -2
  4. results.duckdb +2 -2
app.py CHANGED
@@ -18,52 +18,10 @@ app that displays the following:
18
 
19
  Note that not every model has a response for every puzzle.
20
  """
21
- import re
22
- import duckdb
23
  import gradio as gr
24
- import textwrap
25
 
26
 
27
- def split_into_words(text: str) -> list:
28
- return re.findall(r'\b\w+\b', text.lower())
29
-
30
- def all_words_match(completion: str, answer: str) -> bool:
31
- answer_words = split_into_words(answer)
32
- completion = completion.lower()
33
-
34
- return all(word in completion for word in answer_words)
35
-
36
- def answer_without_thoughts(completion: str) -> str:
37
- if "<think>" not in completion[:200]:
38
- return completion
39
-
40
- chunks = completion.split("</think>")
41
- if len(chunks) <= 1:
42
- return ""
43
-
44
- return chunks[-1].strip()
45
-
46
- def check_answer(completion: str, answer: str) -> bool:
47
- """
48
- Check if all words in the answer are in the completion, in the same order.
49
- """
50
- completion_words = split_into_words(answer_without_thoughts(completion))
51
- answer_words = split_into_words(answer)
52
- indices = []
53
- for word in answer_words:
54
- if word in completion_words:
55
- indices.append(completion_words.index(word))
56
- else:
57
- return False
58
- return indices == sorted(indices) or indices == sorted(indices, reverse=True)
59
-
60
-
61
- def clip_text(text: str, width: int) -> str:
62
- return text if len(text) <= width else text[:width] + "..."
63
-
64
- def wrap_text(text: str, width: int) -> str:
65
- return textwrap.fill(text, width=width)
66
-
67
  def get_model_response(prompt_id, model_name):
68
  query = f"""
69
  SELECT completion FROM results.completions
@@ -89,18 +47,14 @@ def display_model_response(puzzle_id, model_name):
89
  return "From " + model_name + ":\n" + response if response else "No response from this model."
90
 
91
 
92
- conn = duckdb.connect(":memory:")
93
- conn.execute("ATTACH DATABASE 'results.duckdb' AS results")
94
- conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'")
95
- conn.create_function("check_answer", check_answer)
96
- conn.create_function("clip_text", clip_text)
97
- conn.create_function("wrap_text", wrap_text)
98
 
99
  # Get all unique model names
100
  model_names = [item[0] for item in conn.sql("SELECT DISTINCT parent_dir FROM results.completions").fetchall()]
 
101
  # Just for display.
102
  cleaned_model_names = [name.replace("completions-", "") for name in model_names]
103
- print(cleaned_model_names)
104
 
105
  def build_table():
106
  # Construct the query to create two columns for each model: MODEL_answer and MODEL_ok
@@ -141,11 +95,11 @@ def build_table():
141
  joined_df, model_correct_columns = build_table()
142
 
143
  relabelled_df = joined_df[['ID', 'challenge_clipped', 'answer', *model_correct_columns]].rename(columns={
144
- 'ID': 'Puzzle ID',
145
  'challenge_clipped': 'Challenge',
146
  'answer': 'Answer',
147
  **{model.replace("-", "_") + '_ok': model.replace("completions-", "") for model in model_names}
148
- })
149
 
150
  model_columns = {
151
  index + 3: name for index, name in enumerate(model_names)
 
18
 
19
  Note that not every model has a response for every puzzle.
20
  """
 
 
21
  import gradio as gr
22
+ from metrics import load_results
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def get_model_response(prompt_id, model_name):
26
  query = f"""
27
  SELECT completion FROM results.completions
 
47
  return "From " + model_name + ":\n" + response if response else "No response from this model."
48
 
49
 
50
+ conn = load_results()
 
 
 
 
 
51
 
52
  # Get all unique model names
53
  model_names = [item[0] for item in conn.sql("SELECT DISTINCT parent_dir FROM results.completions").fetchall()]
54
+ model_names.sort()
55
  # Just for display.
56
  cleaned_model_names = [name.replace("completions-", "") for name in model_names]
57
+
58
 
59
  def build_table():
60
  # Construct the query to create two columns for each model: MODEL_answer and MODEL_ok
 
95
  joined_df, model_correct_columns = build_table()
96
 
97
  relabelled_df = joined_df[['ID', 'challenge_clipped', 'answer', *model_correct_columns]].rename(columns={
98
+ 'ID': 'ID',
99
  'challenge_clipped': 'Challenge',
100
  'answer': 'Answer',
101
  **{model.replace("-", "_") + '_ok': model.replace("completions-", "") for model in model_names}
102
+ }).sort_values(by='ID')
103
 
104
  model_columns = {
105
  index + 3: name for index, name in enumerate(model_names)
metrics.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import duckdb
3
+ import textwrap
4
+
5
+ def _parse_answer(text: str) -> str:
6
+ """
7
+ Converts text to lowercase. Interprets "," and "-->" as separators for
8
+ elements of a set. Within each set, drops all non-alphanumeric characters
9
+ and returns that set.
10
+
11
+ Another way to describe this is that we interpret adjacent words as
12
+ phrases that must be present literally. However, comma and arrow separate
13
+ distinct phrases that may be present in any order. All other characters
14
+ are dropped.
15
+ """
16
+ text = text.lower()
17
+ groups = re.split(r'-->|,', text)
18
+ return [" ".join(re.findall(r'\b\w+\b', group)) for group in groups]
19
+
20
+ def _answer_without_thoughts(completion: str) -> str:
21
+ if "<think>" not in completion[:200]:
22
+ return completion
23
+
24
+ chunks = completion.split("</think>")
25
+ if len(chunks) <= 1:
26
+ return ""
27
+
28
+ return chunks[-1].strip()
29
+
30
+ def _check_answer(completion: str, answer: str) -> bool:
31
+ """
32
+ Check that all the phrases that must appear in the answer appear in the
33
+ completion. We ignore "thoughts", capitalization, and punctuation.
34
+ """
35
+ completion = _answer_without_thoughts(completion).lower()
36
+ answer_phrases = _parse_answer(answer)
37
+ r = all(phrase in completion for phrase in answer_phrases)
38
+ return r
39
+
40
+
41
+ def _clip_text(text: str, width: int) -> str:
42
+ return text if len(text) <= width else text[:width] + "..."
43
+
44
+ def _wrap_text(text: str, width: int) -> str:
45
+ return textwrap.fill(text, width=width)
46
+
47
+ def load_results():
48
+ conn = duckdb.connect(":memory:")
49
+ conn.execute("ATTACH DATABASE 'results.duckdb' AS results")
50
+ conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'")
51
+ conn.create_function("check_answer", _check_answer)
52
+ conn.create_function("clip_text", _clip_text)
53
+ conn.create_function("wrap_text", _wrap_text)
54
+ return conn
55
+
56
+ def accuracy_by_model(conn):
57
+ model_accuracies = conn.sql("""
58
+ WITH AnswerCheck AS (
59
+ SELECT
60
+ results.parent_dir AS model,
61
+ COUNT(*) AS total,
62
+ SUM(CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct
63
+ FROM
64
+ results.completions results
65
+ JOIN
66
+ challenges challenges
67
+ ON
68
+ results.prompt_id = challenges.ID
69
+ GROUP BY
70
+ results.parent_dir
71
+ )
72
+ SELECT
73
+ model,
74
+ total,
75
+ correct,
76
+ ROUND(correct / total, 2) AS accuracy
77
+ FROM
78
+ AnswerCheck
79
+ """)
80
+
81
+ print(model_accuracies)
82
+
83
+ if __name__ == "__main__":
84
+ conn = load_results()
85
+ accuracy_by_model(conn)
puzzles_cleaned.csv CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0a6bd98c71e31ec98439b56cd22bd23af52763d24b66da7eda42d30c610693ce
3
- size 1134920
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7efd3a2897270124ecc8a299b96d14fb54600f3c0faf27b790d8b0312720f3cd
3
+ size 1132332
results.duckdb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d83d136691e04a7570e3c3eb1b11fca96078d5041c1dc87f3aed86f5c9effa93
3
- size 29634560
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa7c7911a1ecf7fe4223995e3d393dd78cf8d4023409197854bf471fd8ab7c48
3
+ size 32518144