|
import gradio as gr |
|
import spaces |
|
import torch |
|
import pandas as pd |
|
import plotly.graph_objects as go |
|
from datasets import load_dataset |
|
from sentence_transformers import SentenceTransformer |
|
from sentence_transformers.evaluation import InformationRetrievalEvaluator, SequentialEvaluator |
|
from sentence_transformers.util import cos_sim |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
zero = torch.Tensor([0]).to(device) |
|
print(f"Device being used: {zero.device}") |
|
|
|
|
|
@spaces.GPU |
|
def evaluate_model(model_id, num_questions): |
|
model = SentenceTransformer(model_id, device=device) |
|
matryoshka_dimensions = [768, 512, 256, 128, 64] |
|
|
|
|
|
datasets_info = [ |
|
{ |
|
"name": "Financial", |
|
"dataset_id": "Omartificial-Intelligence-Space/Arabic-finanical-rag-embedding-dataset", |
|
"split": "train", |
|
"columns": ("question", "context"), |
|
"sample_size": num_questions |
|
}, |
|
{ |
|
"name": "MLQA", |
|
"dataset_id": "google/xtreme", |
|
"subset": "MLQA.ar.ar", |
|
"split": "validation", |
|
"columns": ("question", "context"), |
|
"sample_size": num_questions |
|
}, |
|
{ |
|
"name": "ARCD", |
|
"dataset_id": "hsseinmz/arcd", |
|
"split": "train", |
|
"columns": ("question", "context"), |
|
"sample_size": num_questions, |
|
"last_rows": True |
|
} |
|
] |
|
|
|
evaluation_results = [] |
|
scores_by_dataset = {} |
|
|
|
for dataset_info in datasets_info: |
|
|
|
if "subset" in dataset_info: |
|
dataset = load_dataset(dataset_info["dataset_id"], dataset_info["subset"], split=dataset_info["split"]) |
|
else: |
|
dataset = load_dataset(dataset_info["dataset_id"], split=dataset_info["split"]) |
|
|
|
|
|
if dataset_info.get("last_rows"): |
|
dataset = dataset.select( |
|
range(len(dataset) - dataset_info["sample_size"], len(dataset))) |
|
else: |
|
dataset = dataset.select(range(min(dataset_info["sample_size"], len(dataset)))) |
|
|
|
|
|
dataset = dataset.rename_column(dataset_info["columns"][0], "anchor") |
|
dataset = dataset.rename_column(dataset_info["columns"][1], "positive") |
|
|
|
|
|
if "id" not in dataset.column_names: |
|
dataset = dataset.add_column("id", range(len(dataset))) |
|
|
|
|
|
corpus = dict(zip(dataset["id"], dataset["positive"])) |
|
queries = dict(zip(dataset["id"], dataset["anchor"])) |
|
|
|
|
|
relevant_docs = {q_id: [q_id] for q_id in queries} |
|
|
|
matryoshka_evaluators = [] |
|
for dim in matryoshka_dimensions: |
|
ir_evaluator = InformationRetrievalEvaluator( |
|
queries=queries, |
|
corpus=corpus, |
|
relevant_docs=relevant_docs, |
|
name=f"dim_{dim}", |
|
truncate_dim=dim, |
|
score_functions={"cosine": cos_sim} |
|
) |
|
matryoshka_evaluators.append(ir_evaluator) |
|
|
|
evaluator = SequentialEvaluator(matryoshka_evaluators) |
|
results = evaluator(model) |
|
|
|
scores_ndcg = [] |
|
scores_mrr = [] |
|
for dim in matryoshka_dimensions: |
|
ndcg_key = f"dim_{dim}_cosine_ndcg@10" |
|
mrr_key = f"dim_{dim}_cosine_mrr@10" |
|
ndcg_score = results[ndcg_key] if ndcg_key in results else None |
|
mrr_score = results[mrr_key] if mrr_key in results else None |
|
evaluation_results.append({ |
|
"Dataset": dataset_info["name"], |
|
"Dimension": dim, |
|
"NDCG@10": ndcg_score, |
|
"MRR@10": mrr_score |
|
}) |
|
scores_ndcg.append(ndcg_score) |
|
scores_mrr.append(mrr_score) |
|
|
|
|
|
scores_by_dataset[dataset_info["name"]] = { |
|
"NDCG@10": scores_ndcg, |
|
"MRR@10": scores_mrr |
|
} |
|
|
|
|
|
result_df = pd.DataFrame(evaluation_results) |
|
|
|
|
|
charts = [] |
|
color_scale_ndcg = '#a05195' |
|
color_scale_mrr = '#2f4b7c' |
|
|
|
for dataset_name, scores in scores_by_dataset.items(): |
|
fig = go.Figure() |
|
|
|
fig.add_trace(go.Bar( |
|
x=[str(dim) for dim in matryoshka_dimensions], |
|
y=scores["NDCG@10"], |
|
name="NDCG@10", |
|
marker_color=color_scale_ndcg, |
|
text=[f"{score:.3f}" if score else "N/A" for score in scores["NDCG@10"]], |
|
textposition='auto' |
|
)) |
|
|
|
|
|
fig.add_trace(go.Bar( |
|
x=[str(dim) for dim in matryoshka_dimensions], |
|
y=scores["MRR@10"], |
|
name="MRR@10", |
|
marker_color=color_scale_mrr, |
|
text=[f"{score:.3f}" if score else "N/A" for score in scores["MRR@10"]], |
|
textposition='auto' |
|
)) |
|
|
|
fig.update_layout( |
|
title=f"{dataset_name} Evaluation", |
|
xaxis_title="Embedding Dimension", |
|
yaxis_title="Score", |
|
barmode='group', |
|
template="plotly_white" |
|
) |
|
charts.append(fig) |
|
|
|
return result_df, charts[0], charts[1], charts[2] |
|
|
|
|
|
|
|
def display_results(model_name, num_questions): |
|
result_df, chart1, chart2, chart3 = evaluate_model(model_name, num_questions) |
|
return result_df, chart1, chart2, chart3 |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=display_results, |
|
inputs=[ |
|
gr.Textbox(label="Enter a Hugging Face Model ID", |
|
placeholder="e.g., Omartificial-Intelligence-Space/GATE-AraBert-v1"), |
|
gr.Slider(label="Number of Questions", minimum=1, maximum=500, step=1, value=500) |
|
], |
|
outputs=[ |
|
gr.Dataframe(label="Evaluation Results"), |
|
gr.Plot(label="Financial Dataset"), |
|
gr.Plot(label="MLQA Dataset"), |
|
gr.Plot(label="ARCD Dataset") |
|
], |
|
title="MERAA : Matryoshka Embedding Retrieval Assessment for Arabic", |
|
description=( |
|
"Evaluate your Embedding model or any Arabic Sentence Transformer model's performance on **context and question retrieval** for Arabic datasets for Enhancing RAG (Retrieval-Augmented Generation).\n" |
|
"- **ARCD** evaluates short context retrieval performance.\n" |
|
"- **MLQA Arabic** evaluates long context retrieval performance.\n" |
|
"- **Arabic Financial Dataset** focuses on financial context retrieval.\n\n" |
|
"**Evaluation Metrics:**\n" |
|
"The evaluation uses **NDCG@10** and **MRR@10**, which measure how well the retrieved documents (contexts) match the query relevance.\n" |
|
"Higher scores indicate better performance. Embedding dimensions are reduced from 768 to 64, evaluating how well the model performs with fewer dimensions." |
|
), |
|
theme="default", |
|
live=False, |
|
css="footer {visibility: hidden;}" |
|
) |
|
|
|
demo.launch(debug=True) |
|
|
|
|
|
print("\nCreated by Omar Najar | Omartificial Intelligence Space") |
|
|