Spaces:
Sleeping
Sleeping
collection
Browse files
app.py
CHANGED
@@ -20,12 +20,12 @@ from mixedbread_ai.client import MixedbreadAI
|
|
20 |
from langchain.callbacks.tracers import ConsoleCallbackHandler
|
21 |
from langchain_huggingface import HuggingFaceEmbeddings
|
22 |
import os
|
23 |
-
# from
|
24 |
-
from hf_to_chroma_ds import import_into_chroma
|
25 |
from datasets import load_dataset
|
26 |
from chromadb.utils import embedding_functions
|
27 |
-
from hf_to_chroma_ds import Memoires_DS
|
28 |
from dotenv import load_dotenv
|
|
|
29 |
|
30 |
# Global params
|
31 |
CHROMA_PATH = "chromadb_mem10_mxbai_800_complete"
|
@@ -42,29 +42,20 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
42 |
mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
|
43 |
model_emb = "mixedbread-ai/mxbai-embed-large-v1"
|
44 |
|
45 |
-
huggingface_ef = embedding_functions.huggingface_embedding_function.HuggingFaceEmbeddingFunction(
|
46 |
-
api_key=HF_API_KEY,
|
47 |
-
model_name=model_emb
|
48 |
-
)
|
49 |
-
|
50 |
# Set up ChromaDB
|
|
|
|
|
51 |
client = chromadb.Client()
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
# )
|
61 |
-
|
62 |
-
|
63 |
-
collection = import_into_chroma(
|
64 |
-
chroma_client=client,
|
65 |
-
dataset=Memoires_DS,
|
66 |
-
embedding_function=huggingface_ef #Memoires_DS.embedding_function
|
67 |
)
|
|
|
68 |
|
69 |
db = Chroma(
|
70 |
client=client,
|
@@ -83,7 +74,7 @@ class Reranker(BaseRetriever):
|
|
83 |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
84 |
) -> List[Document]:
|
85 |
docs = self.retriever.invoke(query)
|
86 |
-
results = mxbai_client.reranking(model=
|
87 |
return [Document(page_content=res.input) for res in results.data]
|
88 |
|
89 |
# Set up reranker + LLM
|
|
|
20 |
from langchain.callbacks.tracers import ConsoleCallbackHandler
|
21 |
from langchain_huggingface import HuggingFaceEmbeddings
|
22 |
import os
|
23 |
+
# from hf_to_chroma_ds import import_into_chroma
|
|
|
24 |
from datasets import load_dataset
|
25 |
from chromadb.utils import embedding_functions
|
26 |
+
# from hf_to_chroma_ds import Memoires_DS
|
27 |
from dotenv import load_dotenv
|
28 |
+
from tqdm import tqdm
|
29 |
|
30 |
# Global params
|
31 |
CHROMA_PATH = "chromadb_mem10_mxbai_800_complete"
|
|
|
42 |
mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
|
43 |
model_emb = "mixedbread-ai/mxbai-embed-large-v1"
|
44 |
|
|
|
|
|
|
|
|
|
|
|
45 |
# Set up ChromaDB
|
46 |
+
memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True)
|
47 |
+
batched_ds = memoires_ds.batch(batch_size=40000)
|
48 |
client = chromadb.Client()
|
49 |
+
collection = client.get_or_create_collection(name="embeddings_mxbai") #, embedding_function=HuggingFaceEmbeddings(model_name=model_emb))
|
50 |
+
|
51 |
+
for batch in tqdm(batched_ds, desc="Processing dataset batches"): #, total=len(batched_ds)):
|
52 |
+
collection.add(
|
53 |
+
ids=batch["id"],
|
54 |
+
metadatas=batch["metadata"],
|
55 |
+
documents=batch["document"],
|
56 |
+
embeddings=batch["embedding"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
)
|
58 |
+
print(f"Collection complete: {collection.count()}")
|
59 |
|
60 |
db = Chroma(
|
61 |
client=client,
|
|
|
74 |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
75 |
) -> List[Document]:
|
76 |
docs = self.retriever.invoke(query)
|
77 |
+
results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k)
|
78 |
return [Document(page_content=res.input) for res in results.data]
|
79 |
|
80 |
# Set up reranker + LLM
|