File size: 5,580 Bytes
034ac91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env python3

import yaml
from opentelemetry.sdk.trace import TracerProvider

from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace.export import SimpleSpanProcessor

endpoint = "http://0.0.0.0:6006/v1/traces"
trace_provider = TracerProvider()
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))

SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)

import argparse
import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import datasets
import pandas as pd
from dabstep_benchmark.utils import evaluate
from smolagents.utils import console
from utils import TqdmLoggingHandler
from constants import REPO_ID
from tqdm import tqdm
from prompts import (
    reasoning_llm_system_prompt,
    reasoning_llm_task_prompt,
    chat_llm_task_prompt,
    chat_llm_system_prompt
)
from utils import (
    is_reasoning_llm,
    create_code_agent_with_chat_llm,
    create_code_agent_with_reasoning_llm,
    get_tasks_to_run,
    append_answer,
    append_console_output,
    download_context
)

logging.basicConfig(level=logging.WARNING, handlers=[TqdmLoggingHandler()])
logger = logging.getLogger(__name__)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--concurrency", type=int, default=4)
    parser.add_argument("--model-id", type=str, default="openai/o3-mini")
    parser.add_argument("--experiment", type=str, default=None)
    parser.add_argument("--max-tasks", type=int, default=-1)
    parser.add_argument("--max-steps", type=int, default=10)
    parser.add_argument("--tasks-ids", type=int, nargs="+", default=None)
    parser.add_argument("--api-base", type=str, default=None)
    parser.add_argument("--api-key", type=str, default=None)
    parser.add_argument("--split", type=str, default="default", choices=["default", "dev"])
    parser.add_argument("--timestamp", type=str, default=None)
    return parser.parse_args()


def run_single_task(
        task: dict,
        model_id: str,
        api_base: str,
        api_key: str,
        ctx_path: str,
        base_filename: Path,
        is_dev_data: bool,
        max_steps: int
):
    if is_reasoning_llm(model_id):
        prompt = reasoning_llm_task_prompt.format(
            question=task["question"],
            guidelines=task["guidelines"]
        )
        agent = create_code_agent_with_reasoning_llm(model_id, api_base, api_key, max_steps, ctx_path)
    else:
        prompt = chat_llm_task_prompt.format(
            ctx_path=ctx_path,
            question=task["question"],
            guidelines=task["guidelines"]
        )
        agent = create_code_agent_with_chat_llm(model_id, api_base, api_key, max_steps)

    with console.capture() as capture:
        answer = agent.run(prompt)

    logger.warning(f"Task id: {task['task_id']}\tQuestion: {task['question']} Answer: {answer}\n{'=' * 50}")

    answer_dict = {"task_id": str(task["task_id"]), "agent_answer": str(answer)}
    answers_file = base_filename / "answers.jsonl"
    logs_file = base_filename / "logs.txt"

    if is_dev_data:
        scores = evaluate(agent_answers=pd.DataFrame([answer_dict]), tasks_with_gt=pd.DataFrame([task]))
        entry = {**answer_dict, "answer": task["answer"], "score": scores[0]["score"], "level": scores[0]["level"]}
        append_answer(entry, answers_file)
    else:
        append_answer(answer_dict, answers_file)
    append_console_output(capture.get(), logs_file)


def main():
    args = parse_args()
    logger.warning(f"Starting run with arguments: {args}")

    ctx_path = download_context(str(Path().resolve()))

    runs_dir = Path().resolve() / "runs"
    runs_dir.mkdir(parents=True, exist_ok=True)
    timestamp = time.time() if not args.timestamp else args.timestamp
    base_filename = runs_dir / f"{args.model_id.replace('/', '_').replace('.', '_')}/{args.split}/{int(timestamp)}"

    # save config
    os.makedirs(base_filename, exist_ok=True)
    with open(base_filename / "config.yaml", "w", encoding="utf-8") as f:
        if is_reasoning_llm(args.model_id):
            args.system_prompt = reasoning_llm_system_prompt
        else:
            args.system_prompt = chat_llm_system_prompt
        args_dict = vars(args)
        yaml.dump(args_dict, f, default_flow_style=False)

    # Load dataset with user-chosen split
    data = datasets.load_dataset(REPO_ID, name="tasks", split=args.split, download_mode='force_redownload')

    if args.max_tasks >= 0 and args.tasks_ids is not None:
        logger.error(f"Can not provide {args.max_tasks=} and {args.tasks_ids=} at the same time")
    total = len(data) if args.max_tasks < 0 else min(len(data), args.max_tasks)

    tasks_to_run = get_tasks_to_run(data, total, base_filename, args.tasks_ids)
    with ThreadPoolExecutor(max_workers=args.concurrency) as exe:
        futures = [
            exe.submit(
                run_single_task,
               task,
               args.model_id,
               args.api_base,
               args.api_key,
               ctx_path,
               base_filename,
               (args.split == "dev"),
               args.max_steps
            )
            for task in tasks_to_run
        ]
        for f in tqdm(as_completed(futures), total=len(tasks_to_run), desc="Processing tasks"):
            f.result()

    logger.warning("All tasks processed.")


if __name__ == "__main__":
    main()