cheesyFishes's picture
improve loading experience
050040c
raw
history blame
9.03 kB
import gradio as gr
import os
import torch
from llama_parse import LlamaParse
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.core.indices import MultiModalVectorStoreIndex
from llama_index.core.schema import Document, ImageDocument
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
example_indexes = {
"IONIQ 2024": "./iconiq_report_index",
"Uber 10k 2021": "./uber_index",
}
DEFAULT_INDEX = "IONIQ 2024"
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
image_embed_model = HuggingFaceEmbedding(
model_name="llamaindex/vdr-2b-multi-v1",
device=device,
trust_remote_code=True,
token=os.getenv("HUGGINGFACE_TOKEN"),
model_kwargs={"torch_dtype": torch.float16},
embed_batch_size=2,
)
text_embed_model = HuggingFaceEmbedding(
model_name="BAAI/bge-small-en",
device=device,
trust_remote_code=True,
token=os.getenv("HUGGINGFACE_TOKEN"),
embed_batch_size=1,
)
class IndexManager:
"""Avoids deepcopying the index object in gr.State"""
def __init__(self):
self.current_index = None
# Initialize with default index
self.load_index(example_indexes[DEFAULT_INDEX])
def load_index(self, index_path):
storage_context = StorageContext.from_defaults(persist_dir=index_path)
self.current_index = load_index_from_storage(
storage_context,
embed_model=text_embed_model,
image_embed_model=image_embed_model,
)
return f"Loaded index: {index_path}"
def set_index(self, index):
self.current_index = index
def get_index(self):
return self.current_index
index_manager = IndexManager()
def load_index(index_path: str) -> MultiModalVectorStoreIndex:
index_manager.load_index(index_path)
return index_manager.get_index()
def create_index(file, llama_parse_key, progress=gr.Progress()):
if not file or not llama_parse_key:
return None, "Please provide both a file and LlamaParse API key"
try:
progress(0, desc="Initializing LlamaParse...")
parser = LlamaParse(
api_key=llama_parse_key,
take_screenshot=True,
)
# Process document
progress(0.2, desc="Processing document with LlamaParse...")
md_json_obj = parser.get_json_result(file.name)[0]
progress(0.4, desc="Downloading and processing images...")
image_dicts = parser.get_images(
[md_json_obj],
download_path=os.path.join(os.path.dirname(file.name), f"{file.name}_images")
)
# Create text document
progress(0.6, desc="Creating text documents...")
text = ""
for page in md_json_obj["pages"]:
text += page["md"] + "\n\n"
text_docs = [Document(text=text.strip())]
# Create image documents
progress(0.8, desc="Creating image documents...")
image_docs = []
for image_dict in image_dicts:
image_docs.append(ImageDocument(text=image_dict["name"], image_path=image_dict["path"]))
# Create index
progress(0.9, desc="Creating final index...")
index = MultiModalVectorStoreIndex.from_documents(
text_docs + image_docs,
embed_model=text_embed_model,
image_embed_model=image_embed_model,
)
progress(1.0, desc="Complete!")
index_manager.set_index(index)
return "Index created successfully!"
except Exception as e:
return f"Error creating index: {str(e)}"
def run_search(query, text_top_k, image_top_k):
index = index_manager.get_index()
if not index:
return "Please create or select an index first.", [], []
retriever = index.as_retriever(
similarity_top_k=text_top_k,
image_similarity_top_k=image_top_k,
)
image_nodes = retriever.text_to_image_retrieve(query)
text_nodes = retriever.text_retrieve(query)
# Extract text and scores from nodes
text_results = [{"text": node.text, "score": f"{node.score:.3f}"} for node in text_nodes]
# Load images and scores
image_results = []
for node in image_nodes:
if hasattr(node.node, 'image_path') and os.path.exists(node.node.image_path):
try:
image_results.append((
node.node.image_path,
f"Similarity: {node.score:.3f}",
))
except Exception as e:
print(f"Error loading image {node.node.image_path}: {e}")
return "Search completed!", text_results, image_results
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Multi-Modal Retrieval with LlamaIndex and llamaindex/vdr-2b-multi-v1")
gr.Markdown("""
This demo shows how to use the new `llamaindex/vdr-2b-multi-v1` model for multi-modal document search.
Using this model, we can index images and perform text-to-image retrieval.
This demo compares to pure text retrieval using the `BAAI/bge-small-en` model. Is this a fair comparison? Not really,
but it's the easiest to run in a limited huggingface space, and shows the strengths of screenshot-based retrieval.
"""
)
with gr.Row():
with gr.Column():
# Index selection/creation
with gr.Tab("Use Existing Index"):
existing_index_dropdown = gr.Dropdown(
choices=list(example_indexes.keys()),
label="Select Pre-made Index",
value=list(example_indexes.keys())[0]
)
with gr.Tab("Create New Index"):
gr.Markdown(
"""
To create a new index, enter your LlamaParse API key and upload a PDF.
You can get a free API key by signing up [here](https://cloud.llamaindex.ai).
Processing will take a few minutes when creating a new index, depending on the size of the document.
"""
)
file_upload = gr.File(label="Upload PDF")
llama_parse_key = gr.Textbox(
label="LlamaParse API Key",
type="password"
)
create_btn = gr.Button("Create Index")
create_status = gr.Textbox(label="Status", interactive=False)
# Search controls
query_input = gr.Textbox(label="Search Query", value="What is the Executive Summary?")
with gr.Row():
text_top_k = gr.Slider(
minimum=1,
maximum=10,
value=2,
step=1,
label="Text Top-K"
)
image_top_k = gr.Slider(
minimum=1,
maximum=10,
value=2,
step=1,
label="Image Top-K"
)
search_btn = gr.Button("Search")
with gr.Column():
# Results display
status_output = gr.Textbox(label="Search Status")
image_output = gr.Gallery(
label="Retrieved Images",
show_label=True, # This will show the similarity score captions
elem_id="gallery"
)
text_output = gr.JSON(
label="Retrieved Text with Similarity Scores",
elem_id="text_results"
)
# Event handlers
def load_existing_index(index_name, progress=gr.Progress()):
if index_name:
try:
progress(0, desc="Loading index...")
result = index_manager.load_index(example_indexes[index_name])
progress(1.0, desc="Index loaded!")
return result, None
except Exception as e:
return f"Error loading index: {str(e)}", None
return "No index selected", None
existing_index_dropdown.change(
fn=load_existing_index,
inputs=[existing_index_dropdown],
outputs=[create_status, query_input],
api_name=False,
show_progress=True
)
create_btn.click(
fn=create_index,
inputs=[file_upload, llama_parse_key],
outputs=[create_status],
api_name=False,
show_progress=True
)
search_btn.click(
fn=run_search,
inputs=[query_input, text_top_k, image_top_k],
outputs=[status_output, text_output, image_output],
api_name=False
)
gr.Markdown("""
This demo was built with [LlamaIndex](https://docs.llamaindex.ai) and [LlamaParse](https://cloud.llamaindex.ai). To see more multi-modal demos, check out the [llama parse examples](https://github.com/run-llama/llama_parse/tree/main/examples/multimodal).
"""
)
if __name__ == "__main__":
# Running locally
demo.launch()