File size: 6,065 Bytes
000c07e 8fcea9a 1b3f3b2 000c07e 8fcea9a 000c07e 8fcea9a 000c07e 8fcea9a 000c07e 8fcea9a 9a3117e 8fcea9a 000c07e 2b8f77d 9a3117e 8fcea9a 2b8f77d 8fcea9a 000c07e 1b3f3b2 000c07e 2b8f77d 1b3f3b2 2b8f77d 1b3f3b2 2b8f77d 1b3f3b2 2b8f77d 1b3f3b2 2b8f77d 1b3f3b2 000c07e 1b3f3b2 000c07e 1b3f3b2 000c07e 1b3f3b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import re
import duckdb
import textwrap
from typing import List, Tuple
import argparse
def _parse_answer(text: str) -> List[List[str]]:
"""
Converts text to lowercase. Then interprets ";" as a separator between
alternatives. Within each alternative, interprets "," and "-->" as separators
for elements of a set. Within each set, drops all non-alphanumeric characters
and returns that set.
Another way to describe this is that we interpret adjacent words as
phrases that must be present literally. However, comma and arrow separate
distinct phrases that may be present in any order. All other characters
are dropped.
"""
text = text.lower()
alternatives = re.split(r';', text)
result = [ ]
for alternative in alternatives:
groups = re.split(r'–?-?-?>|,', alternative)
result.append([" ".join(re.findall(r'\b\w+\b', group)) for group in groups])
return result
def _answer_without_thoughts(completion: str) -> str:
if "<think>" not in completion[:200]:
return completion
chunks = completion.split("</think>")
if len(chunks) <= 1:
return ""
return chunks[-1].strip()
def _check_answer(completion: str, answer: str) -> bool:
"""
Check that all the phrases that must appear in the answer appear in the
completion. We ignore "thoughts", capitalization, and punctuation.
"""
completion = _answer_without_thoughts(completion).lower()
completion = re.sub(r'[^\w\s]', ' ', completion) # this replaces punctuations with space, aligning with the _parse_answer function's ' '.join
completion = re.sub(r'\s+', ' ', completion) # normalize consecutive (Unicode) spaces to finish aligning with _parse_answer
alternative_answers = _parse_answer(answer)
for answer_phrases in alternative_answers:
# if all(phrase in completion for phrase in answer_phrases):
if all(re.search(rf'\b{re.escape(phrase)}\b', completion) for phrase in answer_phrases):
return True
return False
def _clip_text(text: str, width: int) -> str:
return text if len(text) <= width else text[:width] + "..."
def _wrap_text(text: str, width: int) -> str:
return textwrap.fill(text, width=width)
def load_results():
conn = duckdb.connect(":memory:")
conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)")
conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'")
conn.create_function("check_answer", _check_answer)
conn.create_function("clip_text", _clip_text)
conn.create_function("wrap_text", _wrap_text)
return conn
def r1_accuracy_by_completion_length(conn,model_name):
"""
For the responses from the completions-r1 model:
1. We calculate completion length and correctness for each problem.
2. We sort by length.
3. We compute cumulative number of correct responses.
"""
r1_completions = conn.sql(f"""
WITH LengthsAndCorrectness AS (
SELECT
LENGTH(results.completion) AS length,
CAST(check_answer(results.completion, challenges.answer) AS INT32) AS correct
FROM results.completions results JOIN challenges
ON results.prompt_id = challenges.ID
WHERE results.parent_dir = '{model_name}'
),
TotalItems AS (
SELECT COUNT(*) as total_count
FROM LengthsAndCorrectness
),
CumulativeCorrect AS (
SELECT
length,
SUM(correct) OVER (ORDER BY length) as cumulative_correct,
FROM LengthsAndCorrectness
)
SELECT
length,
cumulative_correct,
CAST(cumulative_correct AS FLOAT) / total_count AS cumulative_accuracy
FROM CumulativeCorrect, TotalItems
ORDER BY length
""")
return r1_completions
def accuracy_by_model_and_time(conn):
model_accuracies = conn.sql("""
WITH ChallengesWithDates AS (
SELECT
ID,
answer,
EXTRACT(YEAR FROM CAST(date AS DATE)) AS year
FROM
challenges
),
DateAnswerCheck AS (
SELECT
results.parent_dir AS model,
dates.year,
COUNT(*) AS total,
SUM(CAST(check_answer(results.completion, dates.answer) AS INTEGER)) AS correct
FROM
results.completions results
JOIN
ChallengesWithDates dates
ON
results.prompt_id = dates.ID
GROUP BY
results.parent_dir,
dates.year
)
SELECT
model,
year,
total,
correct,
ROUND(correct / total, 2) AS accuracy
FROM
DateAnswerCheck
ORDER BY
model,
year
""")
return model_accuracies
def accuracy_by_model(conn):
return conn.sql("""
WITH AnswerCheck AS (
SELECT
results.parent_dir AS model,
COUNT(*) AS total,
SUM(CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct
FROM
results.completions results
JOIN
challenges challenges
ON
results.prompt_id = challenges.ID
GROUP BY
results.parent_dir
)
SELECT
model,
total,
correct,
ROUND(correct / total, 2) AS accuracy
FROM
AnswerCheck
""")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--by-model-and-time", action="store_true")
args = parser.parse_args()
conn = load_results()
if args.by_model_and_time:
print(accuracy_by_model_and_time(conn))
else:
print(accuracy_by_model(conn))
if __name__ == "__main__":
main()
|