Spaces:
Sleeping
Sleeping
from langchain_community.tools import TavilySearchResults | |
from langchain_core.retrievers import BaseRetriever | |
from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
from langchain_core.vectorstores import VectorStoreRetriever | |
from langgraph.prebuilt import create_react_agent | |
from langchain_core.documents import Document | |
from langchain_openai import ChatOpenAI | |
from langgraph.checkpoint.memory import MemorySaver | |
from mixedbread_ai.client import MixedbreadAI | |
from langchain.chains import create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain.prompts import ChatPromptTemplate | |
from dotenv import load_dotenv | |
import os | |
from langchain_chroma import Chroma | |
import chromadb | |
from typing import List | |
from datasets import load_dataset | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from tqdm import tqdm | |
from datetime import datetime | |
load_dotenv() | |
# Global params | |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
MODEL_EMB = "mxbai-embed-large" | |
MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1" | |
LLM_NAME = "gpt-4o-mini" | |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY") | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
HF_API_KEY = os.environ.get("HF_API_KEY") | |
# MixedbreadAI Client | |
mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY) | |
model_emb = "mixedbread-ai/mxbai-embed-large-v1" | |
# # Set up ChromaDB | |
memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True) | |
batched_ds = memoires_ds.batch(batch_size=41000) | |
client = chromadb.Client() | |
collection = client.get_or_create_collection(name="embeddings_mxbai") | |
for batch in tqdm(batched_ds, desc="Processing dataset batches"): | |
collection.add( | |
ids=batch["id"], | |
metadatas=batch["metadata"], | |
documents=batch["document"], | |
embeddings=batch["embedding"], | |
) | |
print(f"Collection complete: {collection.count()}") | |
del memoires_ds, batched_ds | |
llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, temperature=0) | |
def init_rag_tool(): | |
"""Init tools to allow an LLM to query the documents""" | |
# client = chromadb.PersistentClient(path=CHROMA_PATH) | |
db = Chroma( | |
client=client, | |
collection_name=f"embeddings_mxbai", | |
embedding_function = HuggingFaceEmbeddings(model_name=model_emb) | |
) | |
# Reranker class | |
class Reranker(BaseRetriever): | |
retriever: VectorStoreRetriever | |
# model: CrossEncoder | |
k: int | |
def _get_relevant_documents( | |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
docs = self.retriever.invoke(query) | |
results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k) | |
return [Document(page_content=res.input) for res in results.data] | |
# Set up reranker + LLM | |
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25}) | |
reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4) | |
llm = ChatOpenAI(model=LLM_NAME, verbose=True) | |
system_prompt = ( | |
"Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}" | |
"Si tu ne connais pas la réponse, dis que tu ne sais pas." | |
) | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system_prompt), | |
("human", "{input}"), | |
] | |
) | |
question_answer_chain = create_stuff_documents_chain(llm, prompt) | |
rag_chain = create_retrieval_chain(reranker, question_answer_chain) | |
rag_tool = rag_chain.as_tool( | |
name="RAG_search", | |
description="Recherche d'information dans les mémoires d'actuariat", | |
arg_types={"input": str}, | |
) | |
return rag_tool | |
def init_websearch_tool(): | |
web_search_tool = TavilySearchResults( | |
name="Web_search", | |
max_results=5, | |
description="Recherche d'informations sur le web", | |
search_depth="advanced", | |
include_answer=True, | |
include_raw_content=True, | |
include_images=False, | |
verbose=False, | |
) | |
return web_search_tool | |
def create_agent(): | |
rag_tool = init_rag_tool() | |
web_search_tool = init_websearch_tool() | |
memory = MemorySaver() | |
llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, verbose=True, temperature=0, streaming=True) | |
tools = [rag_tool, web_search_tool] | |
system_message = """ | |
Tu es un assistant dont la fonction est de répondre à des questions à propos de l'assurance et de l'actuariat. | |
Utilise les outils RAG_search ou Web_search pour répondre aux questions de l'utilisateur. | |
""" # Dans la réponse finale, sépare les informations de l'outil RAG et de l'outil Web. | |
react_agent = create_react_agent(llm_4o, tools, state_modifier=system_message, checkpointer=memory, debug=False) | |
return react_agent | |