Optimisation d'un système RAG pour la recherche sémantique

Community Article Published July 13, 2024

Article en cours de rédaction… ✍️

Cet article fait référence à l'article : https://huggingface.co./blog/not-lain/rag-chatbot-using-llama3

Introduction

Le Retrieval-Augmented Generation (RAG) est une technique puissante qui combine la recherche d'information et la génération de texte pour produire des réponses plus précises et contextuellement pertinentes. Dans cet article, nous allons explorer comment optimiser un système RAG traditionnel pour la recherche sémantique, améliorant ainsi la qualité et la pertinence des résultats.

Pourquoi optimiser pour la recherche sémantique ?

La recherche sémantique va au-delà de la simple correspondance de mots-clés. Elle cherche à comprendre l'intention et le contexte de la requête, permettant de trouver des informations pertinentes même lorsque les termes exacts ne sont pas présents dans le texte. Cette approche est particulièrement bénéfique pour les systèmes RAG, car elle permet de récupérer des informations plus pertinentes pour alimenter le modèle de génération.

Modifications clés pour l'optimisation sémantique

Nous allons examiner les modifications à apporter à un système RAG traditionnel pour l'optimiser pour la recherche sémantique. Nous comparerons la version initiale avec la version optimisée pour chaque composant clé.

Structure des données

Avant de plonger dans les optimisations, voici la structure de nos chunks de données :

{
    "id": "01",
    "title": "…",
    "content": "…",
    "tags": ["…","…"]
}

Ces chunks sont stockés au format .parquet et publiés sur Hugging Face pour faciliter leur utilisation.

Cette comparaison va se dérouler en trois étapes. 1ère étape ajout des embedding au dataset ; 2ème étape vérification de l'indexation et du retrievial ; 3ème étape intégration du systeme RAG dans gradio.

1. Ajout des embedding au dataset.

version de base :

from datasets import load_dataset

dataset = load_dataset("not-lain/wikipedia")
dataset
from sentence_transformers import SentenceTransformer
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
# embed the dataset
def embed(batch):
  # or you can combine multiple columns here, for example the title and the text
  information = batch["text"]
  return {"embeddings" : ST.encode(information)}
dataset = dataset.map(embed,batched=True,batch_size=16)
dataset.push_to_hub("not-lain/wikipedia", revision="embedded")

version optimisée recherche sémantique :

from datasets import load_dataset

dataset = load_dataset("path-dataset")
dataset
from sentence_transformers import SentenceTransformer

# Nous gardons le modèle original pour sa qualité d'embedding
ST = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def embed(batch):
    """
    Ajoute une colonne 'embeddings' au dataset en tenant compte de la structure spécifique des chunks
    """
    # Combinaison du titre, du contenu et des tags pour une représentation riche
    combined_info = []
    for item in batch:
        # Joindre les tags en une seule chaîne
        tags_string = " ".join(item['tags'])
        # Combiner titre, contenu et tags
        combined = f"{item['title']} {item['content']} {tags_string}"
        combined_info.append(combined)
    
    # Création et normalisation des embeddings
    embeddings = ST.encode(combined_info, normalize_embeddings=True)
    
    return {"embeddings": embeddings}

# Utilisation de map avec batching pour une meilleure efficacité
dataset = dataset.map(embed, batched=True, batch_size=16)
dataset.push_to_hub("path-dataset", revision="embedded")

Les améliorations apportées :

  • Le modèle d'embedding à été changé pour un modèle plus obtimisé pour la recherche sémantique. "sentence-transformers/all-MiniLM-L6-v2"
  • Combinaison du titre et du contenu pour une représentation plus riche.
  • Normalisation des embeddings pour améliorer la cohérence des comparaisons vectorielles.

2. Recherche sémantique à travers le dataset.

version de base :

from datasets import load_dataset

dataset = load_dataset("not-lain/wikipedia",revision = "embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
def search(query: str, k: int = 3 ):
    """a function that embeds a new query and returns the most probable results"""
    embedded_query = ST.encode(query) # embed new query
    scores, retrieved_examples = data.get_nearest_examples( # retrieve results
        "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
        k=k # get only top k results
    )
    return scores, retrieved_examples
scores , result = search("anarchy", 4 ) # search for word anarchy and get the best 4 matching values from the dataset
# the lower the better
scores
result['title']
print(result["text"][0])

version optimisée :

from datasets import load_dataset

dataset = load_dataset("path-dataset",revision = "embedded")
data = dataset["train"]
# Optimisation : utilisation de la métrique du produit scalaire pour les embeddings normalisés
dataset.add_faiss_index("embeddings", metric_type=faiss.METRIC_INNER_PRODUCT) #Vérifier que la fonction existe et qu'elle fonctionne.
def semantic_search(query: str, k: int = 3):
    # Normalisation de l'embedding de la requête
    embedded_query = ST.encode(query, normalize_embeddings=True)
    
    scores, retrieved_chunks = dataset.get_nearest_examples(
        "embeddings", embedded_query, k=k
    )
    
    results = []
    for score, chunk in zip(scores, retrieved_chunks):
        results.append({
            'score': score,
            'id': chunk['id'],
            'title': chunk['title'],
            'content': chunk['content'],
            'tags': chunk['tags'],
            'similarity': (1 + score) / 2  # Conversion de la similarité cosinus [-1, 1] à [0, 1]
        })
    
    results.sort(key=lambda x: x['similarity'], reverse=True)
    return results
query = "Quelle est l'identité de Lucas?"
results = semantic_search(query, k=3)

for result in results:
    print(f"ID: {result['id']}")
    print(f"Titre: {result['title']}")
    print(f"Similarité: {result['similarity']:.2f}")
    print(f"Tags: {', '.join(result['tags'])}")
    print(f"Contenu: {result['content'][:200]}...")
    print("---")

3. Intégration dans gradio.

version de base :