|
import os |
|
from logger import log_response |
|
from custom_agent import CustomHfAgent |
|
|
|
|
|
from langchain.vectorstores import FAISS |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.chains import ConversationalRetrievalChain, ConversationChain |
|
from langchain.llms import HuggingFaceHub |
|
|
|
|
|
image = [] |
|
def handle_submission(user_message, selected_tools, url_endpoint, document, image, context): |
|
|
|
log_response("User input \n {}".format(user_message)) |
|
log_response("selected_tools \n {}".format(selected_tools)) |
|
log_response("url_endpoint \n {}".format(url_endpoint)) |
|
log_response("document \n {}".format(document)) |
|
log_response("image \n {}".format(image)) |
|
log_response("context \n {}".format(context)) |
|
|
|
agent = CustomHfAgent( |
|
url_endpoint=url_endpoint, |
|
token=os.environ['HF_token'], |
|
additional_tools=selected_tools, |
|
input_params={"max_new_tokens": 192}, |
|
) |
|
|
|
response = agent.chat(user_message,document=document,image=image, context = context) |
|
|
|
log_response("Agent Response\n {}".format(response)) |
|
|
|
return response |
|
|
|
def cut_text_after_keyword(text, keyword): |
|
index = text.find(keyword) |
|
if index != -1: |
|
return text[:index].strip() |
|
return text |
|
|
|
|
|
|
|
|
|
def handle_submission_chat(user_message, response): |
|
|
|
agent_chat_bot = ConversationChainSingleton().get_conversation_chain() |
|
|
|
if response is not None: |
|
text = agent_chat_bot.predict(input=user_message + response) |
|
else: |
|
text = agent_chat_bot.predict(input=user_message) |
|
|
|
result = cut_text_after_keyword(text, "Human:") |
|
print(result) |
|
|
|
return result |
|
|
|
class ConversationChainSingleton: |
|
_instance = None |
|
|
|
def __new__(cls, *args, **kwargs): |
|
if not cls._instance: |
|
cls._instance = super(ConversationChainSingleton, cls).__new__(cls) |
|
|
|
cls._instance.conversation_chain = get_conversation_chain() |
|
return cls._instance |
|
|
|
def get_conversation_chain(self): |
|
return self.conversation_chain |
|
|
|
|
|
def get_conversation_chain( ): |
|
""" |
|
Create a conversational retrieval chain and a language model. |
|
|
|
""" |
|
|
|
llm = HuggingFaceHub( |
|
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
model_kwargs={"max_length": 1048, "temperature":0.2, "max_new_tokens":256, "top_p":0.95, "repetition_penalty":1.0}, |
|
) |
|
|
|
|
|
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) |
|
conversation_chain = ConversationChain( |
|
llm=llm, verbose=True, memory=ConversationBufferMemory() |
|
) |
|
return conversation_chain |
|
|