cheesyFishes commited on
Commit
d072d89
·
1 Parent(s): a78f8a0

make models singleton

Browse files
Files changed (1) hide show
  1. app.py +40 -24
app.py CHANGED
@@ -37,13 +37,33 @@ text_embed_model = HuggingFaceEmbedding(
37
  embed_batch_size=1,
38
  )
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def load_index(index_path: str) -> MultiModalVectorStoreIndex:
41
- storage_context = StorageContext.from_defaults(persist_dir=index_path)
42
- return load_index_from_storage(
43
- storage_context,
44
- embed_model=text_embed_model,
45
- image_embed_model=image_embed_model,
46
- )
47
 
48
  def create_index(file, llama_parse_key, progress=gr.Progress()):
49
  if not file or not llama_parse_key:
@@ -88,14 +108,17 @@ def create_index(file, llama_parse_key, progress=gr.Progress()):
88
  )
89
 
90
  progress(1.0, desc="Complete!")
91
- return index, "Index created successfully!"
 
92
 
93
  except Exception as e:
94
- return None, f"Error creating index: {str(e)}"
95
 
96
- def run_search(index, query, text_top_k, image_top_k):
 
97
  if not index:
98
  return "Please create or select an index first.", [], []
 
99
  retriever = index.as_retriever(
100
  similarity_top_k=text_top_k,
101
  image_similarity_top_k=image_top_k,
@@ -194,41 +217,34 @@ Processing will take a few minutes when creating a new index, depending on the s
194
  elem_id="text_results"
195
  )
196
 
197
- # State
198
- index_state = gr.State()
199
-
200
- # Load default index on startup
201
- default_index = load_index(example_indexes["IONIQ 2024"])
202
- index_state.value = default_index
203
-
204
  # Event handlers
205
  def load_existing_index(index_name):
206
  if index_name:
207
  try:
208
- index = load_index(example_indexes[index_name])
209
- return index, f"Loaded index: {index_name}"
210
  except Exception as e:
211
- return None, f"Error loading index: {str(e)}"
212
- return None, "No index selected"
213
 
214
  existing_index_dropdown.change(
215
  fn=load_existing_index,
216
  inputs=[existing_index_dropdown],
217
- outputs=[index_state, create_status],
218
  api_name=False
219
  )
220
 
221
  create_btn.click(
222
  fn=create_index,
223
  inputs=[file_upload, llama_parse_key],
224
- outputs=[index_state, create_status],
225
  api_name=False,
226
- show_progress=True # Enable progress bar
227
  )
228
 
229
  search_btn.click(
230
  fn=run_search,
231
- inputs=[index_state, query_input, text_top_k, image_top_k],
232
  outputs=[status_output, text_output, image_output],
233
  api_name=False
234
  )
 
37
  embed_batch_size=1,
38
  )
39
 
40
+ class IndexManager:
41
+ """Avoids deepcopying the index object in gr.State"""
42
+ def __init__(self):
43
+ self.current_index = None
44
+ # Initialize with default index
45
+ self.load_index(example_indexes["IONIQ 2024"])
46
+
47
+ def load_index(self, index_path):
48
+ storage_context = StorageContext.from_defaults(persist_dir=index_path)
49
+ self.current_index = load_index_from_storage(
50
+ storage_context,
51
+ embed_model=text_embed_model,
52
+ image_embed_model=image_embed_model,
53
+ )
54
+ return f"Loaded index: {index_path}"
55
+
56
+ def set_index(self, index):
57
+ self.current_index = index
58
+
59
+ def get_index(self):
60
+ return self.current_index
61
+
62
+ index_manager = IndexManager()
63
+
64
  def load_index(index_path: str) -> MultiModalVectorStoreIndex:
65
+ index_manager.load_index(index_path)
66
+ return index_manager.get_index()
 
 
 
 
67
 
68
  def create_index(file, llama_parse_key, progress=gr.Progress()):
69
  if not file or not llama_parse_key:
 
108
  )
109
 
110
  progress(1.0, desc="Complete!")
111
+ index_manager.set_index(index)
112
+ return "Index created successfully!"
113
 
114
  except Exception as e:
115
+ return f"Error creating index: {str(e)}"
116
 
117
+ def run_search(query, text_top_k, image_top_k):
118
+ index = index_manager.get_index()
119
  if not index:
120
  return "Please create or select an index first.", [], []
121
+
122
  retriever = index.as_retriever(
123
  similarity_top_k=text_top_k,
124
  image_similarity_top_k=image_top_k,
 
217
  elem_id="text_results"
218
  )
219
 
 
 
 
 
 
 
 
220
  # Event handlers
221
  def load_existing_index(index_name):
222
  if index_name:
223
  try:
224
+ index_manager.load_index(example_indexes[index_name])
225
+ return f"Loaded index: {index_name}"
226
  except Exception as e:
227
+ return f"Error loading index: {str(e)}"
228
+ return "No index selected"
229
 
230
  existing_index_dropdown.change(
231
  fn=load_existing_index,
232
  inputs=[existing_index_dropdown],
233
+ outputs=[create_status],
234
  api_name=False
235
  )
236
 
237
  create_btn.click(
238
  fn=create_index,
239
  inputs=[file_upload, llama_parse_key],
240
+ outputs=[create_status],
241
  api_name=False,
242
+ show_progress=True
243
  )
244
 
245
  search_btn.click(
246
  fn=run_search,
247
+ inputs=[query_input, text_top_k, image_top_k],
248
  outputs=[status_output, text_output, image_output],
249
  api_name=False
250
  )