olamidegoriola pvanand commited on
Commit
2373fa4
1 Parent(s): 815ddfe

perform vector search using FASS (#7)

Browse files

- perform vector search using FASS (f5d9bbef4679155c5b861f7a8d65521e4b96159c)


Co-authored-by: Anand <[email protected]>

Files changed (1) hide show
  1. actions/search_content.py +61 -0
actions/search_content.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # search_content.py
2
+
3
+ import faiss
4
+ import pandas as pd
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+ # Define paths for model, Faiss index, and data file
8
+ MODEL_SAVE_PATH = "all-distilroberta-v1-model.pkl"
9
+ FAISS_INDEX_FILE_PATH = "index.faiss"
10
+ DATA_FILE_PATH = "omdena_qna_dataset/omdena_faq_training_data.csv"
11
+
12
+ def load_transformer_model(model_file):
13
+ """Load a sentence transformer model from a file."""
14
+ return SentenceTransformer.load(model_file)
15
+
16
+ def load_faiss_index(filename):
17
+ """Load a Faiss index from a file."""
18
+ return faiss.read_index(filename)
19
+
20
+ def load_data(file_path):
21
+ """Load data from a CSV file and preprocess it."""
22
+ data_frame = pd.read_csv(file_path)
23
+ data_frame["id"] = data_frame.index
24
+
25
+ # Create a 'QNA' column that combines 'Questions' and 'Answers'
26
+ data_frame['QNA'] = data_frame.apply(lambda row: f"Question: {row['Questions']}, Answer: {row['Answers']}", axis=1)
27
+ return data_frame.set_index(["id"], drop=False)
28
+
29
+ def search_content(query, data_frame_indexed, transformer_model, faiss_index, k=5):
30
+ """Search the content using a query and return the top k results."""
31
+ # Encode the query using the model
32
+ query_vector = transformer_model.encode([query])
33
+
34
+ # Normalize the query vector
35
+ faiss.normalize_L2(query_vector)
36
+
37
+ # Search the Faiss index using the query vector
38
+ top_k = faiss_index.search(query_vector, k)
39
+
40
+ # Extract the IDs and similarities of the top k results
41
+ ids = top_k[1][0].tolist()
42
+ similarities = top_k[0][0].tolist()
43
+
44
+ # Get the corresponding results from the data frame
45
+ results = data_frame_indexed.loc[ids]
46
+
47
+ # Add a column for the similarities
48
+ results["similarities"] = similarities
49
+ return results
50
+
51
+ def main_search(query):
52
+ """Main function to execute the search."""
53
+ transformer_model = load_transformer_model(MODEL_SAVE_PATH)
54
+ faiss_index = load_faiss_index(FAISS_INDEX_FILE_PATH)
55
+ data_frame_indexed = load_data(DATA_FILE_PATH)
56
+ results = search_content(query, data_frame_indexed, transformer_model, faiss_index)
57
+ return results['QNA'] # return the results
58
+
59
+ if __name__ == "__main__":
60
+ query = "school courses"
61
+ print(main_search(query)) # print the results if this script is run directly