File size: 4,648 Bytes
f55a67c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from langchain_community.tools import TavilySearchResults

from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.vectorstores import VectorStoreRetriever
from langgraph.prebuilt import create_react_agent
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from mixedbread_ai.client import MixedbreadAI
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.prompts import ChatPromptTemplate
from dotenv import load_dotenv
import os
from langchain_chroma import Chroma
import chromadb
from typing import List
from datasets import load_dataset
from langchain_huggingface import HuggingFaceEmbeddings 


load_dotenv()
# Global params
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
MODEL_EMB = "mxbai-embed-large"
MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1"
LLM_NAME = "gpt-4o-mini"
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY")
HF_TOKEN = os.environ.get("HF_TOKEN")
HF_API_KEY = os.environ.get("HF_API_KEY")

# MixedbreadAI Client
mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
model_emb = "mixedbread-ai/mxbai-embed-large-v1"  

# # Set up ChromaDB
memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True)
batched_ds = memoires_ds.batch(batch_size=41000)
client = chromadb.Client()
collection = client.get_or_create_collection(name="embeddings_mxbai") 



llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, temperature=0)

    

def init_rag_tool():
    """Init tools to allow an LLM to query the documents"""
    # client = chromadb.PersistentClient(path=CHROMA_PATH)
    db = Chroma(
        client=client,
        collection_name=f"embeddings_mxbai",
        embedding_function = HuggingFaceEmbeddings(model_name=model_emb) 
    )
    # Reranker class
    class Reranker(BaseRetriever):
        retriever: VectorStoreRetriever
        # model: CrossEncoder
        k: int

        def _get_relevant_documents(
            self, query: str, *, run_manager: CallbackManagerForRetrieverRun
        ) -> List[Document]:
            docs = self.retriever.invoke(query)
            results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k)
            return [Document(page_content=res.input) for res in results.data]

    # Set up reranker + LLM
    retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25})
    reranker = Reranker(retriever=retriever, k=4)  #Reranker(retriever=retriever, model=model, k=4)
    llm = ChatOpenAI(model=LLM_NAME, verbose=True) 


    system_prompt = (
    "Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}"
    "Si tu ne connais pas la réponse, dis que tu ne sais pas."
    )

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            ("human", "{input}"),
        ]
    )

    question_answer_chain = create_stuff_documents_chain(llm, prompt)
    rag_chain = create_retrieval_chain(reranker, question_answer_chain)

    rag_tool = rag_chain.as_tool(
        name="RAG_search",
        description="Recherche d'information dans les mémoires d'actuariat",
        arg_types={"input": str},
    )
    return rag_tool


def init_websearch_tool():
    web_search_tool = TavilySearchResults(
        name="Web_search",
        max_results=5,
        description="Recherche d'informations sur le web",
        search_depth="advanced",
        include_answer=True,
        include_raw_content=True,
        include_images=False,
        verbose=False,
    )
    return web_search_tool


def create_agent():
    rag_tool = init_rag_tool()
    web_search_tool = init_websearch_tool()
    memory = MemorySaver()  
    llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, verbose=True, temperature=0, streaming=True)
    tools = [rag_tool, web_search_tool]
    system_message = """
        Tu es un assistant dont la fonction est de répondre à des questions à propos de l'assurance et de l'actuariat.
        Utilise les outils RAG_search ou Web_search pour répondre aux questions de l'utilisateur.
    """    # Dans la réponse finale, sépare les informations de l'outil RAG et de l'outil Web. 

    react_agent = create_react_agent(llm_4o, tools, state_modifier=system_message, checkpointer=memory, debug=False)
    return react_agent