potion-retrieval-32M Model Card

Model2Vec logo

This Model2Vec model is optmized for retrieval tasks. It is a finetune of potion-base-32M. It's finetuned using a modified version of the training approach described in this blogpost. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical.

Installation

Install model2vec using pip:

pip install model2vec

Usage

Load this model using the from_pretrained method:

from model2vec import StaticModel
# Load a pretrained Model2Vec model
model = StaticModel.from_pretrained("minishlab/potion-retrieval-32M")
# Compute text embeddings
embeddings = model.encode(["Example sentence"])

How it works

Model2vec creates a small, static model that outperforms other static embedding models by a large margin on all tasks on MTEB. This model is pre-trained using Tokenlearn. It's created using the following steps:

  • Distillation: first, a model is distilled from a sentence transformer model using Model2Vec.
  • Training data creation: the sentence transformer model is used to create training data by creating mean output embeddings on a large corpus.
  • Training: the distilled model is trained on the training data using Tokenlearn.
  • Post-training re-regularization: after training, the model is re-regularized by weighting the tokens based on their frequency, applying PCA, and finally applying SIF weighting.

The results for this model can be found on the Model2Vec results page.

Results

The results for this model are shown in the table below. The full Model2Vec results for all models can be found on the Model2Vec results page.

Average (All)                                                 49.73
Average (MTEB)                                                49.76
Classification                                                59.56
Clustering                                                    30.55
PairClassification                                            76.38
Reranking                                                     50.05
Retrieval                                                     36.35
STS                                                           73.22
Summarization                                                 28.85
PEARL                                                         49.31
WordSim                                                       50.02

Additional Resources

Library Authors

Model2Vec was developed by the Minish Lab team consisting of Stephan Tulkens and Thomas van Dongen.

Citation

Please cite the Model2Vec repository if you use this model in your work.

@software{minishlab2024model2vec,
  authors = {Stephan Tulkens and Thomas van Dongen},
  title = {Model2Vec: The Fastest State-of-the-Art Static Embeddings in the World},
  year = {2024},
  url = {https://github.com/MinishLab/model2vec}
}

Reproducibility

The following script can be used to reproduce this model. All credits go to Tom Aarsen for this fine-tuning approach and code he introduced in his blogpost. We make a few modifcations to the original code, namely:

  • We start with a pre-trained Model2Vec model (potion-base-32M).
  • We reduce the dataset size by a factor of 10. During experiments we saw that we didn't need the full dataset for the model to converge.
  • We decease the learning rate and train for 3 epochs instead of 1. Using a high learning rate wipes the effects of using a pre-trained model.
import random
import logging
from datasets import load_dataset, Dataset, DatasetDict
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers
from sentence_transformers.evaluation import NanoBEIREvaluator
from sentence_transformers.models.StaticEmbedding import StaticEmbedding
import wandb

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
random.seed(12)


def load_train_eval_datasets(factor: int = 1):
    """
    Loads train and eval datasets from disk if available. Otherwise, downloads 
    them from Hugging Face, preprocesses, and saves them to disk. If `factor` is 
    greater than 1, returns a fraction (1/factor) of each dataset subset.

    :param factor: The factor by which the data is reduced. If factor=1, no reduction is performed.
    :return: (train_dataset: DatasetDict, eval_dataset: DatasetDict)
    """
    try:
        # Try loading from disk
        train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
        eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
    except FileNotFoundError:
        print("Prebuilt datasets not found on disk. Building from scratch...")

        print("Loading gooaq dataset...")
        gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
        gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
        gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"]
        gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"]
        print("Loaded gooaq dataset.")

        print("Loading msmarco dataset...")
        msmarco_dataset = load_dataset(
            "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1",
            "triplet",
            split="train"
        )
        msmarco_dataset_dict = msmarco_dataset.train_test_split(test_size=10_000, seed=12)
        msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"]
        msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"]
        print("Loaded msmarco dataset.")

        print("Loading squad dataset...")
        squad_dataset = load_dataset("sentence-transformers/squad", split="train")
        squad_dataset_dict = squad_dataset.train_test_split(test_size=10_000, seed=12)
        squad_train_dataset: Dataset = squad_dataset_dict["train"]
        squad_eval_dataset: Dataset = squad_dataset_dict["test"]
        print("Loaded squad dataset.")

        print("Loading s2orc dataset...")
        s2orc_dataset = load_dataset(
            "sentence-transformers/s2orc", 
            "title-abstract-pair", 
            split="train[:100000]"  # limit to 100k
        )
        s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12)
        s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"]
        s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"]
        print("Loaded s2orc dataset.")

        print("Loading allnli dataset...")
        allnli_train_dataset = load_dataset(
            "sentence-transformers/all-nli", 
            "triplet", 
            split="train"
        )
        allnli_eval_dataset = load_dataset(
            "sentence-transformers/all-nli", 
            "triplet", 
            split="dev"
        )
        print("Loaded allnli dataset.")

        print("Loading paq dataset...")
        paq_dataset = load_dataset("sentence-transformers/paq", split="train")
        paq_dataset_dict = paq_dataset.train_test_split(test_size=10_000, seed=12)
        paq_train_dataset: Dataset = paq_dataset_dict["train"]
        paq_eval_dataset: Dataset = paq_dataset_dict["test"]
        print("Loaded paq dataset.")

        print("Loading trivia_qa dataset...")
        trivia_qa = load_dataset("sentence-transformers/trivia-qa", split="train")
        trivia_qa_dataset_dict = trivia_qa.train_test_split(test_size=5_000, seed=12)
        trivia_qa_train_dataset: Dataset = trivia_qa_dataset_dict["train"]
        trivia_qa_eval_dataset: Dataset = trivia_qa_dataset_dict["test"]
        print("Loaded trivia_qa dataset.")

        print("Loading msmarco_10m dataset...")
        msmarco_10m_dataset = load_dataset("bclavie/msmarco-10m-triplets", split="train")
        msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split(
            test_size=10_000, seed=12
        )
        msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"]
        msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"]
        print("Loaded msmarco_10m dataset.")

        print("Loading swim_ir dataset...")
        swim_ir_dataset = load_dataset(
            "nthakur/swim-ir-monolingual", 
            "en", 
            split="train"
        ).select_columns(["query", "text"])
        swim_ir_dataset_dict = swim_ir_dataset.train_test_split(
            test_size=10_000, seed=12
        )
        swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"]
        swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"]
        print("Loaded swim_ir dataset.")

        # NOTE: 20 negatives
        print("Loading pubmedqa dataset...")
        pubmedqa_dataset = load_dataset(
            "sentence-transformers/pubmedqa", 
            "triplet-20", 
            split="train"
        )
        pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split(test_size=100, seed=12)
        pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"]
        pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"]
        print("Loaded pubmedqa dataset.")

        # NOTE: A lot of overlap with anchor/positives
        print("Loading miracl dataset...")
        miracl_dataset = load_dataset(
            "sentence-transformers/miracl", 
            "en-triplet-all", 
            split="train"
        )
        miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12)
        miracl_train_dataset: Dataset = miracl_dataset_dict["train"]
        miracl_eval_dataset: Dataset = miracl_dataset_dict["test"]
        print("Loaded miracl dataset.")

        # NOTE: A lot of overlap with anchor/positives
        print("Loading mldr dataset...")
        mldr_dataset = load_dataset(
            "sentence-transformers/mldr", 
            "en-triplet-all", 
            split="train"
        )
        mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12)
        mldr_train_dataset: Dataset = mldr_dataset_dict["train"]
        mldr_eval_dataset: Dataset = mldr_dataset_dict["test"]
        print("Loaded mldr dataset.")

        # NOTE: A lot of overlap with anchor/positives
        print("Loading mr_tydi dataset...")
        mr_tydi_dataset = load_dataset(
            "sentence-transformers/mr-tydi", 
            "en-triplet-all", 
            split="train"
        )
        mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split(test_size=10_000, seed=12)
        mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"]
        mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"]
        print("Loaded mr_tydi dataset.")

        train_dataset = DatasetDict({
            "gooaq": gooaq_train_dataset,
            "msmarco": msmarco_train_dataset,
            "squad": squad_train_dataset,
            "s2orc": s2orc_train_dataset,
            "allnli": allnli_train_dataset,
            "paq": paq_train_dataset,
            "trivia_qa": trivia_qa_train_dataset,
            "msmarco_10m": msmarco_10m_train_dataset,
            "swim_ir": swim_ir_train_dataset,
            "pubmedqa": pubmedqa_train_dataset,
            "miracl": miracl_train_dataset,
            "mldr": mldr_train_dataset,
            "mr_tydi": mr_tydi_train_dataset,
        })
        eval_dataset = DatasetDict({
            "gooaq": gooaq_eval_dataset,
            "msmarco": msmarco_eval_dataset,
            "squad": squad_eval_dataset,
            "s2orc": s2orc_eval_dataset,
            "allnli": allnli_eval_dataset,
            "paq": paq_eval_dataset,
            "trivia_qa": trivia_qa_eval_dataset,
            "msmarco_10m": msmarco_10m_eval_dataset,
            "swim_ir": swim_ir_eval_dataset,
            "pubmedqa": pubmedqa_eval_dataset,
            "miracl": miracl_eval_dataset,
            "mldr": mldr_eval_dataset,
            "mr_tydi": mr_tydi_eval_dataset,
        })

        # Save to disk for next time
        train_dataset.save_to_disk("datasets/train_dataset")
        eval_dataset.save_to_disk("datasets/eval_dataset")

        # Quit to avoid memory overhead on large datasets
        quit()

    # Reduce the dataset if factor > 1
    if factor > 1:
        for subset_name in train_dataset:
            ds = train_dataset[subset_name].shuffle(seed=42)
            new_len = len(ds) // factor
            train_dataset[subset_name] = ds.select(range(new_len))

        for subset_name in eval_dataset:
            ds = eval_dataset[subset_name].shuffle(seed=42)
            new_len = len(ds) // factor
            eval_dataset[subset_name] = ds.select(range(new_len))

    return train_dataset, eval_dataset


def main():
    wandb.init(entity="minishlab", project="minishlab")
    # 1. Load a model to finetune
    static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-32M")

    # 2. Initialize the SentenceTransformer model
    model_name = "potion-retrieval-32M"
    model = SentenceTransformer(
        modules=[static_embedding],
        model_card_data=SentenceTransformerModelCardData(
            language="en",
            license="MIT",
            model_name=model_name,
        ),
    )

    # 3. Load training & evaluation datasets
    # NOTE: we reduce the total dataset size by a factor of 10 
    train_dataset, eval_dataset = load_train_eval_datasets(factor=10)
    print(train_dataset)

    # 4. Define a loss function
    loss = MultipleNegativesRankingLoss(model)
    loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512])

    # 5. Specify training arguments
    run_name = model_name
    epochs = 3
    lr = 0.05
    args = SentenceTransformerTrainingArguments(
        output_dir=f"models/{run_name}",
        num_train_epochs=epochs,
        per_device_train_batch_size=2048,
        per_device_eval_batch_size=2048,
        learning_rate=lr,
        warmup_ratio=0.1,
        fp16=False,
        bf16=True,
        batch_sampler=BatchSamplers.NO_DUPLICATES,
        multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
        eval_strategy="steps",
        eval_steps=250,
        save_strategy="steps",
        save_steps=250,
        save_total_limit=2,
        logging_steps=250,
        logging_first_step=True,
        run_name=run_name,
        report_to=["wandb"],
        load_best_model_at_end=True,
        metric_for_best_model="eval_NanoBEIR_mean_cosine_ndcg@10",
        greater_is_better=True,
    )

    # 6. Create an evaluator & evaluate the base model
    evaluator = NanoBEIREvaluator()
    evaluator(model)

    # 7. Create a trainer & train
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
        evaluator=evaluator,
    )
    trainer.train()

    # 8. Evaluate the trained model and save
    evaluator(model)
    model.save_pretrained(f"models/{run_name}/final")


if __name__ == "__main__":
    main()
Downloads last month
22
Safetensors
Model size
32.3M params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API was unable to determine this model’s pipeline type.

Collections including minishlab/potion-retrieval-32M