Spaces:
Running
Running
import re | |
import duckdb | |
import textwrap | |
def _parse_answer(text: str) -> str: | |
""" | |
Converts text to lowercase. Interprets "," and "-->" as separators for | |
elements of a set. Within each set, drops all non-alphanumeric characters | |
and returns that set. | |
Another way to describe this is that we interpret adjacent words as | |
phrases that must be present literally. However, comma and arrow separate | |
distinct phrases that may be present in any order. All other characters | |
are dropped. | |
""" | |
text = text.lower() | |
groups = re.split(r'-->|,', text) | |
return [" ".join(re.findall(r'\b\w+\b', group)) for group in groups] | |
def _answer_without_thoughts(completion: str) -> str: | |
if "<think>" not in completion[:200]: | |
return completion | |
chunks = completion.split("</think>") | |
if len(chunks) <= 1: | |
return "" | |
return chunks[-1].strip() | |
def _check_answer(completion: str, answer: str) -> bool: | |
""" | |
Check that all the phrases that must appear in the answer appear in the | |
completion. We ignore "thoughts", capitalization, and punctuation. | |
""" | |
completion = _answer_without_thoughts(completion).lower() | |
answer_phrases = _parse_answer(answer) | |
r = all(phrase in completion for phrase in answer_phrases) | |
return r | |
def _clip_text(text: str, width: int) -> str: | |
return text if len(text) <= width else text[:width] + "..." | |
def _wrap_text(text: str, width: int) -> str: | |
return textwrap.fill(text, width=width) | |
def load_results(): | |
conn = duckdb.connect(":memory:") | |
conn.execute("ATTACH DATABASE 'results.duckdb' AS results") | |
conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'") | |
conn.create_function("check_answer", _check_answer) | |
conn.create_function("clip_text", _clip_text) | |
conn.create_function("wrap_text", _wrap_text) | |
return conn | |
def accuracy_by_model(conn): | |
model_accuracies = conn.sql(""" | |
WITH AnswerCheck AS ( | |
SELECT | |
results.parent_dir AS model, | |
COUNT(*) AS total, | |
SUM(CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct | |
FROM | |
results.completions results | |
JOIN | |
challenges challenges | |
ON | |
results.prompt_id = challenges.ID | |
GROUP BY | |
results.parent_dir | |
) | |
SELECT | |
model, | |
total, | |
correct, | |
ROUND(correct / total, 2) AS accuracy | |
FROM | |
AnswerCheck | |
""") | |
print(model_accuracies) | |
if __name__ == "__main__": | |
conn = load_results() | |
accuracy_by_model(conn) | |