File size: 6,065 Bytes
000c07e
 
 
8fcea9a
1b3f3b2
000c07e
8fcea9a
000c07e
8fcea9a
 
 
000c07e
 
8fcea9a
000c07e
 
 
 
 
 
8fcea9a
 
 
9a3117e
8fcea9a
 
000c07e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b8f77d
9a3117e
8fcea9a
 
2b8f77d
 
8fcea9a
 
000c07e
 
 
 
 
 
 
 
 
 
1b3f3b2
000c07e
 
 
 
 
 
2b8f77d
1b3f3b2
 
 
 
 
 
2b8f77d
1b3f3b2
 
 
 
 
 
2b8f77d
 
 
 
 
 
 
 
 
 
 
1b3f3b2
2b8f77d
1b3f3b2
 
2b8f77d
 
 
 
1b3f3b2
 
 
 
 
000c07e
1b3f3b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
000c07e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b3f3b2
 
 
 
 
 
 
 
 
000c07e
 
1b3f3b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import re
import duckdb
import textwrap
from typing import List, Tuple
import argparse

def _parse_answer(text: str) -> List[List[str]]:
    """
    Converts text to lowercase. Then interprets ";" as a separator between
    alternatives. Within each alternative, interprets "," and "-->" as separators
    for elements of a set. Within each set, drops all non-alphanumeric characters
    and returns that set.


    Another way to describe this is that we interpret adjacent words as
    phrases that must be present literally. However, comma and arrow separate
    distinct phrases that may be present in any order. All other characters
    are dropped.
    """
    text = text.lower()
    alternatives = re.split(r';', text)
    result = [ ]
    for alternative in alternatives:
        groups = re.split(r'–?-?-?>|,', alternative)
        result.append([" ".join(re.findall(r'\b\w+\b', group)) for group in groups])
    return result

def _answer_without_thoughts(completion: str) -> str:
    if "<think>" not in completion[:200]:
        return completion
    
    chunks = completion.split("</think>")
    if len(chunks) <= 1:
        return ""
    
    return chunks[-1].strip()

def _check_answer(completion: str, answer: str) -> bool:
    """
    Check that all the phrases that must appear in the answer appear in the
    completion. We ignore "thoughts", capitalization, and punctuation.
    """
    completion = _answer_without_thoughts(completion).lower()
    completion  = re.sub(r'[^\w\s]', ' ', completion) # this replaces punctuations with space, aligning with the _parse_answer function's ' '.join
    completion = re.sub(r'\s+', ' ', completion) # normalize consecutive (Unicode) spaces to finish aligning with _parse_answer
    alternative_answers = _parse_answer(answer)
    for answer_phrases in alternative_answers:
        # if all(phrase in completion for phrase in answer_phrases):
        if all(re.search(rf'\b{re.escape(phrase)}\b', completion) for phrase in answer_phrases):
            return True
    return False


def _clip_text(text: str, width: int) -> str:
    return text if len(text) <= width else text[:width] + "..."

def _wrap_text(text: str, width: int) -> str:
    return textwrap.fill(text, width=width)

def load_results():
    conn = duckdb.connect(":memory:")
    conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)")
    conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'")
    conn.create_function("check_answer", _check_answer)
    conn.create_function("clip_text", _clip_text)
    conn.create_function("wrap_text", _wrap_text)
    return conn

def r1_accuracy_by_completion_length(conn,model_name):
    """
    For the responses from the completions-r1 model:
    1. We calculate completion length and correctness for each problem.
    2. We sort by length.
    3. We compute cumulative number of correct responses.
    """
    r1_completions = conn.sql(f"""
        WITH LengthsAndCorrectness AS (
            SELECT 
                LENGTH(results.completion) AS length,
                CAST(check_answer(results.completion, challenges.answer) AS INT32) AS correct
            FROM results.completions results JOIN  challenges
            ON results.prompt_id = challenges.ID
            WHERE results.parent_dir = '{model_name}'
        ),
        TotalItems AS (
            SELECT COUNT(*) as total_count
            FROM LengthsAndCorrectness
        ),
        CumulativeCorrect AS (
            SELECT 
                length,
                SUM(correct) OVER (ORDER BY length) as cumulative_correct,
            FROM LengthsAndCorrectness
        )

        SELECT 
            length,
            cumulative_correct,
            CAST(cumulative_correct AS FLOAT) / total_count AS cumulative_accuracy
        FROM CumulativeCorrect, TotalItems
        ORDER BY length
    """)
    return r1_completions


def accuracy_by_model_and_time(conn):
    model_accuracies = conn.sql("""
        WITH ChallengesWithDates AS (
            SELECT 
                ID,
                answer,
                EXTRACT(YEAR FROM CAST(date AS DATE)) AS year
            FROM 
                challenges
        ),
        DateAnswerCheck AS (
            SELECT 
                results.parent_dir AS model,
                dates.year,
                COUNT(*) AS total,
                SUM(CAST(check_answer(results.completion, dates.answer) AS INTEGER)) AS correct
            FROM 
                results.completions results
            JOIN 
                ChallengesWithDates dates
            ON 
                results.prompt_id = dates.ID
            GROUP BY 
                results.parent_dir,
                dates.year
        )
        SELECT 
            model,
            year,
            total,
            correct,
            ROUND(correct / total, 2) AS accuracy
        FROM 
            DateAnswerCheck
        ORDER BY
            model,
            year
    """)

    return model_accuracies

def accuracy_by_model(conn):
    return conn.sql("""
        WITH AnswerCheck AS (
            SELECT 
                results.parent_dir AS model,
                COUNT(*) AS total,
                SUM(CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct
            FROM 
                results.completions results
            JOIN 
                challenges challenges
            ON 
                results.prompt_id = challenges.ID
            GROUP BY 
                results.parent_dir
        )
        SELECT 
            model,
            total,
            correct,
            ROUND(correct / total, 2) AS accuracy
        FROM 
            AnswerCheck
    """)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--by-model-and-time", action="store_true")
    args = parser.parse_args()
    conn = load_results()
    if args.by_model_and_time:
        print(accuracy_by_model_and_time(conn))
    else:
        print(accuracy_by_model(conn))

if __name__ == "__main__":
    main()