Spaces:
Running
Running
File size: 5,014 Bytes
f55a67c bc7d8a5 4066cbb f55a67c a303d6f 69a7d1e bc7d8a5 f55a67c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from langchain_community.tools import TavilySearchResults
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.vectorstores import VectorStoreRetriever
from langgraph.prebuilt import create_react_agent
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from mixedbread_ai.client import MixedbreadAI
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.prompts import ChatPromptTemplate
from dotenv import load_dotenv
import os
from langchain_chroma import Chroma
import chromadb
from typing import List
from datasets import load_dataset
from langchain_huggingface import HuggingFaceEmbeddings
from tqdm import tqdm
from datetime import datetime
load_dotenv()
# Global params
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
MODEL_EMB = "mxbai-embed-large"
MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1"
LLM_NAME = "gpt-4o-mini"
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY")
HF_TOKEN = os.environ.get("HF_TOKEN")
HF_API_KEY = os.environ.get("HF_API_KEY")
# MixedbreadAI Client
mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
model_emb = "mixedbread-ai/mxbai-embed-large-v1"
# # Set up ChromaDB
memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True)
batched_ds = memoires_ds.batch(batch_size=41000)
client = chromadb.Client()
collection = client.get_or_create_collection(name="embeddings_mxbai")
for batch in tqdm(batched_ds, desc="Processing dataset batches"):
collection.add(
ids=batch["id"],
metadatas=batch["metadata"],
documents=batch["document"],
embeddings=batch["embedding"],
)
print(f"Collection complete: {collection.count()}")
del memoires_ds, batched_ds
llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, temperature=0)
def init_rag_tool():
"""Init tools to allow an LLM to query the documents"""
# client = chromadb.PersistentClient(path=CHROMA_PATH)
db = Chroma(
client=client,
collection_name=f"embeddings_mxbai",
embedding_function = HuggingFaceEmbeddings(model_name=model_emb)
)
# Reranker class
class Reranker(BaseRetriever):
retriever: VectorStoreRetriever
# model: CrossEncoder
k: int
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
docs = self.retriever.invoke(query)
results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k)
return [Document(page_content=res.input) for res in results.data]
# Set up reranker + LLM
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25})
reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4)
llm = ChatOpenAI(model=LLM_NAME, verbose=True)
system_prompt = (
"Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}"
"Si tu ne connais pas la réponse, dis que tu ne sais pas."
)
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, prompt)
rag_chain = create_retrieval_chain(reranker, question_answer_chain)
rag_tool = rag_chain.as_tool(
name="RAG_search",
description="Recherche d'information dans les mémoires d'actuariat",
arg_types={"input": str},
)
return rag_tool
def init_websearch_tool():
web_search_tool = TavilySearchResults(
name="Web_search",
max_results=5,
description="Recherche d'informations sur le web",
search_depth="advanced",
include_answer=True,
include_raw_content=True,
include_images=False,
verbose=False,
)
return web_search_tool
def create_agent():
rag_tool = init_rag_tool()
web_search_tool = init_websearch_tool()
memory = MemorySaver()
llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, verbose=True, temperature=0, streaming=True)
tools = [rag_tool, web_search_tool]
system_message = """
Tu es un assistant dont la fonction est de répondre à des questions à propos de l'assurance et de l'actuariat.
Utilise les outils RAG_search ou Web_search pour répondre aux questions de l'utilisateur.
""" # Dans la réponse finale, sépare les informations de l'outil RAG et de l'outil Web.
react_agent = create_react_agent(llm_4o, tools, state_modifier=system_message, checkpointer=memory, debug=False)
return react_agent
|