eliot-hub commited on
Commit
42f87c6
·
1 Parent(s): 9b25e9e

first commit

Browse files
Files changed (2) hide show
  1. app.py +156 -0
  2. requirements.txt +194 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from langchain_community.vectorstores import Chroma
3
+ from langchain.prompts import ChatPromptTemplate
4
+ from langchain.chains import create_retrieval_chain, create_history_aware_retriever
5
+ from langchain.chains.combine_documents import create_stuff_documents_chain
6
+ from langchain_core.prompts import MessagesPlaceholder
7
+ from langchain_community.chat_message_histories import ChatMessageHistory
8
+ from langchain_core.runnables.history import RunnableWithMessageHistory
9
+ import torch
10
+ import chromadb
11
+ from typing import List
12
+ from langchain_core.documents import Document
13
+ from langchain_core.retrievers import BaseRetriever
14
+ from langchain_core.callbacks import CallbackManagerForRetrieverRun
15
+ from langchain_core.vectorstores import VectorStoreRetriever
16
+
17
+ from langchain_openai import ChatOpenAI
18
+ from mixedbread_ai.client import MixedbreadAI
19
+
20
+ from langchain.callbacks.tracers import ConsoleCallbackHandler
21
+ from langchain_huggingface import HuggingFaceEmbeddings
22
+ import os
23
+ from chroma_datasets.utils import import_into_chroma
24
+ from datasets import load_dataset
25
+
26
+ # Global params
27
+ CHROMA_PATH = "chromadb_mem10_mxbai_800_complete"
28
+ MODEL_EMB = "mxbai-embed-large"
29
+ MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1"
30
+ LLM_NAME = "gpt-4o-mini"
31
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
32
+ MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY")
33
+
34
+ # Load the reranker model
35
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
36
+ mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
37
+ model_emb = "mixedbread-ai/mxbai-embed-large-v1"
38
+
39
+ # Set up ChromaDB
40
+ client = chromadb.Client()
41
+ dataset = load_dataset("eliot-hub/memoires_vec_800", split="data")
42
+ # client = chromadb.PersistentClient(path=os.path.join(os.path.abspath(os.getcwd()), "01_Notebooks", "RAG-ollama", "chatbot_actuariat_APP", CHROMA_PATH))
43
+
44
+
45
+ db = import_into_chroma(
46
+ chroma_client=client,
47
+ dataset=dataset,
48
+ embedding_function=HuggingFaceEmbeddings(model_name=model_emb)
49
+ )
50
+ # db = Chroma(
51
+ # client=client,
52
+ # collection_name=f"embeddings_mxbai",
53
+ # embedding_function= HuggingFaceEmbeddings(model_name=model_emb)
54
+ # )
55
+
56
+
57
+ # Reranker class
58
+ class Reranker(BaseRetriever):
59
+ retriever: VectorStoreRetriever
60
+ # model: CrossEncoder
61
+ k: int
62
+
63
+ def _get_relevant_documents(
64
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
65
+ ) -> List[Document]:
66
+ docs = self.retriever.invoke(query)
67
+ results = mxbai_client.reranking(model="mixedbread-ai/mxbai-rerank-large-v1", query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k)
68
+ return [Document(page_content=res.input) for res in results.data]
69
+
70
+ # Set up reranker + LLM
71
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25})
72
+ reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4)
73
+ llm = ChatOpenAI(model=LLM_NAME, api_key=OPENAI_API_KEY, verbose=True)
74
+
75
+ # Set up the contextualize question prompt
76
+ contextualize_q_system_prompt = (
77
+ "Compte tenu de l'historique des discussions et de la dernière question de l'utilisateur "
78
+ "qui peut faire référence à un contexte dans l'historique du chat, "
79
+ "formuler une question autonome qui peut être comprise "
80
+ "sans l'historique du chat. Ne répondez PAS à la question, "
81
+ "juste la reformuler si nécessaire et sinon la renvoyer telle quelle."
82
+ )
83
+
84
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
85
+ [
86
+ ("system", contextualize_q_system_prompt),
87
+ MessagesPlaceholder("chat_history"),
88
+ ("human", "{input}"),
89
+ ]
90
+ )
91
+
92
+ # Create the history-aware retriever
93
+ history_aware_retriever = create_history_aware_retriever(
94
+ llm, reranker, contextualize_q_prompt
95
+ )
96
+
97
+ # Set up the QA prompt
98
+ system_prompt = (
99
+ "Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}"
100
+ )
101
+ qa_prompt = ChatPromptTemplate.from_messages(
102
+ [
103
+ ("system", system_prompt),
104
+ MessagesPlaceholder("chat_history"),
105
+ ("human", "{input}"),
106
+ ]
107
+ )
108
+
109
+ # Create the question-answer chain
110
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
111
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
112
+
113
+ # Set up the conversation history
114
+ store = {}
115
+
116
+ def get_session_history(session_id: str) -> ChatMessageHistory:
117
+ if session_id not in store:
118
+ store[session_id] = ChatMessageHistory()
119
+ return store[session_id]
120
+
121
+ conversational_rag_chain = RunnableWithMessageHistory(
122
+ rag_chain,
123
+ get_session_history,
124
+ input_messages_key="input",
125
+ history_messages_key="chat_history",
126
+ output_messages_key="answer",
127
+ )
128
+
129
+ # Gradio interface
130
+ def chatbot(message, history):
131
+ session_id = "gradio_session"
132
+ response = conversational_rag_chain.invoke(
133
+ {"input": message},
134
+ config={
135
+ "configurable": {"session_id": session_id},
136
+ "callbacks": [ConsoleCallbackHandler()]
137
+ },
138
+ )["answer"]
139
+ return response
140
+
141
+ iface = gr.ChatInterface(
142
+ chatbot,
143
+ title="Assurance Chatbot",
144
+ description="Posez vos questions sur l'assurance",
145
+ theme="soft",
146
+ examples=[
147
+ "Qu'est-ce que l'assurance multirisque habitation ?",
148
+ "Qu'est-ce que la garantie DTA ?",
149
+ ],
150
+ retry_btn=None,
151
+ undo_btn=None,
152
+ clear_btn="Effacer la conversation",
153
+ )
154
+
155
+ if __name__ == "__main__":
156
+ iface.launch() # share=True
requirements.txt ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.4.0
3
+ aiohttp==3.10.5
4
+ aiosignal==1.3.1
5
+ altair==5.4.1
6
+ annotated-types==0.7.0
7
+ anyio==4.4.0
8
+ asgiref==3.8.1
9
+ asttokens==2.4.1
10
+ attrs==24.2.0
11
+ backoff==2.2.1
12
+ bcrypt==4.2.0
13
+ blinker==1.8.2
14
+ build==1.2.1
15
+ cachetools==5.5.0
16
+ certifi==2024.8.30
17
+ charset-normalizer==3.3.2
18
+ chroma-datasets==0.1.5
19
+ chroma-hnswlib==0.7.6
20
+ chromadb==0.5.7
21
+ click==8.1.7
22
+ colorama==0.4.6
23
+ coloredlogs==15.0.1
24
+ comm==0.2.2
25
+ contourpy==1.3.0
26
+ cycler==0.12.1
27
+ dataclasses-json==0.6.7
28
+ datasets==3.0.0
29
+ debugpy==1.8.5
30
+ decorator==5.1.1
31
+ Deprecated==1.2.14
32
+ dill==0.3.8
33
+ distro==1.9.0
34
+ executing==2.1.0
35
+ fastapi==0.112.2
36
+ ffmpy==0.4.0
37
+ filelock==3.15.4
38
+ flatbuffers==24.3.25
39
+ fonttools==4.54.0
40
+ frozenlist==1.4.1
41
+ fsspec==2024.6.1
42
+ gitdb==4.0.11
43
+ GitPython==3.1.43
44
+ google-auth==2.34.0
45
+ googleapis-common-protos==1.65.0
46
+ gradio==4.44.0
47
+ gradio_client==1.3.0
48
+ greenlet==3.0.3
49
+ grpcio==1.66.1
50
+ h11==0.14.0
51
+ httpcore==1.0.5
52
+ httptools==0.6.1
53
+ httpx==0.27.2
54
+ httpx-sse==0.4.0
55
+ huggingface-hub==0.24.6
56
+ humanfriendly==10.0
57
+ idna==3.8
58
+ importlib_metadata==8.4.0
59
+ importlib_resources==6.4.4
60
+ ipykernel==6.29.5
61
+ ipython==8.27.0
62
+ jedi==0.19.1
63
+ Jinja2==3.1.4
64
+ jiter==0.5.0
65
+ joblib==1.4.2
66
+ jsonpatch==1.33
67
+ jsonpointer==3.0.0
68
+ jsonschema==4.23.0
69
+ jsonschema-specifications==2023.12.1
70
+ jupyter_client==8.6.2
71
+ jupyter_core==5.7.2
72
+ kiwisolver==1.4.7
73
+ kubernetes==30.1.0
74
+ langchain==0.3.0
75
+ langchain-chroma==0.1.4
76
+ langchain-community==0.3.0
77
+ langchain-core==0.3.5
78
+ langchain-huggingface==0.1.0
79
+ langchain-openai==0.2.0
80
+ langchain-text-splitters==0.3.0
81
+ langsmith==0.1.126
82
+ markdown-it-py==3.0.0
83
+ MarkupSafe==2.1.5
84
+ marshmallow==3.22.0
85
+ matplotlib==3.9.2
86
+ matplotlib-inline==0.1.7
87
+ mdurl==0.1.2
88
+ mixedbread-ai==2.2.6
89
+ mmh3==4.1.0
90
+ monotonic==1.6
91
+ mpmath==1.3.0
92
+ multidict==6.0.5
93
+ multiprocess==0.70.16
94
+ mypy-extensions==1.0.0
95
+ narwhals==1.6.0
96
+ nest-asyncio==1.6.0
97
+ networkx==3.3
98
+ numpy==1.26.4
99
+ oauthlib==3.2.2
100
+ onnxruntime==1.19.0
101
+ openai==1.43.0
102
+ opentelemetry-api==1.27.0
103
+ opentelemetry-exporter-otlp-proto-common==1.27.0
104
+ opentelemetry-exporter-otlp-proto-grpc==1.27.0
105
+ opentelemetry-instrumentation==0.48b0
106
+ opentelemetry-instrumentation-asgi==0.48b0
107
+ opentelemetry-instrumentation-fastapi==0.48b0
108
+ opentelemetry-proto==1.27.0
109
+ opentelemetry-sdk==1.27.0
110
+ opentelemetry-semantic-conventions==0.48b0
111
+ opentelemetry-util-http==0.48b0
112
+ orjson==3.10.7
113
+ overrides==7.7.0
114
+ packaging==24.1
115
+ pandas==2.2.2
116
+ parso==0.8.4
117
+ pillow==10.4.0
118
+ platformdirs==4.3.2
119
+ posthog==3.6.0
120
+ prompt_toolkit==3.0.47
121
+ protobuf==4.25.4
122
+ psutil==6.0.0
123
+ pure_eval==0.2.3
124
+ pyarrow==17.0.0
125
+ pyasn1==0.6.0
126
+ pyasn1_modules==0.4.0
127
+ pydantic==2.8.2
128
+ pydantic-settings==2.5.2
129
+ pydantic_core==2.20.1
130
+ pydeck==0.9.1
131
+ pydub==0.25.1
132
+ Pygments==2.18.0
133
+ pyparsing==3.1.4
134
+ pypdf==4.3.1
135
+ PyPika==0.48.9
136
+ pyproject_hooks==1.1.0
137
+ pyreadline3==3.4.1
138
+ python-dateutil==2.9.0.post0
139
+ python-dotenv==1.0.1
140
+ python-multipart==0.0.10
141
+ pytz==2024.1
142
+ pywin32==306
143
+ PyYAML==6.0.2
144
+ pyzmq==26.2.0
145
+ referencing==0.35.1
146
+ regex==2024.7.24
147
+ requests==2.32.3
148
+ requests-oauthlib==2.0.0
149
+ rich==13.8.0
150
+ rpds-py==0.20.0
151
+ rsa==4.9
152
+ ruff==0.6.7
153
+ safetensors==0.4.4
154
+ scikit-learn==1.5.2
155
+ scipy==1.14.1
156
+ semantic-version==2.10.0
157
+ sentence-transformers==3.1.1
158
+ sentencepiece==0.2.0
159
+ setuptools==72.1.0
160
+ shellingham==1.5.4
161
+ six==1.16.0
162
+ smmap==5.0.1
163
+ sniffio==1.3.1
164
+ SQLAlchemy==2.0.32
165
+ stack-data==0.6.3
166
+ starlette==0.38.4
167
+ sympy==1.13.2
168
+ tenacity==8.5.0
169
+ threadpoolctl==3.5.0
170
+ tiktoken==0.7.0
171
+ tokenizers==0.19.1
172
+ toml==0.10.2
173
+ tomlkit==0.12.0
174
+ torch==2.4.0
175
+ tornado==6.4.1
176
+ tqdm==4.66.5
177
+ traitlets==5.14.3
178
+ transformers==4.44.2
179
+ typer==0.12.5
180
+ typing-inspect==0.9.0
181
+ typing_extensions==4.12.2
182
+ tzdata==2024.1
183
+ urllib3==2.2.2
184
+ uvicorn==0.30.6
185
+ watchdog==4.0.2
186
+ watchfiles==0.24.0
187
+ wcwidth==0.2.13
188
+ websocket-client==1.8.0
189
+ websockets==12.0
190
+ wheel==0.43.0
191
+ wrapt==1.16.0
192
+ xxhash==3.5.0
193
+ yarl==1.9.7
194
+ zipp==3.20.1