File size: 1,677 Bytes
af9f214
60bb959
 
 
af9f214
 
 
c18ea18
438881a
 
 
 
af9f214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# custom_agent.py
import os
import base64
import io
import requests
import time
from transformers import Agent
from logger import log_response

import time  
import torch

class CustomHfAgent(Agent):
    def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
        super().__init__(
            chat_prompt_template=chat_prompt_template,
            run_prompt_template=run_prompt_template,
            additional_tools=additional_tools,
        )
        self.url_endpoint = url_endpoint
        self.token = token
        self.input_params = input_params

    def generate_one(self, prompt, stop):
        headers = {"Authorization": self.token}
        max_new_tokens = self.input_params.get("max_new_tokens", 192)
        parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
        inputs = {
            "inputs": prompt,
            "parameters": parameters,
        }
        response = requests.post(self.url_endpoint, json=inputs, headers=headers)

        if response.status_code == 429:
            log_response("Getting rate-limited, waiting a tiny bit before trying again.")
            time.sleep(1)
            return self._generate_one(prompt)
        elif response.status_code != 200:
            raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
        log_response(response)
        result = response.json()[0]["generated_text"]
        for stop_seq in stop:
            if result.endswith(stop_seq):
                return result[: -len(stop_seq)]
        return result