# custom_agent.py import os import base64 import io import requests import time from transformers import Agent from logger import log_response import time import torch class CustomHfAgent(Agent): def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None): super().__init__( chat_prompt_template=chat_prompt_template, run_prompt_template=run_prompt_template, additional_tools=additional_tools, ) self.url_endpoint = url_endpoint self.token = token self.input_params = input_params def generate_one(self, prompt, stop): headers = {"Authorization": self.token} max_new_tokens = self.input_params.get("max_new_tokens", 192) parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True} inputs = { "inputs": prompt, "parameters": parameters, } response = requests.post(self.url_endpoint, json=inputs, headers=headers) if response.status_code == 429: log_response("Getting rate-limited, waiting a tiny bit before trying again.") time.sleep(1) return self._generate_one(prompt) elif response.status_code != 200: raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}") log_response(response) result = response.json()[0]["generated_text"] for stop_seq in stop: if result.endswith(stop_seq): return result[: -len(stop_seq)] return result