Spaces:
Sleeping
Sleeping
clean
Browse files- .gitignore +2 -1
- app.py +13 -16
.gitignore
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
.env
|
|
|
|
1 |
+
.env
|
2 |
+
hf_to_chroma_ds
|
app.py
CHANGED
@@ -1,4 +1,8 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
2 |
from langchain_chroma import Chroma
|
3 |
from langchain.prompts import ChatPromptTemplate
|
4 |
from langchain.chains import create_retrieval_chain, create_history_aware_retriever
|
@@ -6,25 +10,18 @@ from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
6 |
from langchain_core.prompts import MessagesPlaceholder
|
7 |
from langchain_community.chat_message_histories import ChatMessageHistory
|
8 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
9 |
-
import torch
|
10 |
-
import chromadb
|
11 |
-
from typing import List
|
12 |
from langchain_core.documents import Document
|
13 |
from langchain_core.retrievers import BaseRetriever
|
14 |
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
15 |
from langchain_core.vectorstores import VectorStoreRetriever
|
16 |
-
|
17 |
from langchain_openai import ChatOpenAI
|
18 |
-
from mixedbread_ai.client import MixedbreadAI
|
19 |
-
|
20 |
from langchain.callbacks.tracers import ConsoleCallbackHandler
|
21 |
from langchain_huggingface import HuggingFaceEmbeddings
|
22 |
-
|
23 |
-
# from hf_to_chroma_ds import import_into_chroma
|
24 |
from datasets import load_dataset
|
25 |
-
|
26 |
-
|
27 |
-
from
|
28 |
from tqdm import tqdm
|
29 |
|
30 |
# Global params
|
@@ -37,18 +34,18 @@ MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY")
|
|
37 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
38 |
HF_API_KEY = os.environ.get("HF_API_KEY")
|
39 |
|
40 |
-
#
|
41 |
-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
42 |
mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
|
43 |
model_emb = "mixedbread-ai/mxbai-embed-large-v1"
|
44 |
|
45 |
# Set up ChromaDB
|
46 |
memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True)
|
47 |
-
batched_ds = memoires_ds.batch(batch_size=
|
48 |
client = chromadb.Client()
|
49 |
-
collection = client.get_or_create_collection(name="embeddings_mxbai")
|
50 |
|
51 |
-
for batch in tqdm(batched_ds, desc="Processing dataset batches"):
|
52 |
collection.add(
|
53 |
ids=batch["id"],
|
54 |
metadatas=batch["metadata"],
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
|
4 |
import gradio as gr
|
5 |
+
|
6 |
from langchain_chroma import Chroma
|
7 |
from langchain.prompts import ChatPromptTemplate
|
8 |
from langchain.chains import create_retrieval_chain, create_history_aware_retriever
|
|
|
10 |
from langchain_core.prompts import MessagesPlaceholder
|
11 |
from langchain_community.chat_message_histories import ChatMessageHistory
|
12 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
|
|
|
|
|
13 |
from langchain_core.documents import Document
|
14 |
from langchain_core.retrievers import BaseRetriever
|
15 |
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
16 |
from langchain_core.vectorstores import VectorStoreRetriever
|
|
|
17 |
from langchain_openai import ChatOpenAI
|
|
|
|
|
18 |
from langchain.callbacks.tracers import ConsoleCallbackHandler
|
19 |
from langchain_huggingface import HuggingFaceEmbeddings
|
20 |
+
|
|
|
21 |
from datasets import load_dataset
|
22 |
+
import chromadb
|
23 |
+
from typing import List
|
24 |
+
from mixedbread_ai.client import MixedbreadAI
|
25 |
from tqdm import tqdm
|
26 |
|
27 |
# Global params
|
|
|
34 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
35 |
HF_API_KEY = os.environ.get("HF_API_KEY")
|
36 |
|
37 |
+
# MixedbreadAI Client
|
38 |
+
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
39 |
mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
|
40 |
model_emb = "mixedbread-ai/mxbai-embed-large-v1"
|
41 |
|
42 |
# Set up ChromaDB
|
43 |
memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True)
|
44 |
+
batched_ds = memoires_ds.batch(batch_size=50000)
|
45 |
client = chromadb.Client()
|
46 |
+
collection = client.get_or_create_collection(name="embeddings_mxbai")
|
47 |
|
48 |
+
for batch in tqdm(batched_ds, desc="Processing dataset batches"):
|
49 |
collection.add(
|
50 |
ids=batch["id"],
|
51 |
metadatas=batch["metadata"],
|