chatbot_app / hf_to_chroma_ds.py
eliot-hub's picture
Update hf_to_chroma_ds.py
df5d241 verified
raw
history blame
5.42 kB
# imports
from abc import ABC, abstractmethod
from typing import Optional, Union, Sequence, Dict, Mapping, List, Any
from typing_extensions import TypedDict
from chroma_datasets.types import AddEmbedding, Datapoint
from chroma_datasets.utils import load_huggingface_dataset, to_chroma_schema
from chromadb.utils import embedding_functions
import os
from dotenv import load_dotenv
HF_API_KEY = os.environ.get("HF_API_KEY")
ef_instruction_dict = {
"HuggingFaceEmbeddingFunction": """
from chromadb.utils import embedding_functions
hf_ef = embedding_functions.huggingface_embedding_function.HuggingFaceEmbeddingFunction(api_key={HF_API_KEY}, model_name="mixedbread-ai/mxbai-embed-large-v1")
"""
}
class Dataset(ABC):
"""
Abstract class for a dataset
All datasets should inherit from this class
Properties:
hf_data: the raw data from huggingface
embedding_function: the embedding function used to generate the embeddings
embeddingFunctionInstructions: tell the user how to set up the embedding function
"""
hf_dataset_name: str
hf_data: Any
embedding_function: str
embedding_function_instructions: str
@classmethod
def load_data(cls):
cls.hf_data = load_huggingface_dataset(
cls.hf_dataset_name,
split_name="data"
)
@classmethod
def raw_text(cls) -> str:
if cls.hf_data is None:
cls.load_data()
return "\n".join(cls.hf_data["document"])
@classmethod
def chunked(cls) -> List[Datapoint]:
if cls.hf_data is None:
cls.load_data()
return cls.hf_data
@classmethod
def to_chroma(cls) -> AddEmbedding:
return to_chroma_schema(cls.chunked())
class Memoires_DS(Dataset):
"""
"""
hf_data = None
hf_dataset_name = "eliot-hub/memoires_vec_800"
embedding_function = "HuggingFaceEmbeddingFunction"
embedding_function_instructions = ef_instruction_dict[embedding_function]
def import_into_chroma(chroma_client, dataset, collection_name=None, embedding_function=None, batch_size=5000):
"""
Imports a dataset into Chroma in batches.
Args:
chroma_client (ChromaClient): The ChromaClient to use.
collection_name (str): The name of the collection to load the dataset into.
dataset (AddEmbedding): The dataset to load.
embedding_function (Optional[Callable[[str], np.ndarray]]): A function that takes a string and returns an embedding.
batch_size (int): The size of each batch to load.
"""
# if chromadb is not installed, raise an error
try:
import chromadb
from chromadb.utils import embedding_functions
except ImportError:
raise ImportError("Please install chromadb to use this function. `pip install chromadb`")
ef = None
if dataset.embedding_function is not None:
if embedding_function is None:
error_msg = "See documentation"
if dataset.embedding_function_instructions is not None:
error_msg = dataset.embedding_function_instructions
raise ValueError(f"""
Dataset requires embedding function: {dataset.embedding_function}.
{error_msg}
""")
if embedding_function.__class__.__name__ != dataset.embedding_function:
raise ValueError(f"Please use {dataset.embedding_function} as the embedding function for this dataset. You passed {embedding_function.__class__.__name__}")
if embedding_function is not None:
ef = embedding_function
# if collection_name is None, get the name from the dataset type
if collection_name is None:
collection_name = dataset.__name__
if ef is None:
ef = embedding_functions.DefaultEmbeddingFunction()
print("########### Init collection ###########")
collection = chroma_client.create_collection(
collection_name,
embedding_function=ef
)
# Retrieve the mapped data
print("########### Init to_chroma ###########")
mapped_data = dataset.to_chroma()
del dataset
# Split the data into batches and add them to the collection
def chunk_data(data, size):
"""Helper function to split data into batches."""
for i in range(0, len(data), size):
yield data[i:i+size]
print("########### Chunking ###########")
ids_batches = list(chunk_data(mapped_data["ids"], batch_size))
metadatas_batches = list(chunk_data(mapped_data["metadatas"], batch_size))
documents_batches = list(chunk_data(mapped_data["documents"], batch_size))
embeddings_batches = list(chunk_data(mapped_data["embeddings"], batch_size))
total_docs = len(mapped_data["ids"])
print("########### Iterating batches ###########")
for i, (ids, metadatas, documents, embeddings) in enumerate(zip(ids_batches, metadatas_batches, documents_batches, embeddings_batches)):
collection.add(
ids=ids,
metadatas=metadatas,
documents=documents,
embeddings=embeddings,
)
print(f"Batch {i+1}/{len(ids_batches)}: Loaded {len(ids)} documents.")
print(f"Successfully loaded {total_docs} documents into the collection named: {collection_name}")
return collection