File size: 5,014 Bytes
f55a67c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc7d8a5
4066cbb
f55a67c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a303d6f
 
 
 
 
 
 
69a7d1e
bc7d8a5
f55a67c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from langchain_community.tools import TavilySearchResults

from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.vectorstores import VectorStoreRetriever
from langgraph.prebuilt import create_react_agent
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from mixedbread_ai.client import MixedbreadAI
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.prompts import ChatPromptTemplate
from dotenv import load_dotenv
import os
from langchain_chroma import Chroma
import chromadb
from typing import List
from datasets import load_dataset
from langchain_huggingface import HuggingFaceEmbeddings 
from tqdm import tqdm
from datetime import datetime

load_dotenv()
# Global params
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
MODEL_EMB = "mxbai-embed-large"
MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1"
LLM_NAME = "gpt-4o-mini"
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY")
HF_TOKEN = os.environ.get("HF_TOKEN")
HF_API_KEY = os.environ.get("HF_API_KEY")

# MixedbreadAI Client
mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
model_emb = "mixedbread-ai/mxbai-embed-large-v1"  

# # Set up ChromaDB
memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True)
batched_ds = memoires_ds.batch(batch_size=41000)
client = chromadb.Client()
collection = client.get_or_create_collection(name="embeddings_mxbai") 
for batch in tqdm(batched_ds, desc="Processing dataset batches"): 
    collection.add(
        ids=batch["id"],
        metadatas=batch["metadata"],
        documents=batch["document"],
        embeddings=batch["embedding"],      
    )
print(f"Collection complete: {collection.count()}")
del memoires_ds, batched_ds

llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, temperature=0)

    

def init_rag_tool():
    """Init tools to allow an LLM to query the documents"""
    # client = chromadb.PersistentClient(path=CHROMA_PATH)
    db = Chroma(
        client=client,
        collection_name=f"embeddings_mxbai",
        embedding_function = HuggingFaceEmbeddings(model_name=model_emb) 
    )
    # Reranker class
    class Reranker(BaseRetriever):
        retriever: VectorStoreRetriever
        # model: CrossEncoder
        k: int

        def _get_relevant_documents(
            self, query: str, *, run_manager: CallbackManagerForRetrieverRun
        ) -> List[Document]:
            docs = self.retriever.invoke(query)
            results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k)
            return [Document(page_content=res.input) for res in results.data]

    # Set up reranker + LLM
    retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25})
    reranker = Reranker(retriever=retriever, k=4)  #Reranker(retriever=retriever, model=model, k=4)
    llm = ChatOpenAI(model=LLM_NAME, verbose=True) 


    system_prompt = (
    "Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}"
    "Si tu ne connais pas la réponse, dis que tu ne sais pas."
    )

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            ("human", "{input}"),
        ]
    )

    question_answer_chain = create_stuff_documents_chain(llm, prompt)
    rag_chain = create_retrieval_chain(reranker, question_answer_chain)

    rag_tool = rag_chain.as_tool(
        name="RAG_search",
        description="Recherche d'information dans les mémoires d'actuariat",
        arg_types={"input": str},
    )
    return rag_tool


def init_websearch_tool():
    web_search_tool = TavilySearchResults(
        name="Web_search",
        max_results=5,
        description="Recherche d'informations sur le web",
        search_depth="advanced",
        include_answer=True,
        include_raw_content=True,
        include_images=False,
        verbose=False,
    )
    return web_search_tool


def create_agent():
    rag_tool = init_rag_tool()
    web_search_tool = init_websearch_tool()
    memory = MemorySaver()  
    llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, verbose=True, temperature=0, streaming=True)
    tools = [rag_tool, web_search_tool]
    system_message = """
        Tu es un assistant dont la fonction est de répondre à des questions à propos de l'assurance et de l'actuariat.
        Utilise les outils RAG_search ou Web_search pour répondre aux questions de l'utilisateur.
    """    # Dans la réponse finale, sépare les informations de l'outil RAG et de l'outil Web. 

    react_agent = create_react_agent(llm_4o, tools, state_modifier=system_message, checkpointer=memory, debug=False)
    return react_agent