Spaces:
Running
Running
import re | |
import duckdb | |
import textwrap | |
from typing import List, Tuple | |
import argparse | |
def _parse_answer(text: str) -> List[List[str]]: | |
""" | |
Converts text to lowercase. Then interprets ";" as a separator between | |
alternatives. Within each alternative, interprets "," and "-->" as separators | |
for elements of a set. Within each set, drops all non-alphanumeric characters | |
and returns that set. | |
Another way to describe this is that we interpret adjacent words as | |
phrases that must be present literally. However, comma and arrow separate | |
distinct phrases that may be present in any order. All other characters | |
are dropped. | |
""" | |
text = text.lower() | |
alternatives = re.split(r';', text) | |
result = [ ] | |
for alternative in alternatives: | |
groups = re.split(r'-->|,', alternative) | |
result.append([" ".join(re.findall(r'\b\w+\b', group)) for group in groups]) | |
return result | |
def _answer_without_thoughts(completion: str) -> str: | |
if "<think>" not in completion[:200]: | |
return completion | |
chunks = completion.split("</think>") | |
if len(chunks) <= 1: | |
return "" | |
return chunks[-1].strip() | |
def _check_answer(completion: str, answer: str) -> bool: | |
""" | |
Check that all the phrases that must appear in the answer appear in the | |
completion. We ignore "thoughts", capitalization, and punctuation. | |
""" | |
completion = _answer_without_thoughts(completion).lower() | |
completion = re.sub(r'[^\w\s]', ' ', completion) # this replaces punctuations with space, aligning with the _parse_answer function's ' '.join | |
alternative_answers = _parse_answer(answer) | |
for answer_phrases in alternative_answers: | |
# if all(phrase in completion for phrase in answer_phrases): | |
if all(re.search(rf'\b{re.escape(phrase)}\b', completion) for phrase in answer_phrases): | |
return True | |
return False | |
def _clip_text(text: str, width: int) -> str: | |
return text if len(text) <= width else text[:width] + "..." | |
def _wrap_text(text: str, width: int) -> str: | |
return textwrap.fill(text, width=width) | |
def load_results(): | |
conn = duckdb.connect(":memory:") | |
conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)") | |
conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'") | |
conn.create_function("check_answer", _check_answer) | |
conn.create_function("clip_text", _clip_text) | |
conn.create_function("wrap_text", _wrap_text) | |
return conn | |
def r1_accuracy_by_completion_length(conn,model_name): | |
""" | |
For the responses from the completions-r1 model: | |
1. We calculate completion length and correctness for each problem. | |
2. We sort by length. | |
3. We compute cumulative number of correct responses. | |
""" | |
r1_completions = conn.sql(f""" | |
WITH LengthsAndCorrectness AS ( | |
SELECT | |
LENGTH(results.completion) AS length, | |
CAST(check_answer(results.completion, challenges.answer) AS INT32) AS correct | |
FROM results.completions results JOIN challenges | |
ON results.prompt_id = challenges.ID | |
WHERE results.parent_dir = '{model_name}' | |
), | |
TotalItems AS ( | |
SELECT COUNT(*) as total_count | |
FROM LengthsAndCorrectness | |
), | |
CumulativeCorrect AS ( | |
SELECT | |
length, | |
SUM(correct) OVER (ORDER BY length) as cumulative_correct, | |
FROM LengthsAndCorrectness | |
) | |
SELECT | |
length, | |
cumulative_correct, | |
CAST(cumulative_correct AS FLOAT) / total_count AS cumulative_accuracy | |
FROM CumulativeCorrect, TotalItems | |
ORDER BY length | |
""") | |
return r1_completions | |
def accuracy_by_model_and_time(conn): | |
model_accuracies = conn.sql(""" | |
WITH ChallengesWithDates AS ( | |
SELECT | |
ID, | |
answer, | |
EXTRACT(YEAR FROM CAST(date AS DATE)) AS year | |
FROM | |
challenges | |
), | |
DateAnswerCheck AS ( | |
SELECT | |
results.parent_dir AS model, | |
dates.year, | |
COUNT(*) AS total, | |
SUM(CAST(check_answer(results.completion, dates.answer) AS INTEGER)) AS correct | |
FROM | |
results.completions results | |
JOIN | |
ChallengesWithDates dates | |
ON | |
results.prompt_id = dates.ID | |
GROUP BY | |
results.parent_dir, | |
dates.year | |
) | |
SELECT | |
model, | |
year, | |
total, | |
correct, | |
ROUND(correct / total, 2) AS accuracy | |
FROM | |
DateAnswerCheck | |
ORDER BY | |
model, | |
year | |
""") | |
return model_accuracies | |
def accuracy_by_model(conn): | |
return conn.sql(""" | |
WITH AnswerCheck AS ( | |
SELECT | |
results.parent_dir AS model, | |
COUNT(*) AS total, | |
SUM(CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct | |
FROM | |
results.completions results | |
JOIN | |
challenges challenges | |
ON | |
results.prompt_id = challenges.ID | |
GROUP BY | |
results.parent_dir | |
) | |
SELECT | |
model, | |
total, | |
correct, | |
ROUND(correct / total, 2) AS accuracy | |
FROM | |
AnswerCheck | |
""") | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--by-model-and-time", action="store_true") | |
args = parser.parse_args() | |
conn = load_results() | |
if args.by_model_and_time: | |
print(accuracy_by_model_and_time(conn)) | |
else: | |
print(accuracy_by_model(conn)) | |
if __name__ == "__main__": | |
main() | |