Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from huggingface_hub import list_models | |
from typing import List | |
import torch | |
from transformers import DonutProcessor, VisionEncoderDecoderModel | |
from PIL import Image | |
import json | |
import re | |
import logging | |
from datasets import load_dataset | |
import os | |
import numpy as np | |
from datetime import datetime | |
# Importar utils y save_img si no están ya importados | |
import utils | |
# Logging configuration | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Paths to the static image and GIF | |
README_IMAGE_PATH = os.path.join("figs", "saliencies-merit-dataset.png") | |
GIF_PATH = os.path.join("figs", "demo-samples.gif") | |
# Global variables for Donut model, processor, and dataset | |
dataset = None | |
def load_merit_dataset(): | |
global dataset | |
if dataset is None: | |
dataset = load_dataset( | |
"de-Rodrigo/merit", name="en-digital-seq", split="test", num_proc=8 | |
) | |
return dataset | |
def get_image_from_dataset(index): | |
global dataset | |
if dataset is None: | |
dataset = load_merit_dataset() | |
image_data = dataset[int(index)]["image"] | |
return image_data | |
def get_collection_models(tag: str) -> List[str]: | |
"""Get a list of models from a specific Hugging Face collection.""" | |
models = list_models(author="de-Rodrigo") | |
return [model.modelId for model in models if tag in model.tags] | |
def initialize_donut(): | |
try: | |
donut_model = VisionEncoderDecoderModel.from_pretrained( | |
"de-Rodrigo/donut-merit" | |
) | |
donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit") | |
donut_model = donut_model.to("cuda") | |
logger.info("Donut model loaded successfully on GPU") | |
return donut_model, donut_processor | |
except Exception as e: | |
logger.error(f"Error loading Donut model: {str(e)}") | |
raise | |
def compute_saliency(outputs, pixels, donut_p, image): | |
token_logits = torch.stack(outputs.scores, dim=1) | |
token_probs = torch.softmax(token_logits, dim=-1) | |
token_texts = [] | |
saliency_images = [] | |
for token_index in range(len(token_probs[0])): | |
target_token_prob = token_probs[ | |
0, token_index, outputs.sequences[0, token_index] | |
] | |
if pixels.grad is not None: | |
pixels.grad.zero_() | |
target_token_prob.backward(retain_graph=True) | |
saliency = pixels.grad.data.abs().squeeze().mean(dim=0) | |
token_id = outputs.sequences[0][token_index].item() | |
token_text = donut_p.tokenizer.decode([token_id]) | |
logger.info(f"Considered sequence token: {token_text}") | |
safe_token_text = re.sub(r'[<>:"/\\|?*]', "_", token_text) | |
current_datetime = datetime.now().strftime("%Y%m%d%H%M%S") | |
unique_safe_token_text = f"{safe_token_text}_{current_datetime}" | |
file_name = f"saliency_{unique_safe_token_text}.png" | |
saliency = utils.convert_tensor_to_rgba_image(saliency) | |
# Merge saliency image twice | |
saliency = utils.add_transparent_image(np.array(image), saliency) | |
saliency = utils.convert_rgb_to_rgba_image(saliency) | |
saliency = utils.add_transparent_image(np.array(image), saliency, 0.7) | |
saliency = utils.label_frame(saliency, token_text) | |
saliency_images.append(saliency) | |
token_texts.append(token_text) | |
return saliency_images, token_texts | |
def process_image_donut(image): | |
try: | |
model, processor = initialize_donut() | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda") | |
pixel_values.requires_grad = True | |
task_prompt = "<s_cord-v2>" | |
decoder_input_ids = processor.tokenizer( | |
task_prompt, add_special_tokens=False, return_tensors="pt" | |
)["input_ids"].to("cuda") | |
outputs = model.generate.__wrapped__( | |
model, | |
pixel_values, | |
decoder_input_ids=decoder_input_ids, | |
max_length=model.decoder.config.max_position_embeddings, | |
early_stopping=True, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
eos_token_id=processor.tokenizer.eos_token_id, | |
use_cache=True, | |
num_beams=1, | |
bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
return_dict_in_generate=True, | |
output_scores=True, | |
) | |
saliency_images, token_texts = compute_saliency(outputs, pixel_values, processor, image) | |
sequence = processor.batch_decode(outputs.sequences)[0] | |
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace( | |
processor.tokenizer.pad_token, "" | |
) | |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() | |
result = processor.token2json(sequence) | |
return saliency_images, json.dumps(result, indent=2) | |
except Exception as e: | |
logger.error(f"Error processing image with Donut: {str(e)}") | |
return None, f"Error: {str(e)}" | |
def process_image(model_name, image=None, dataset_image_index=None): | |
if dataset_image_index is not None: | |
image = get_image_from_dataset(dataset_image_index) | |
if model_name == "de-Rodrigo/donut-merit": | |
saliency_images, result = process_image_donut(image) | |
else: | |
# Aquí deberías implementar el procesamiento para otros modelos | |
saliency_images, result = None, f"Processing for model {model_name} not implemented" | |
return saliency_images, result | |
def update_image(dataset_image_index): | |
return get_image_from_dataset(dataset_image_index) | |
if __name__ == "__main__": | |
# Load the dataset | |
load_merit_dataset() | |
models = get_collection_models("saliency") | |
models.append("de-Rodrigo/donut-merit") | |
with gr.Blocks() as demo: | |
gr.Markdown("# Saliency Maps with the MERIT Dataset 🎒📃🏆") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Image(value=README_IMAGE_PATH, height=400) | |
with gr.Column(scale=1): | |
gr.Image( | |
value=GIF_PATH, label="Dataset samples you can process", height=400 | |
) | |
with gr.Tab("Introduction"): | |
gr.Markdown( | |
""" | |
## Welcome to Saliency Maps with the [MERIT Dataset](https://huggingface.co./datasets/de-Rodrigo/merit) 🎒📃🏆 | |
This space demonstrates the capabilities of different Vision Language models | |
for document understanding tasks. | |
### Key Features: | |
- Process images from the [MERIT Dataset](https://huggingface.co./datasets/de-Rodrigo/merit) or upload your own image. | |
- Use a fine-tuned version of the models availabe to extract grades from documents. | |
- Visualize saliency maps to understand where the model is looking (WIP 🛠️). | |
""" | |
) | |
with gr.Tab("Try It Yourself"): | |
gr.Markdown( | |
"Select a model and an image from the dataset, or upload your own image." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
model_dropdown = gr.Dropdown(choices=models, label="Select Model") | |
dataset_slider = gr.Slider( | |
minimum=0, | |
maximum=len(dataset) - 1, | |
step=1, | |
label="Dataset Image Index", | |
) | |
upload_image = gr.Image( | |
type="pil", label="Or Upload Your Own Image" | |
) | |
preview_image = gr.Image(label="Selected/Uploaded Image") | |
process_button = gr.Button("Process Image") | |
with gr.Row(): | |
output_image = gr.Gallery(label="Processed Saliency Images") | |
output_text = gr.Textbox(label="Result") | |
# Update preview image when slider changes | |
dataset_slider.change( | |
fn=update_image, inputs=[dataset_slider], outputs=[preview_image] | |
) | |
# Update preview image when an image is uploaded | |
upload_image.change( | |
fn=lambda x: x, inputs=[upload_image], outputs=[preview_image] | |
) | |
# Process image when button is clicked | |
process_button.click( | |
fn=process_image, | |
inputs=[model_dropdown, upload_image, dataset_slider], | |
outputs=[output_image, output_text], | |
) | |
demo.launch() | |