import gradio as gr from langchain_community.vectorstores import Chroma from langchain.prompts import ChatPromptTemplate from langchain.chains import create_retrieval_chain, create_history_aware_retriever from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.prompts import MessagesPlaceholder from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory import torch import chromadb from typing import List from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.vectorstores import VectorStoreRetriever from langchain_openai import ChatOpenAI from mixedbread_ai.client import MixedbreadAI from langchain.callbacks.tracers import ConsoleCallbackHandler from langchain_huggingface import HuggingFaceEmbeddings import os from chroma_datasets.utils import import_into_chroma from datasets import load_dataset # Global params CHROMA_PATH = "chromadb_mem10_mxbai_800_complete" MODEL_EMB = "mxbai-embed-large" MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1" LLM_NAME = "gpt-4o-mini" OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY") # Load the reranker model device = "cuda:0" if torch.cuda.is_available() else "cpu" mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY) model_emb = "mixedbread-ai/mxbai-embed-large-v1" # Set up ChromaDB client = chromadb.Client() dataset = load_dataset("eliot-hub/memoires_vec_800", split="data") # client = chromadb.PersistentClient(path=os.path.join(os.path.abspath(os.getcwd()), "01_Notebooks", "RAG-ollama", "chatbot_actuariat_APP", CHROMA_PATH)) db = import_into_chroma( chroma_client=client, dataset=dataset, embedding_function=HuggingFaceEmbeddings(model_name=model_emb) ) # db = Chroma( # client=client, # collection_name=f"embeddings_mxbai", # embedding_function= HuggingFaceEmbeddings(model_name=model_emb) # ) # Reranker class class Reranker(BaseRetriever): retriever: VectorStoreRetriever # model: CrossEncoder k: int def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: docs = self.retriever.invoke(query) results = mxbai_client.reranking(model="mixedbread-ai/mxbai-rerank-large-v1", query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k) return [Document(page_content=res.input) for res in results.data] # Set up reranker + LLM retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25}) reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4) llm = ChatOpenAI(model=LLM_NAME, api_key=OPENAI_API_KEY, verbose=True) # Set up the contextualize question prompt contextualize_q_system_prompt = ( "Compte tenu de l'historique des discussions et de la dernière question de l'utilisateur " "qui peut faire référence à un contexte dans l'historique du chat, " "formuler une question autonome qui peut être comprise " "sans l'historique du chat. Ne répondez PAS à la question, " "juste la reformuler si nécessaire et sinon la renvoyer telle quelle." ) contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", contextualize_q_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) # Create the history-aware retriever history_aware_retriever = create_history_aware_retriever( llm, reranker, contextualize_q_prompt ) # Set up the QA prompt system_prompt = ( "Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}" ) qa_prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) # Create the question-answer chain question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) # Set up the conversation history store = {} def get_session_history(session_id: str) -> ChatMessageHistory: if session_id not in store: store[session_id] = ChatMessageHistory() return store[session_id] conversational_rag_chain = RunnableWithMessageHistory( rag_chain, get_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", ) # Gradio interface def chatbot(message, history): session_id = "gradio_session" response = conversational_rag_chain.invoke( {"input": message}, config={ "configurable": {"session_id": session_id}, "callbacks": [ConsoleCallbackHandler()] }, )["answer"] return response iface = gr.ChatInterface( chatbot, title="Assurance Chatbot", description="Posez vos questions sur l'assurance", theme="soft", examples=[ "Qu'est-ce que l'assurance multirisque habitation ?", "Qu'est-ce que la garantie DTA ?", ], retry_btn=None, undo_btn=None, clear_btn="Effacer la conversation", ) if __name__ == "__main__": iface.launch() # share=True