# imports from abc import ABC, abstractmethod from typing import Optional, Union, Sequence, Dict, Mapping, List, Any from typing_extensions import TypedDict from chroma_datasets.types import AddEmbedding, Datapoint from chroma_datasets.utils import load_huggingface_dataset, to_chroma_schema from chromadb.utils import embedding_functions import os from dotenv import load_dotenv HF_API_KEY = os.environ.get("HF_API_KEY") ef_instruction_dict = { "HuggingFaceEmbeddingFunction": """ from chromadb.utils import embedding_functions hf_ef = embedding_functions.huggingface_embedding_function.HuggingFaceEmbeddingFunction(api_key={HF_API_KEY}, model_name="mixedbread-ai/mxbai-embed-large-v1") """ } class Dataset(ABC): """ Abstract class for a dataset All datasets should inherit from this class Properties: hf_data: the raw data from huggingface embedding_function: the embedding function used to generate the embeddings embeddingFunctionInstructions: tell the user how to set up the embedding function """ hf_dataset_name: str hf_data: Any embedding_function: str embedding_function_instructions: str @classmethod def load_data(cls): cls.hf_data = load_huggingface_dataset( cls.hf_dataset_name, split_name="data" ) @classmethod def raw_text(cls) -> str: if cls.hf_data is None: cls.load_data() return "\n".join(cls.hf_data["document"]) @classmethod def chunked(cls) -> List[Datapoint]: if cls.hf_data is None: cls.load_data() return cls.hf_data @classmethod def to_chroma(cls) -> AddEmbedding: return to_chroma_schema(cls.chunked()) class Memoires_DS(Dataset): """ """ hf_data = None hf_dataset_name = "eliot-hub/memoires_vec_800" embedding_function = "HuggingFaceEmbeddingFunction" embedding_function_instructions = ef_instruction_dict[embedding_function] def import_into_chroma(chroma_client, dataset, collection_name=None, embedding_function=None, batch_size=5000): """ Imports a dataset into Chroma in batches. Args: chroma_client (ChromaClient): The ChromaClient to use. collection_name (str): The name of the collection to load the dataset into. dataset (AddEmbedding): The dataset to load. embedding_function (Optional[Callable[[str], np.ndarray]]): A function that takes a string and returns an embedding. batch_size (int): The size of each batch to load. """ # if chromadb is not installed, raise an error try: import chromadb from chromadb.utils import embedding_functions except ImportError: raise ImportError("Please install chromadb to use this function. `pip install chromadb`") ef = None if dataset.embedding_function is not None: if embedding_function is None: error_msg = "See documentation" if dataset.embedding_function_instructions is not None: error_msg = dataset.embedding_function_instructions raise ValueError(f""" Dataset requires embedding function: {dataset.embedding_function}. {error_msg} """) if embedding_function.__class__.__name__ != dataset.embedding_function: raise ValueError(f"Please use {dataset.embedding_function} as the embedding function for this dataset. You passed {embedding_function.__class__.__name__}") if embedding_function is not None: ef = embedding_function # if collection_name is None, get the name from the dataset type if collection_name is None: collection_name = dataset.__name__ if ef is None: ef = embedding_functions.DefaultEmbeddingFunction() print("########### Init collection ###########") collection = chroma_client.create_collection( collection_name, embedding_function=ef ) # Retrieve the mapped data print("########### Init to_chroma ###########") mapped_data = dataset.to_chroma() del dataset # Split the data into batches and add them to the collection def chunk_data(data, size): """Helper function to split data into batches.""" for i in range(0, len(data), size): yield data[i:i+size] print("########### Chunking ###########") ids_batches = list(chunk_data(mapped_data["ids"], batch_size)) metadatas_batches = list(chunk_data(mapped_data["metadatas"], batch_size)) documents_batches = list(chunk_data(mapped_data["documents"], batch_size)) embeddings_batches = list(chunk_data(mapped_data["embeddings"], batch_size)) total_docs = len(mapped_data["ids"]) print("########### Iterating batches ###########") for i, (ids, metadatas, documents, embeddings) in enumerate(zip(ids_batches, metadatas_batches, documents_batches, embeddings_batches)): collection.add( ids=ids, metadatas=metadatas, documents=documents, embeddings=embeddings, ) print(f"Batch {i+1}/{len(ids_batches)}: Loaded {len(ids)} documents.") print(f"Successfully loaded {total_docs} documents into the collection named: {collection_name}") return collection