Spaces:
Sleeping
Sleeping
import gradio as gr | |
from langchain_chroma import Chroma | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.chains import create_retrieval_chain, create_history_aware_retriever | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_core.prompts import MessagesPlaceholder | |
from langchain_community.chat_message_histories import ChatMessageHistory | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
import torch | |
import chromadb | |
from typing import List | |
from langchain_core.documents import Document | |
from langchain_core.retrievers import BaseRetriever | |
from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
from langchain_core.vectorstores import VectorStoreRetriever | |
from langchain_openai import ChatOpenAI | |
from mixedbread_ai.client import MixedbreadAI | |
from langchain.callbacks.tracers import ConsoleCallbackHandler | |
from langchain_huggingface import HuggingFaceEmbeddings | |
import os | |
# from chroma_datasets.utils import import_into_chroma | |
from hf_to_chroma_ds import import_into_chroma | |
from datasets import load_dataset | |
from chromadb.utils import embedding_functions | |
from hf_to_chroma_ds import Memoires_DS | |
from dotenv import load_dotenv | |
# Global params | |
CHROMA_PATH = "chromadb_mem10_mxbai_800_complete" | |
MODEL_EMB = "mxbai-embed-large" | |
MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1" | |
LLM_NAME = "gpt-4o-mini" | |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY") | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
HF_API_KEY = os.environ.get("HF_API_KEY") | |
# Load the reranker model | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY) | |
model_emb = "mixedbread-ai/mxbai-embed-large-v1" | |
huggingface_ef = embedding_functions.huggingface_embedding_function.HuggingFaceEmbeddingFunction( | |
api_key=HF_API_KEY, | |
model_name=model_emb | |
) | |
# Set up ChromaDB | |
client = chromadb.Client() | |
# memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN) | |
# client = chromadb.PersistentClient(path=os.path.join(os.path.abspath(os.getcwd()), "01_Notebooks", "RAG-ollama", "chatbot_actuariat_APP", CHROMA_PATH)) | |
# memoires_ds = Dataset( | |
# hf_data = None, | |
# hf_dataset_name = "eliot-hub/memoires_vec_800", | |
# embedding_function = huggingface_ef, | |
# embedding_function_instructions = None | |
# ) | |
collection = import_into_chroma( | |
chroma_client=client, | |
dataset=Memoires_DS, | |
embedding_function=huggingface_ef #Memoires_DS.embedding_function | |
) | |
db = Chroma( | |
client=client, | |
collection_name=f"embeddings_mxbai", | |
embedding_function = HuggingFaceEmbeddings(model_name=model_emb) | |
) | |
# Reranker class | |
class Reranker(BaseRetriever): | |
retriever: VectorStoreRetriever | |
# model: CrossEncoder | |
k: int | |
def _get_relevant_documents( | |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
docs = self.retriever.invoke(query) | |
results = mxbai_client.reranking(model="mixedbread-ai/mxbai-rerank-large-v1", query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k) | |
return [Document(page_content=res.input) for res in results.data] | |
# Set up reranker + LLM | |
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25}) | |
reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4) | |
llm = ChatOpenAI(model=LLM_NAME, verbose=True) #, api_key=OPENAI_API_KEY, ) | |
# Set up the contextualize question prompt | |
contextualize_q_system_prompt = ( | |
"Compte tenu de l'historique des discussions et de la dernière question de l'utilisateur " | |
"qui peut faire référence à un contexte dans l'historique du chat, " | |
"formuler une question autonome qui peut être comprise " | |
"sans l'historique du chat. Ne répondez PAS à la question, " | |
"juste la reformuler si nécessaire et sinon la renvoyer telle quelle." | |
) | |
contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", contextualize_q_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
# Create the history-aware retriever | |
history_aware_retriever = create_history_aware_retriever( | |
llm, reranker, contextualize_q_prompt | |
) | |
# Set up the QA prompt | |
system_prompt = ( | |
"Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}" | |
) | |
qa_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
# Create the question-answer chain | |
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) | |
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
# Set up the conversation history | |
store = {} | |
def get_session_history(session_id: str) -> ChatMessageHistory: | |
if session_id not in store: | |
store[session_id] = ChatMessageHistory() | |
return store[session_id] | |
conversational_rag_chain = RunnableWithMessageHistory( | |
rag_chain, | |
get_session_history, | |
input_messages_key="input", | |
history_messages_key="chat_history", | |
output_messages_key="answer", | |
) | |
# Gradio interface | |
def chatbot(message, history): | |
session_id = "gradio_session" | |
response = conversational_rag_chain.invoke( | |
{"input": message}, | |
config={ | |
"configurable": {"session_id": session_id}, | |
"callbacks": [ConsoleCallbackHandler()] | |
}, | |
)["answer"] | |
return response | |
iface = gr.ChatInterface( | |
chatbot, | |
title="Assurance Chatbot", | |
description="Posez vos questions sur l'assurance", | |
theme="soft", | |
examples=[ | |
"Qu'est-ce que l'assurance multirisque habitation ?", | |
"Qu'est-ce que la garantie DTA ?", | |
], | |
retry_btn=None, | |
undo_btn=None, | |
clear_btn="Effacer la conversation", | |
) | |
if __name__ == "__main__": | |
iface.launch() # share=True |