Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import login, InferenceClient | |
import os | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
import umap | |
import pandas as pd | |
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
login(token=HF_TOKEN) | |
client = InferenceClient(token=HF_TOKEN) | |
embeddings = HuggingFaceEmbeddings(model_name="OrdalieTech/Solon-embeddings-large-0.1") | |
db_code = FAISS.load_local("faiss_code_education", | |
embeddings, | |
allow_dangerous_deserialization=True) | |
reducer = umap.UMAP() | |
index = db_code.index | |
ntotal = min(index.ntotal, 4998) | |
embeds = index.reconstruct_n(0, ntotal) | |
umap_embeds = reducer.fit_transform(embeds) | |
articles_df = pd.DataFrame({ | |
"x" : umap_embeds[:,0], | |
"y" : umap_embeds[:,1], | |
"type" : [ "Source" ] * len(umap_embeds), | |
}) | |
system_prompt = """Tu es un assistant juridique spécialisé dans le Code de l'éducation français. | |
Ta mission est d'aider les utilisateurs à comprendre la législation en répondant à leurs questions. | |
Voici comment tu dois procéder : | |
1. **Analyse de la question:** Lis attentivement la question de l'utilisateur. | |
2. **Identification des articles pertinents:** Examine les 10 articles de loi fournis et sélectionne ceux qui sont les plus pertinents pour répondre à la question. | |
3. **Formulation de la réponse:** Rédige une réponse claire et concise en français, en utilisant les informations des articles sélectionnés. Cite explicitement les articles que tu utilises (par exemple, "Selon l'article L311-3..."). | |
4. **Structure de la réponse:** Si ta réponse s'appuie sur plusieurs articles, structure-la de manière logique, en séparant les informations provenant de chaque article. | |
5. **Cas ambigus:** | |
* Si la question est trop vague, demande des précisions à l'utilisateur. | |
* Si plusieurs articles pourraient s'appliquer, présente les différentes interprétations possibles.""" | |
def query_rag(query, model, system_prompt): | |
docs = db_code.similarity_search(query, 10) | |
article_dict = {} | |
context_list = [] | |
for doc in docs: | |
article = doc.metadata | |
context_list.append(' > '.join(article['chemin'])+'\n'+article['texte']+'\n---\n') | |
article_dict[article['article']] = article | |
user = 'Question de l\'utilisateur : ' + query + '\nContexte législatif :\n' + '\n'.join(context_list) | |
messages = [ { "role" : "system", "content" : system_prompt } ] | |
messages.append( { "role" : "user", "content" : user } ) | |
if "factice" in model: | |
return user, article_dict | |
chat_completion = client.chat_completion( | |
messages=messages, | |
model=model, | |
max_tokens=1024) | |
return chat_completion.choices[0].message.content, article_dict | |
def create_context_response(response, article_dict): | |
context = '\n' | |
for i, article in enumerate(article_dict): | |
art = article_dict[article] | |
context += '* **' + ' > '.join(art['chemin']) + '** : '+ art['texte'].replace('\n', '\n ')+'\n' | |
return context | |
def chat_interface(query, model, system_prompt): | |
response, article_dict = query_rag(query, model, system_prompt) | |
context = create_context_response(response, article_dict) | |
return response, context | |
def update_plot(query): | |
query_embed = embeddings.embed_documents([query])[0] | |
query_umap_embed = reducer.transform([query_embed]) | |
data = { | |
"x": umap_embeds[:, 0].tolist() + [query_umap_embed[0, 0]], | |
"y": umap_embeds[:, 1].tolist() + [query_umap_embed[0, 1]], | |
"type": ["Source"] * len(umap_embeds) + ["Requête"] | |
} | |
return pd.DataFrame(data) | |
with gr.Blocks(title="Assistant Juridique pour le Code de l'éducation (Beta)") as demo: | |
gr.Markdown( | |
""" | |
## Posez vos questions sur le Code de l'éducation | |
**Créé par Marc de Falco** | |
**Avertissement :** Les informations fournies sont données à titre indicatif et ne constituent pas un avis juridique. Les serveurs étant publics, veuillez ne pas inclure de données sensibles. | |
""" | |
) | |
query_box = gr.Textbox(label="Votre question") | |
model = gr.Dropdown( | |
label="Modèle de langage", | |
choices=[ | |
"meta-llama/Meta-Llama-3-70B-Instruct", | |
"meta-llama/Meta-Llama-3-8B-Instruct", | |
"HuggingFaceH4/zephyr-7b-beta", | |
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", | |
"mistralai/Mixtral-8x22B-v0.1", | |
"factice: question+contexte" | |
], | |
value="meta-llama/Meta-Llama-3-70B-Instruct") | |
submit_button = gr.Button("Envoyer") | |
with gr.Tab(label="Réponse"): | |
response_box = gr.Markdown() | |
with gr.Tab(label="Sources"): | |
sources_box = gr.Markdown() | |
with gr.Tab(label="Visualisation"): | |
scatter_plot = gr.ScatterPlot(articles_df, | |
x = "x", y = "y", | |
color="type", | |
label="Visualisation des embeddings", | |
width=500, | |
height=500) | |
with gr.Tab(label="Paramètres"): | |
system_box = gr.Textbox(label="Invite systeme", value=system_prompt, | |
lines=20) | |
submit_button.click(chat_interface, | |
inputs=[query_box, model, system_box], | |
outputs=[response_box, sources_box]) | |
submit_button.click(update_plot, inputs=[query_box], outputs=[scatter_plot]) | |
demo.launch() | |