|
""" |
|
Module: custom_agent |
|
|
|
This module provides a custom class, CustomHfAgent, for interacting with the Hugging Face model API. |
|
|
|
Dependencies: |
|
- time: Standard Python time module for time-related operations. |
|
- requests: HTTP library for making requests. |
|
- transformers: Hugging Face's transformers library for NLP tasks. |
|
- utils.logger: Custom logger module for logging responses. |
|
|
|
Classes: |
|
- CustomHfAgent: A custom class for interacting with the Hugging Face model API. |
|
|
|
Reasono for making this https://github.com/huggingface/transformers/issues/28217 |
|
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/tools/agents.py |
|
|
|
"return_full_text": False, |
|
|
|
""" |
|
|
|
import time |
|
import requests |
|
from transformers import Agent |
|
|
|
|
|
CHAT_MESSAGE_PROMPT = """ |
|
Human: <<task>> |
|
|
|
Assistant: """ |
|
|
|
from utils.logger import log_response |
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
class CustomHfAgent(Agent): |
|
"""A custom class for interacting with the Hugging Face model API.""" |
|
|
|
|
|
|
|
|
|
def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None): |
|
""" |
|
Initialize the CustomHfAgent. |
|
|
|
Args: |
|
- url_endpoint (str): The URL endpoint for the Hugging Face model API. |
|
- token (str): The authentication token required to access the API. |
|
- chat_prompt_template (str): Template for chat prompts. |
|
- run_prompt_template (str): Template for run prompts. |
|
- additional_tools (list): Additional tools for the agent. |
|
- input_params (dict): Additional parameters for input. |
|
|
|
Returns: |
|
- None |
|
""" |
|
super().__init__( |
|
chat_prompt_template=chat_prompt_template, |
|
run_prompt_template=run_prompt_template, |
|
additional_tools=additional_tools, |
|
) |
|
self.url_endpoint = url_endpoint |
|
self.token = token |
|
self.input_params = input_params |
|
|
|
def generate_one(self, prompt, stop): |
|
""" |
|
Generate one response from the Hugging Face model. |
|
|
|
Args: |
|
- prompt (str): The prompt to generate a response for. |
|
- stop (list): A list of strings indicating where to stop generating text. |
|
|
|
Returns: |
|
- str: The generated response. |
|
""" |
|
headers = {"Authorization": "Bearer " +self.token} |
|
max_new_tokens = self.input_params.get("max_new_tokens", 192) |
|
parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True} |
|
inputs = { |
|
"inputs": prompt, |
|
"parameters": parameters, |
|
} |
|
print(inputs) |
|
try: |
|
response = requests.post(self.url_endpoint, json=inputs, headers=headers, timeout=300) |
|
except requests.Timeout: |
|
pass |
|
except requests.ConnectionError: |
|
pass |
|
if response.status_code == 429: |
|
log_response("Getting rate-limited, waiting a tiny bit before trying again.") |
|
time.sleep(1) |
|
return self.generate_one(prompt, stop) |
|
elif response.status_code != 200: |
|
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}") |
|
log_response(response) |
|
result = response.json()[0]["generated_text"] |
|
for stop_seq in stop: |
|
if result.endswith(stop_seq): |
|
return result[: -len(stop_seq)] |
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def format_prompt(self, task, chat_mode=False): |
|
|
|
checkpoint = "bigcode/starcoder" |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint, token = self.token) |
|
|
|
|
|
|
|
|
|
description = "\n".join([f"- {name}: {tool.description}" for name, tool in self.toolbox.items()]) |
|
|
|
if chat_mode: |
|
if self.chat_history is None: |
|
print("no histroy yet ") |
|
prompt = self.chat_prompt_template.replace("<<all_tools>>", description) |
|
prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task) |
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": prompt, |
|
} |
|
] |
|
print("tokenized "+tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)) |
|
|
|
else: |
|
print(" chat histroy ") |
|
print(self.chat_history) |
|
prompt = self.chat_history |
|
prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task) |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": prompt, |
|
} |
|
] |
|
print("tokenized "+tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)) |
|
|
|
|
|
|
|
else: |
|
print("else block not chat mode ") |
|
prompt = self.run_prompt_template.replace("<<all_tools>>", description) |
|
prompt = prompt.replace("<<prompt>>", task) |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": prompt, |
|
} |
|
] |
|
print("tokenized "+tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)) |
|
|
|
|
|
print("formatted propmpt ---- " + prompt) |
|
return prompt |