from langchain_community.tools import TavilySearchResults from langchain_core.retrievers import BaseRetriever from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.vectorstores import VectorStoreRetriever from langgraph.prebuilt import create_react_agent from langchain_core.documents import Document from langchain_openai import ChatOpenAI from langgraph.checkpoint.memory import MemorySaver from mixedbread_ai.client import MixedbreadAI from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.prompts import ChatPromptTemplate from dotenv import load_dotenv import os from langchain_chroma import Chroma import chromadb from typing import List from datasets import load_dataset from langchain_huggingface import HuggingFaceEmbeddings from tqdm import tqdm from datetime import datetime load_dotenv() # Global params OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") MODEL_EMB = "mxbai-embed-large" MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1" LLM_NAME = "gpt-4o-mini" OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY") HF_TOKEN = os.environ.get("HF_TOKEN") HF_API_KEY = os.environ.get("HF_API_KEY") # MixedbreadAI Client mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY) model_emb = "mixedbread-ai/mxbai-embed-large-v1" # # Set up ChromaDB memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True) batched_ds = memoires_ds.batch(batch_size=41000) client = chromadb.Client() collection = client.get_or_create_collection(name="embeddings_mxbai") for batch in tqdm(batched_ds, desc="Processing dataset batches"): collection.add( ids=batch["id"], metadatas=batch["metadata"], documents=batch["document"], embeddings=batch["embedding"], ) print(f"Collection complete: {collection.count()}") del memoires_ds, batched_ds llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, temperature=0) def init_rag_tool(): """Init tools to allow an LLM to query the documents""" # client = chromadb.PersistentClient(path=CHROMA_PATH) db = Chroma( client=client, collection_name=f"embeddings_mxbai", embedding_function = HuggingFaceEmbeddings(model_name=model_emb) ) # Reranker class class Reranker(BaseRetriever): retriever: VectorStoreRetriever # model: CrossEncoder k: int def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: docs = self.retriever.invoke(query) results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k) return [Document(page_content=res.input) for res in results.data] # Set up reranker + LLM retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25}) reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4) llm = ChatOpenAI(model=LLM_NAME, verbose=True) system_prompt = ( "Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}" "Si tu ne connais pas la réponse, dis que tu ne sais pas." ) prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), ("human", "{input}"), ] ) question_answer_chain = create_stuff_documents_chain(llm, prompt) rag_chain = create_retrieval_chain(reranker, question_answer_chain) rag_tool = rag_chain.as_tool( name="RAG_search", description="Recherche d'information dans les mémoires d'actuariat", arg_types={"input": str}, ) return rag_tool def init_websearch_tool(): web_search_tool = TavilySearchResults( name="Web_search", max_results=5, description="Recherche d'informations sur le web", search_depth="advanced", include_answer=True, include_raw_content=True, include_images=False, verbose=False, ) return web_search_tool def create_agent(): rag_tool = init_rag_tool() web_search_tool = init_websearch_tool() memory = MemorySaver() llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, verbose=True, temperature=0, streaming=True) tools = [rag_tool, web_search_tool] system_message = """ Tu es un assistant dont la fonction est de répondre à des questions à propos de l'assurance et de l'actuariat. Utilise les outils RAG_search ou Web_search pour répondre aux questions de l'utilisateur. """ # Dans la réponse finale, sépare les informations de l'outil RAG et de l'outil Web. react_agent = create_react_agent(llm_4o, tools, state_modifier=system_message, checkpointer=memory, debug=False) return react_agent