arjunguha's picture
Update
000c07e unverified
raw
history blame
2.71 kB
import re
import duckdb
import textwrap
def _parse_answer(text: str) -> str:
"""
Converts text to lowercase. Interprets "," and "-->" as separators for
elements of a set. Within each set, drops all non-alphanumeric characters
and returns that set.
Another way to describe this is that we interpret adjacent words as
phrases that must be present literally. However, comma and arrow separate
distinct phrases that may be present in any order. All other characters
are dropped.
"""
text = text.lower()
groups = re.split(r'-->|,', text)
return [" ".join(re.findall(r'\b\w+\b', group)) for group in groups]
def _answer_without_thoughts(completion: str) -> str:
if "<think>" not in completion[:200]:
return completion
chunks = completion.split("</think>")
if len(chunks) <= 1:
return ""
return chunks[-1].strip()
def _check_answer(completion: str, answer: str) -> bool:
"""
Check that all the phrases that must appear in the answer appear in the
completion. We ignore "thoughts", capitalization, and punctuation.
"""
completion = _answer_without_thoughts(completion).lower()
answer_phrases = _parse_answer(answer)
r = all(phrase in completion for phrase in answer_phrases)
return r
def _clip_text(text: str, width: int) -> str:
return text if len(text) <= width else text[:width] + "..."
def _wrap_text(text: str, width: int) -> str:
return textwrap.fill(text, width=width)
def load_results():
conn = duckdb.connect(":memory:")
conn.execute("ATTACH DATABASE 'results.duckdb' AS results")
conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'")
conn.create_function("check_answer", _check_answer)
conn.create_function("clip_text", _clip_text)
conn.create_function("wrap_text", _wrap_text)
return conn
def accuracy_by_model(conn):
model_accuracies = conn.sql("""
WITH AnswerCheck AS (
SELECT
results.parent_dir AS model,
COUNT(*) AS total,
SUM(CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct
FROM
results.completions results
JOIN
challenges challenges
ON
results.prompt_id = challenges.ID
GROUP BY
results.parent_dir
)
SELECT
model,
total,
correct,
ROUND(correct / total, 2) AS accuracy
FROM
AnswerCheck
""")
print(model_accuracies)
if __name__ == "__main__":
conn = load_results()
accuracy_by_model(conn)