|
|
|
import os |
|
import base64 |
|
import io |
|
import requests |
|
import time |
|
from transformers import Agent |
|
from logger import log_response |
|
|
|
import time |
|
import torch |
|
|
|
class CustomHfAgent(Agent): |
|
def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None): |
|
super().__init__( |
|
chat_prompt_template=chat_prompt_template, |
|
run_prompt_template=run_prompt_template, |
|
additional_tools=additional_tools, |
|
) |
|
self.url_endpoint = url_endpoint |
|
self.token = token |
|
self.input_params = input_params |
|
|
|
def generate_one(self, prompt, stop): |
|
headers = {"Authorization": self.token} |
|
max_new_tokens = self.input_params.get("max_new_tokens", 192) |
|
parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True} |
|
inputs = { |
|
"inputs": prompt, |
|
"parameters": parameters, |
|
} |
|
response = requests.post(self.url_endpoint, json=inputs, headers=headers) |
|
|
|
if response.status_code == 429: |
|
log_response("Getting rate-limited, waiting a tiny bit before trying again.") |
|
time.sleep(1) |
|
return self._generate_one(prompt) |
|
elif response.status_code != 200: |
|
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}") |
|
log_response(response) |
|
result = response.json()[0]["generated_text"] |
|
for stop_seq in stop: |
|
if result.endswith(stop_seq): |
|
return result[: -len(stop_seq)] |
|
return result |
|
|