Create custom_agent.py
Browse files- custom_agent.py +38 -0
custom_agent.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# custom_agent.py
|
2 |
+
import requests
|
3 |
+
import time
|
4 |
+
from transformers import Agent
|
5 |
+
|
6 |
+
class CustomHfAgent(Agent):
|
7 |
+
def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
|
8 |
+
super().__init__(
|
9 |
+
chat_prompt_template=chat_prompt_template,
|
10 |
+
run_prompt_template=run_prompt_template,
|
11 |
+
additional_tools=additional_tools,
|
12 |
+
)
|
13 |
+
self.url_endpoint = url_endpoint
|
14 |
+
self.token = token
|
15 |
+
self.input_params = input_params
|
16 |
+
|
17 |
+
def generate_one(self, prompt, stop):
|
18 |
+
headers = {"Authorization": self.token}
|
19 |
+
max_new_tokens = self.input_params.get("max_new_tokens", 192)
|
20 |
+
parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
|
21 |
+
inputs = {
|
22 |
+
"inputs": prompt,
|
23 |
+
"parameters": parameters,
|
24 |
+
}
|
25 |
+
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
|
26 |
+
|
27 |
+
if response.status_code == 429:
|
28 |
+
log_response("Getting rate-limited, waiting a tiny bit before trying again.")
|
29 |
+
time.sleep(1)
|
30 |
+
return self._generate_one(prompt)
|
31 |
+
elif response.status_code != 200:
|
32 |
+
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
|
33 |
+
log_response(response)
|
34 |
+
result = response.json()[0]["generated_text"]
|
35 |
+
for stop_seq in stop:
|
36 |
+
if result.endswith(stop_seq):
|
37 |
+
return result[: -len(stop_seq)]
|
38 |
+
return result
|