|
import re |
|
import duckdb |
|
import textwrap |
|
from typing import List, Tuple |
|
import argparse |
|
|
|
def _parse_answer(text: str) -> List[List[str]]: |
|
""" |
|
Converts text to lowercase. Then interprets ";" as a separator between |
|
alternatives. Within each alternative, interprets "," and "-->" as separators |
|
for elements of a set. Within each set, drops all non-alphanumeric characters |
|
and returns that set. |
|
|
|
|
|
Another way to describe this is that we interpret adjacent words as |
|
phrases that must be present literally. However, comma and arrow separate |
|
distinct phrases that may be present in any order. All other characters |
|
are dropped. |
|
""" |
|
text = text.lower() |
|
alternatives = re.split(r';', text) |
|
result = [ ] |
|
for alternative in alternatives: |
|
groups = re.split(r'–?-?-?>|,', alternative) |
|
result.append([" ".join(re.findall(r'\b\w+\b', group)) for group in groups]) |
|
return result |
|
|
|
def _answer_without_thoughts(completion: str) -> str: |
|
if "<think>" not in completion[:200]: |
|
return completion |
|
|
|
chunks = completion.split("</think>") |
|
if len(chunks) <= 1: |
|
return "" |
|
|
|
return chunks[-1].strip() |
|
|
|
def _check_answer(completion: str, answer: str) -> bool: |
|
""" |
|
Check that all the phrases that must appear in the answer appear in the |
|
completion. We ignore "thoughts", capitalization, and punctuation. |
|
""" |
|
completion = _answer_without_thoughts(completion).lower() |
|
completion = re.sub(r'[^\w\s]', ' ', completion) |
|
completion = re.sub(r'\s+', ' ', completion) |
|
alternative_answers = _parse_answer(answer) |
|
for answer_phrases in alternative_answers: |
|
|
|
if all(re.search(rf'\b{re.escape(phrase)}\b', completion) for phrase in answer_phrases): |
|
return True |
|
return False |
|
|
|
|
|
def _clip_text(text: str, width: int) -> str: |
|
return text if len(text) <= width else text[:width] + "..." |
|
|
|
def _wrap_text(text: str, width: int) -> str: |
|
return textwrap.fill(text, width=width) |
|
|
|
def load_results(): |
|
conn = duckdb.connect(":memory:") |
|
conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)") |
|
conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'") |
|
conn.create_function("check_answer", _check_answer) |
|
conn.create_function("clip_text", _clip_text) |
|
conn.create_function("wrap_text", _wrap_text) |
|
return conn |
|
|
|
def r1_accuracy_by_completion_length(conn,model_name): |
|
""" |
|
For the responses from the completions-r1 model: |
|
1. We calculate completion length and correctness for each problem. |
|
2. We sort by length. |
|
3. We compute cumulative number of correct responses. |
|
""" |
|
r1_completions = conn.sql(f""" |
|
WITH LengthsAndCorrectness AS ( |
|
SELECT |
|
LENGTH(results.completion) AS length, |
|
CAST(check_answer(results.completion, challenges.answer) AS INT32) AS correct |
|
FROM results.completions results JOIN challenges |
|
ON results.prompt_id = challenges.ID |
|
WHERE results.parent_dir = '{model_name}' |
|
), |
|
TotalItems AS ( |
|
SELECT COUNT(*) as total_count |
|
FROM LengthsAndCorrectness |
|
), |
|
CumulativeCorrect AS ( |
|
SELECT |
|
length, |
|
SUM(correct) OVER (ORDER BY length) as cumulative_correct, |
|
FROM LengthsAndCorrectness |
|
) |
|
|
|
SELECT |
|
length, |
|
cumulative_correct, |
|
CAST(cumulative_correct AS FLOAT) / total_count AS cumulative_accuracy |
|
FROM CumulativeCorrect, TotalItems |
|
ORDER BY length |
|
""") |
|
return r1_completions |
|
|
|
|
|
def accuracy_by_model_and_time(conn): |
|
model_accuracies = conn.sql(""" |
|
WITH ChallengesWithDates AS ( |
|
SELECT |
|
ID, |
|
answer, |
|
EXTRACT(YEAR FROM CAST(date AS DATE)) AS year |
|
FROM |
|
challenges |
|
), |
|
DateAnswerCheck AS ( |
|
SELECT |
|
results.parent_dir AS model, |
|
dates.year, |
|
COUNT(*) AS total, |
|
SUM(CAST(check_answer(results.completion, dates.answer) AS INTEGER)) AS correct |
|
FROM |
|
results.completions results |
|
JOIN |
|
ChallengesWithDates dates |
|
ON |
|
results.prompt_id = dates.ID |
|
GROUP BY |
|
results.parent_dir, |
|
dates.year |
|
) |
|
SELECT |
|
model, |
|
year, |
|
total, |
|
correct, |
|
ROUND(correct / total, 2) AS accuracy |
|
FROM |
|
DateAnswerCheck |
|
ORDER BY |
|
model, |
|
year |
|
""") |
|
|
|
return model_accuracies |
|
|
|
def accuracy_by_model(conn): |
|
return conn.sql(""" |
|
WITH AnswerCheck AS ( |
|
SELECT |
|
results.parent_dir AS model, |
|
COUNT(*) AS total, |
|
SUM(CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct |
|
FROM |
|
results.completions results |
|
JOIN |
|
challenges challenges |
|
ON |
|
results.prompt_id = challenges.ID |
|
GROUP BY |
|
results.parent_dir |
|
) |
|
SELECT |
|
model, |
|
total, |
|
correct, |
|
ROUND(correct / total, 2) AS accuracy |
|
FROM |
|
AnswerCheck |
|
""") |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--by-model-and-time", action="store_true") |
|
args = parser.parse_args() |
|
conn = load_results() |
|
if args.by_model_and_time: |
|
print(accuracy_by_model_and_time(conn)) |
|
else: |
|
print(accuracy_by_model(conn)) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|