eliot-hub commited on
Commit
de77992
·
1 Parent(s): 31d0102
Files changed (2) hide show
  1. .gitignore +2 -1
  2. app.py +13 -16
.gitignore CHANGED
@@ -1 +1,2 @@
1
- .env
 
 
1
+ .env
2
+ hf_to_chroma_ds
app.py CHANGED
@@ -1,4 +1,8 @@
 
 
 
1
  import gradio as gr
 
2
  from langchain_chroma import Chroma
3
  from langchain.prompts import ChatPromptTemplate
4
  from langchain.chains import create_retrieval_chain, create_history_aware_retriever
@@ -6,25 +10,18 @@ from langchain.chains.combine_documents import create_stuff_documents_chain
6
  from langchain_core.prompts import MessagesPlaceholder
7
  from langchain_community.chat_message_histories import ChatMessageHistory
8
  from langchain_core.runnables.history import RunnableWithMessageHistory
9
- import torch
10
- import chromadb
11
- from typing import List
12
  from langchain_core.documents import Document
13
  from langchain_core.retrievers import BaseRetriever
14
  from langchain_core.callbacks import CallbackManagerForRetrieverRun
15
  from langchain_core.vectorstores import VectorStoreRetriever
16
-
17
  from langchain_openai import ChatOpenAI
18
- from mixedbread_ai.client import MixedbreadAI
19
-
20
  from langchain.callbacks.tracers import ConsoleCallbackHandler
21
  from langchain_huggingface import HuggingFaceEmbeddings
22
- import os
23
- # from hf_to_chroma_ds import import_into_chroma
24
  from datasets import load_dataset
25
- from chromadb.utils import embedding_functions
26
- # from hf_to_chroma_ds import Memoires_DS
27
- from dotenv import load_dotenv
28
  from tqdm import tqdm
29
 
30
  # Global params
@@ -37,18 +34,18 @@ MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY")
37
  HF_TOKEN = os.environ.get("HF_TOKEN")
38
  HF_API_KEY = os.environ.get("HF_API_KEY")
39
 
40
- # Load the reranker model
41
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
42
  mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
43
  model_emb = "mixedbread-ai/mxbai-embed-large-v1"
44
 
45
  # Set up ChromaDB
46
  memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True)
47
- batched_ds = memoires_ds.batch(batch_size=40000)
48
  client = chromadb.Client()
49
- collection = client.get_or_create_collection(name="embeddings_mxbai") #, embedding_function=HuggingFaceEmbeddings(model_name=model_emb))
50
 
51
- for batch in tqdm(batched_ds, desc="Processing dataset batches"): #, total=len(batched_ds)):
52
  collection.add(
53
  ids=batch["id"],
54
  metadatas=batch["metadata"],
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
  import gradio as gr
5
+
6
  from langchain_chroma import Chroma
7
  from langchain.prompts import ChatPromptTemplate
8
  from langchain.chains import create_retrieval_chain, create_history_aware_retriever
 
10
  from langchain_core.prompts import MessagesPlaceholder
11
  from langchain_community.chat_message_histories import ChatMessageHistory
12
  from langchain_core.runnables.history import RunnableWithMessageHistory
 
 
 
13
  from langchain_core.documents import Document
14
  from langchain_core.retrievers import BaseRetriever
15
  from langchain_core.callbacks import CallbackManagerForRetrieverRun
16
  from langchain_core.vectorstores import VectorStoreRetriever
 
17
  from langchain_openai import ChatOpenAI
 
 
18
  from langchain.callbacks.tracers import ConsoleCallbackHandler
19
  from langchain_huggingface import HuggingFaceEmbeddings
20
+
 
21
  from datasets import load_dataset
22
+ import chromadb
23
+ from typing import List
24
+ from mixedbread_ai.client import MixedbreadAI
25
  from tqdm import tqdm
26
 
27
  # Global params
 
34
  HF_TOKEN = os.environ.get("HF_TOKEN")
35
  HF_API_KEY = os.environ.get("HF_API_KEY")
36
 
37
+ # MixedbreadAI Client
38
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
39
  mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
40
  model_emb = "mixedbread-ai/mxbai-embed-large-v1"
41
 
42
  # Set up ChromaDB
43
  memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True)
44
+ batched_ds = memoires_ds.batch(batch_size=50000)
45
  client = chromadb.Client()
46
+ collection = client.get_or_create_collection(name="embeddings_mxbai")
47
 
48
+ for batch in tqdm(batched_ds, desc="Processing dataset batches"):
49
  collection.add(
50
  ids=batch["id"],
51
  metadatas=batch["metadata"],