File size: 5,580 Bytes
034ac91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
#!/usr/bin/env python3
import yaml
from opentelemetry.sdk.trace import TracerProvider
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
endpoint = "http://0.0.0.0:6006/v1/traces"
trace_provider = TracerProvider()
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))
SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
import argparse
import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import datasets
import pandas as pd
from dabstep_benchmark.utils import evaluate
from smolagents.utils import console
from utils import TqdmLoggingHandler
from constants import REPO_ID
from tqdm import tqdm
from prompts import (
reasoning_llm_system_prompt,
reasoning_llm_task_prompt,
chat_llm_task_prompt,
chat_llm_system_prompt
)
from utils import (
is_reasoning_llm,
create_code_agent_with_chat_llm,
create_code_agent_with_reasoning_llm,
get_tasks_to_run,
append_answer,
append_console_output,
download_context
)
logging.basicConfig(level=logging.WARNING, handlers=[TqdmLoggingHandler()])
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--concurrency", type=int, default=4)
parser.add_argument("--model-id", type=str, default="openai/o3-mini")
parser.add_argument("--experiment", type=str, default=None)
parser.add_argument("--max-tasks", type=int, default=-1)
parser.add_argument("--max-steps", type=int, default=10)
parser.add_argument("--tasks-ids", type=int, nargs="+", default=None)
parser.add_argument("--api-base", type=str, default=None)
parser.add_argument("--api-key", type=str, default=None)
parser.add_argument("--split", type=str, default="default", choices=["default", "dev"])
parser.add_argument("--timestamp", type=str, default=None)
return parser.parse_args()
def run_single_task(
task: dict,
model_id: str,
api_base: str,
api_key: str,
ctx_path: str,
base_filename: Path,
is_dev_data: bool,
max_steps: int
):
if is_reasoning_llm(model_id):
prompt = reasoning_llm_task_prompt.format(
question=task["question"],
guidelines=task["guidelines"]
)
agent = create_code_agent_with_reasoning_llm(model_id, api_base, api_key, max_steps, ctx_path)
else:
prompt = chat_llm_task_prompt.format(
ctx_path=ctx_path,
question=task["question"],
guidelines=task["guidelines"]
)
agent = create_code_agent_with_chat_llm(model_id, api_base, api_key, max_steps)
with console.capture() as capture:
answer = agent.run(prompt)
logger.warning(f"Task id: {task['task_id']}\tQuestion: {task['question']} Answer: {answer}\n{'=' * 50}")
answer_dict = {"task_id": str(task["task_id"]), "agent_answer": str(answer)}
answers_file = base_filename / "answers.jsonl"
logs_file = base_filename / "logs.txt"
if is_dev_data:
scores = evaluate(agent_answers=pd.DataFrame([answer_dict]), tasks_with_gt=pd.DataFrame([task]))
entry = {**answer_dict, "answer": task["answer"], "score": scores[0]["score"], "level": scores[0]["level"]}
append_answer(entry, answers_file)
else:
append_answer(answer_dict, answers_file)
append_console_output(capture.get(), logs_file)
def main():
args = parse_args()
logger.warning(f"Starting run with arguments: {args}")
ctx_path = download_context(str(Path().resolve()))
runs_dir = Path().resolve() / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
timestamp = time.time() if not args.timestamp else args.timestamp
base_filename = runs_dir / f"{args.model_id.replace('/', '_').replace('.', '_')}/{args.split}/{int(timestamp)}"
# save config
os.makedirs(base_filename, exist_ok=True)
with open(base_filename / "config.yaml", "w", encoding="utf-8") as f:
if is_reasoning_llm(args.model_id):
args.system_prompt = reasoning_llm_system_prompt
else:
args.system_prompt = chat_llm_system_prompt
args_dict = vars(args)
yaml.dump(args_dict, f, default_flow_style=False)
# Load dataset with user-chosen split
data = datasets.load_dataset(REPO_ID, name="tasks", split=args.split, download_mode='force_redownload')
if args.max_tasks >= 0 and args.tasks_ids is not None:
logger.error(f"Can not provide {args.max_tasks=} and {args.tasks_ids=} at the same time")
total = len(data) if args.max_tasks < 0 else min(len(data), args.max_tasks)
tasks_to_run = get_tasks_to_run(data, total, base_filename, args.tasks_ids)
with ThreadPoolExecutor(max_workers=args.concurrency) as exe:
futures = [
exe.submit(
run_single_task,
task,
args.model_id,
args.api_base,
args.api_key,
ctx_path,
base_filename,
(args.split == "dev"),
args.max_steps
)
for task in tasks_to_run
]
for f in tqdm(as_completed(futures), total=len(tasks_to_run), desc="Processing tasks"):
f.result()
logger.warning("All tasks processed.")
if __name__ == "__main__":
main() |