#!/usr/bin/env python3 import yaml from opentelemetry.sdk.trace import TracerProvider from openinference.instrumentation.smolagents import SmolagentsInstrumentor from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.trace.export import SimpleSpanProcessor endpoint = "http://0.0.0.0:6006/v1/traces" trace_provider = TracerProvider() trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint))) SmolagentsInstrumentor().instrument(tracer_provider=trace_provider) import argparse import logging import os import time from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import datasets import pandas as pd from dabstep_benchmark.utils import evaluate from smolagents.utils import console from utils import TqdmLoggingHandler from constants import REPO_ID from tqdm import tqdm from prompts import ( reasoning_llm_system_prompt, reasoning_llm_task_prompt, chat_llm_task_prompt, chat_llm_system_prompt ) from utils import ( is_reasoning_llm, create_code_agent_with_chat_llm, create_code_agent_with_reasoning_llm, get_tasks_to_run, append_answer, append_console_output, download_context ) logging.basicConfig(level=logging.WARNING, handlers=[TqdmLoggingHandler()]) logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--concurrency", type=int, default=4) parser.add_argument("--model-id", type=str, default="openai/o3-mini") parser.add_argument("--experiment", type=str, default=None) parser.add_argument("--max-tasks", type=int, default=-1) parser.add_argument("--max-steps", type=int, default=10) parser.add_argument("--tasks-ids", type=int, nargs="+", default=None) parser.add_argument("--api-base", type=str, default=None) parser.add_argument("--api-key", type=str, default=None) parser.add_argument("--split", type=str, default="default", choices=["default", "dev"]) parser.add_argument("--timestamp", type=str, default=None) return parser.parse_args() def run_single_task( task: dict, model_id: str, api_base: str, api_key: str, ctx_path: str, base_filename: Path, is_dev_data: bool, max_steps: int ): if is_reasoning_llm(model_id): prompt = reasoning_llm_task_prompt.format( question=task["question"], guidelines=task["guidelines"] ) agent = create_code_agent_with_reasoning_llm(model_id, api_base, api_key, max_steps, ctx_path) else: prompt = chat_llm_task_prompt.format( ctx_path=ctx_path, question=task["question"], guidelines=task["guidelines"] ) agent = create_code_agent_with_chat_llm(model_id, api_base, api_key, max_steps) with console.capture() as capture: answer = agent.run(prompt) logger.warning(f"Task id: {task['task_id']}\tQuestion: {task['question']} Answer: {answer}\n{'=' * 50}") answer_dict = {"task_id": str(task["task_id"]), "agent_answer": str(answer)} answers_file = base_filename / "answers.jsonl" logs_file = base_filename / "logs.txt" if is_dev_data: scores = evaluate(agent_answers=pd.DataFrame([answer_dict]), tasks_with_gt=pd.DataFrame([task])) entry = {**answer_dict, "answer": task["answer"], "score": scores[0]["score"], "level": scores[0]["level"]} append_answer(entry, answers_file) else: append_answer(answer_dict, answers_file) append_console_output(capture.get(), logs_file) def main(): args = parse_args() logger.warning(f"Starting run with arguments: {args}") ctx_path = download_context(str(Path().resolve())) runs_dir = Path().resolve() / "runs" runs_dir.mkdir(parents=True, exist_ok=True) timestamp = time.time() if not args.timestamp else args.timestamp base_filename = runs_dir / f"{args.model_id.replace('/', '_').replace('.', '_')}/{args.split}/{int(timestamp)}" # save config os.makedirs(base_filename, exist_ok=True) with open(base_filename / "config.yaml", "w", encoding="utf-8") as f: if is_reasoning_llm(args.model_id): args.system_prompt = reasoning_llm_system_prompt else: args.system_prompt = chat_llm_system_prompt args_dict = vars(args) yaml.dump(args_dict, f, default_flow_style=False) # Load dataset with user-chosen split data = datasets.load_dataset(REPO_ID, name="tasks", split=args.split, download_mode='force_redownload') if args.max_tasks >= 0 and args.tasks_ids is not None: logger.error(f"Can not provide {args.max_tasks=} and {args.tasks_ids=} at the same time") total = len(data) if args.max_tasks < 0 else min(len(data), args.max_tasks) tasks_to_run = get_tasks_to_run(data, total, base_filename, args.tasks_ids) with ThreadPoolExecutor(max_workers=args.concurrency) as exe: futures = [ exe.submit( run_single_task, task, args.model_id, args.api_base, args.api_key, ctx_path, base_filename, (args.split == "dev"), args.max_steps ) for task in tasks_to_run ] for f in tqdm(as_completed(futures), total=len(tasks_to_run), desc="Processing tasks"): f.result() logger.warning("All tasks processed.") if __name__ == "__main__": main()