chatbot_app / tools.py
eliot-hub's picture
rm_auth
a303d6f
from langchain_community.tools import TavilySearchResults
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.vectorstores import VectorStoreRetriever
from langgraph.prebuilt import create_react_agent
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from mixedbread_ai.client import MixedbreadAI
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.prompts import ChatPromptTemplate
from dotenv import load_dotenv
import os
from langchain_chroma import Chroma
import chromadb
from typing import List
from datasets import load_dataset
from langchain_huggingface import HuggingFaceEmbeddings
from tqdm import tqdm
from datetime import datetime
load_dotenv()
# Global params
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
MODEL_EMB = "mxbai-embed-large"
MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1"
LLM_NAME = "gpt-4o-mini"
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY")
HF_TOKEN = os.environ.get("HF_TOKEN")
HF_API_KEY = os.environ.get("HF_API_KEY")
# MixedbreadAI Client
mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
model_emb = "mixedbread-ai/mxbai-embed-large-v1"
# # Set up ChromaDB
memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True)
batched_ds = memoires_ds.batch(batch_size=41000)
client = chromadb.Client()
collection = client.get_or_create_collection(name="embeddings_mxbai")
for batch in tqdm(batched_ds, desc="Processing dataset batches"):
collection.add(
ids=batch["id"],
metadatas=batch["metadata"],
documents=batch["document"],
embeddings=batch["embedding"],
)
print(f"Collection complete: {collection.count()}")
del memoires_ds, batched_ds
llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, temperature=0)
def init_rag_tool():
"""Init tools to allow an LLM to query the documents"""
# client = chromadb.PersistentClient(path=CHROMA_PATH)
db = Chroma(
client=client,
collection_name=f"embeddings_mxbai",
embedding_function = HuggingFaceEmbeddings(model_name=model_emb)
)
# Reranker class
class Reranker(BaseRetriever):
retriever: VectorStoreRetriever
# model: CrossEncoder
k: int
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
docs = self.retriever.invoke(query)
results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k)
return [Document(page_content=res.input) for res in results.data]
# Set up reranker + LLM
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25})
reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4)
llm = ChatOpenAI(model=LLM_NAME, verbose=True)
system_prompt = (
"Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}"
"Si tu ne connais pas la réponse, dis que tu ne sais pas."
)
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, prompt)
rag_chain = create_retrieval_chain(reranker, question_answer_chain)
rag_tool = rag_chain.as_tool(
name="RAG_search",
description="Recherche d'information dans les mémoires d'actuariat",
arg_types={"input": str},
)
return rag_tool
def init_websearch_tool():
web_search_tool = TavilySearchResults(
name="Web_search",
max_results=5,
description="Recherche d'informations sur le web",
search_depth="advanced",
include_answer=True,
include_raw_content=True,
include_images=False,
verbose=False,
)
return web_search_tool
def create_agent():
rag_tool = init_rag_tool()
web_search_tool = init_websearch_tool()
memory = MemorySaver()
llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, verbose=True, temperature=0, streaming=True)
tools = [rag_tool, web_search_tool]
system_message = """
Tu es un assistant dont la fonction est de répondre à des questions à propos de l'assurance et de l'actuariat.
Utilise les outils RAG_search ou Web_search pour répondre aux questions de l'utilisateur.
""" # Dans la réponse finale, sépare les informations de l'outil RAG et de l'outil Web.
react_agent = create_react_agent(llm_4o, tools, state_modifier=system_message, checkpointer=memory, debug=False)
return react_agent