Spaces:
Sleeping
Sleeping
# imports | |
from abc import ABC, abstractmethod | |
from typing import Optional, Union, Sequence, Dict, Mapping, List, Any | |
from typing_extensions import TypedDict | |
from chroma_datasets.types import AddEmbedding, Datapoint | |
from chroma_datasets.utils import load_huggingface_dataset, to_chroma_schema | |
from chromadb.utils import embedding_functions | |
import os | |
from dotenv import load_dotenv | |
HF_API_KEY = os.environ.get("HF_API_KEY") | |
ef_instruction_dict = { | |
"HuggingFaceEmbeddingFunction": """ | |
from chromadb.utils import embedding_functions | |
hf_ef = embedding_functions.huggingface_embedding_function.HuggingFaceEmbeddingFunction(api_key={HF_API_KEY}, model_name="mixedbread-ai/mxbai-embed-large-v1") | |
""" | |
} | |
class Dataset(ABC): | |
""" | |
Abstract class for a dataset | |
All datasets should inherit from this class | |
Properties: | |
hf_data: the raw data from huggingface | |
embedding_function: the embedding function used to generate the embeddings | |
embeddingFunctionInstructions: tell the user how to set up the embedding function | |
""" | |
hf_dataset_name: str | |
hf_data: Any | |
embedding_function: str | |
embedding_function_instructions: str | |
def load_data(cls): | |
cls.hf_data = load_huggingface_dataset( | |
cls.hf_dataset_name, | |
split_name="data" | |
) | |
def raw_text(cls) -> str: | |
if cls.hf_data is None: | |
cls.load_data() | |
return "\n".join(cls.hf_data["document"]) | |
def chunked(cls) -> List[Datapoint]: | |
if cls.hf_data is None: | |
cls.load_data() | |
return cls.hf_data | |
def to_chroma(cls) -> AddEmbedding: | |
return to_chroma_schema(cls.chunked()) | |
class Memoires_DS(Dataset): | |
""" | |
""" | |
hf_data = None | |
hf_dataset_name = "eliot-hub/memoires_vec_800" | |
embedding_function = "HuggingFaceEmbeddingFunction" | |
embedding_function_instructions = ef_instruction_dict[embedding_function] | |
def import_into_chroma(chroma_client, dataset, collection_name=None, embedding_function=None, batch_size=5000): | |
""" | |
Imports a dataset into Chroma in batches. | |
Args: | |
chroma_client (ChromaClient): The ChromaClient to use. | |
collection_name (str): The name of the collection to load the dataset into. | |
dataset (AddEmbedding): The dataset to load. | |
embedding_function (Optional[Callable[[str], np.ndarray]]): A function that takes a string and returns an embedding. | |
batch_size (int): The size of each batch to load. | |
""" | |
# if chromadb is not installed, raise an error | |
try: | |
import chromadb | |
from chromadb.utils import embedding_functions | |
except ImportError: | |
raise ImportError("Please install chromadb to use this function. `pip install chromadb`") | |
ef = None | |
if dataset.embedding_function is not None: | |
if embedding_function is None: | |
error_msg = "See documentation" | |
if dataset.embedding_function_instructions is not None: | |
error_msg = dataset.embedding_function_instructions | |
raise ValueError(f""" | |
Dataset requires embedding function: {dataset.embedding_function}. | |
{error_msg} | |
""") | |
if embedding_function.__class__.__name__ != dataset.embedding_function: | |
raise ValueError(f"Please use {dataset.embedding_function} as the embedding function for this dataset. You passed {embedding_function.__class__.__name__}") | |
if embedding_function is not None: | |
ef = embedding_function | |
# if collection_name is None, get the name from the dataset type | |
if collection_name is None: | |
collection_name = dataset.__name__ | |
if ef is None: | |
ef = embedding_functions.DefaultEmbeddingFunction() | |
print("########### Init collection ###########") | |
collection = chroma_client.create_collection( | |
collection_name, | |
embedding_function=ef | |
) | |
# Retrieve the mapped data | |
print("########### Init to_chroma ###########") | |
mapped_data = dataset.to_chroma() | |
del dataset | |
# Split the data into batches and add them to the collection | |
def chunk_data(data, size): | |
"""Helper function to split data into batches.""" | |
for i in range(0, len(data), size): | |
yield data[i:i+size] | |
print("########### Chunking ###########") | |
ids_batches = list(chunk_data(mapped_data["ids"], batch_size)) | |
metadatas_batches = list(chunk_data(mapped_data["metadatas"], batch_size)) | |
documents_batches = list(chunk_data(mapped_data["documents"], batch_size)) | |
embeddings_batches = list(chunk_data(mapped_data["embeddings"], batch_size)) | |
total_docs = len(mapped_data["ids"]) | |
print("########### Iterating batches ###########") | |
for i, (ids, metadatas, documents, embeddings) in enumerate(zip(ids_batches, metadatas_batches, documents_batches, embeddings_batches)): | |
collection.add( | |
ids=ids, | |
metadatas=metadatas, | |
documents=documents, | |
embeddings=embeddings, | |
) | |
print(f"Batch {i+1}/{len(ids_batches)}: Loaded {len(ids)} documents.") | |
print(f"Successfully loaded {total_docs} documents into the collection named: {collection_name}") | |
return collection | |