saliencies / app.py
de-Rodrigo's picture
Visualize Saliency Maps
f53adeb
import spaces
import gradio as gr
from huggingface_hub import list_models
from typing import List
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import json
import re
import logging
from datasets import load_dataset
import os
import numpy as np
from datetime import datetime
# Importar utils y save_img si no están ya importados
import utils
# Logging configuration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Paths to the static image and GIF
README_IMAGE_PATH = os.path.join("figs", "saliencies-merit-dataset.png")
GIF_PATH = os.path.join("figs", "demo-samples.gif")
# Global variables for Donut model, processor, and dataset
dataset = None
def load_merit_dataset():
global dataset
if dataset is None:
dataset = load_dataset(
"de-Rodrigo/merit", name="en-digital-seq", split="test", num_proc=8
)
return dataset
def get_image_from_dataset(index):
global dataset
if dataset is None:
dataset = load_merit_dataset()
image_data = dataset[int(index)]["image"]
return image_data
def get_collection_models(tag: str) -> List[str]:
"""Get a list of models from a specific Hugging Face collection."""
models = list_models(author="de-Rodrigo")
return [model.modelId for model in models if tag in model.tags]
def initialize_donut():
try:
donut_model = VisionEncoderDecoderModel.from_pretrained(
"de-Rodrigo/donut-merit"
)
donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit")
donut_model = donut_model.to("cuda")
logger.info("Donut model loaded successfully on GPU")
return donut_model, donut_processor
except Exception as e:
logger.error(f"Error loading Donut model: {str(e)}")
raise
def compute_saliency(outputs, pixels, donut_p, image):
token_logits = torch.stack(outputs.scores, dim=1)
token_probs = torch.softmax(token_logits, dim=-1)
token_texts = []
saliency_images = []
for token_index in range(len(token_probs[0])):
target_token_prob = token_probs[
0, token_index, outputs.sequences[0, token_index]
]
if pixels.grad is not None:
pixels.grad.zero_()
target_token_prob.backward(retain_graph=True)
saliency = pixels.grad.data.abs().squeeze().mean(dim=0)
token_id = outputs.sequences[0][token_index].item()
token_text = donut_p.tokenizer.decode([token_id])
logger.info(f"Considered sequence token: {token_text}")
safe_token_text = re.sub(r'[<>:"/\\|?*]', "_", token_text)
current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")
unique_safe_token_text = f"{safe_token_text}_{current_datetime}"
file_name = f"saliency_{unique_safe_token_text}.png"
saliency = utils.convert_tensor_to_rgba_image(saliency)
# Merge saliency image twice
saliency = utils.add_transparent_image(np.array(image), saliency)
saliency = utils.convert_rgb_to_rgba_image(saliency)
saliency = utils.add_transparent_image(np.array(image), saliency, 0.7)
saliency = utils.label_frame(saliency, token_text)
saliency_images.append(saliency)
token_texts.append(token_text)
return saliency_images, token_texts
@spaces.GPU(duration=300)
def process_image_donut(image):
try:
model, processor = initialize_donut()
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda")
pixel_values.requires_grad = True
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt"
)["input_ids"].to("cuda")
outputs = model.generate.__wrapped__(
model,
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
output_scores=True,
)
saliency_images, token_texts = compute_saliency(outputs, pixel_values, processor, image)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, ""
)
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
result = processor.token2json(sequence)
return saliency_images, json.dumps(result, indent=2)
except Exception as e:
logger.error(f"Error processing image with Donut: {str(e)}")
return None, f"Error: {str(e)}"
@spaces.GPU(duration=300)
def process_image(model_name, image=None, dataset_image_index=None):
if dataset_image_index is not None:
image = get_image_from_dataset(dataset_image_index)
if model_name == "de-Rodrigo/donut-merit":
saliency_images, result = process_image_donut(image)
else:
# Aquí deberías implementar el procesamiento para otros modelos
saliency_images, result = None, f"Processing for model {model_name} not implemented"
return saliency_images, result
def update_image(dataset_image_index):
return get_image_from_dataset(dataset_image_index)
if __name__ == "__main__":
# Load the dataset
load_merit_dataset()
models = get_collection_models("saliency")
models.append("de-Rodrigo/donut-merit")
with gr.Blocks() as demo:
gr.Markdown("# Saliency Maps with the MERIT Dataset 🎒📃🏆")
with gr.Row():
with gr.Column(scale=1):
gr.Image(value=README_IMAGE_PATH, height=400)
with gr.Column(scale=1):
gr.Image(
value=GIF_PATH, label="Dataset samples you can process", height=400
)
with gr.Tab("Introduction"):
gr.Markdown(
"""
## Welcome to Saliency Maps with the [MERIT Dataset](https://huggingface.co./datasets/de-Rodrigo/merit) 🎒📃🏆
This space demonstrates the capabilities of different Vision Language models
for document understanding tasks.
### Key Features:
- Process images from the [MERIT Dataset](https://huggingface.co./datasets/de-Rodrigo/merit) or upload your own image.
- Use a fine-tuned version of the models availabe to extract grades from documents.
- Visualize saliency maps to understand where the model is looking (WIP 🛠️).
"""
)
with gr.Tab("Try It Yourself"):
gr.Markdown(
"Select a model and an image from the dataset, or upload your own image."
)
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(choices=models, label="Select Model")
dataset_slider = gr.Slider(
minimum=0,
maximum=len(dataset) - 1,
step=1,
label="Dataset Image Index",
)
upload_image = gr.Image(
type="pil", label="Or Upload Your Own Image"
)
preview_image = gr.Image(label="Selected/Uploaded Image")
process_button = gr.Button("Process Image")
with gr.Row():
output_image = gr.Gallery(label="Processed Saliency Images")
output_text = gr.Textbox(label="Result")
# Update preview image when slider changes
dataset_slider.change(
fn=update_image, inputs=[dataset_slider], outputs=[preview_image]
)
# Update preview image when an image is uploaded
upload_image.change(
fn=lambda x: x, inputs=[upload_image], outputs=[preview_image]
)
# Process image when button is clicked
process_button.click(
fn=process_image,
inputs=[model_dropdown, upload_image, dataset_slider],
outputs=[output_image, output_text],
)
demo.launch()