Chris4K's picture
Update controller.py
3e43065 verified
raw
history blame
2.79 kB
import os
from logger import log_response
from custom_agent import CustomHfAgent
from langchain.vectorstores import FAISS
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain, ConversationChain
from langchain.llms import HuggingFaceHub
image = []
def handle_submission(user_message, selected_tools, url_endpoint, document, image, context):
log_response("User input \n {}".format(user_message))
log_response("selected_tools \n {}".format(selected_tools))
log_response("url_endpoint \n {}".format(url_endpoint))
log_response("document \n {}".format(document))
log_response("image \n {}".format(image))
log_response("context \n {}".format(context))
agent = CustomHfAgent(
url_endpoint=url_endpoint,
token=os.environ['HF_token'],
additional_tools=selected_tools,
input_params={"max_new_tokens": 192},
)
response = agent.chat(user_message,document=document,image=image, context = context)
log_response("Agent Response\n {}".format(response))
return response
def cut_text_after_keyword(text, keyword):
index = text.find(keyword)
if index != -1:
return text[:index].strip()
return text
def handle_submission_chat(user_message, response):
# os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.environ['HF_token']
agent_chat_bot = ConversationChainSingleton().get_conversation_chain()
if response is not None:
text = agent_chat_bot.predict(input=user_message + response)
else:
text = agent_chat_bot.predict(input=user_message)
result = cut_text_after_keyword(text, "Human:")
print(result)
return result
class ConversationChainSingleton:
_instance = None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(ConversationChainSingleton, cls).__new__(cls)
# Initialize your conversation chain here
cls._instance.conversation_chain = get_conversation_chain()
return cls._instance
def get_conversation_chain(self):
return self.conversation_chain
def get_conversation_chain( ):
"""
Create a conversational retrieval chain and a language model.
"""
llm = HuggingFaceHub(
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
model_kwargs={"max_length": 1048, "temperature":0.2, "max_new_tokens":256, "top_p":0.95, "repetition_penalty":1.0},
)
# llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
conversation_chain = ConversationChain(
llm=llm, verbose=True, memory=ConversationBufferMemory()
)
return conversation_chain