Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
f53adeb
1
Parent(s):
bedfdc1
Visualize Saliency Maps
Browse files- __pycache__/utils.cpython-310.pyc +0 -0
- app.py +74 -36
- utils.py +157 -0
__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.56 kB). View file
|
|
app.py
CHANGED
@@ -10,22 +10,22 @@ import re
|
|
10 |
import logging
|
11 |
from datasets import load_dataset
|
12 |
import os
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# Logging configuration
|
15 |
logging.basicConfig(level=logging.INFO)
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
18 |
-
|
19 |
# Paths to the static image and GIF
|
20 |
README_IMAGE_PATH = os.path.join("figs", "saliencies-merit-dataset.png")
|
21 |
GIF_PATH = os.path.join("figs", "demo-samples.gif")
|
22 |
|
23 |
# Global variables for Donut model, processor, and dataset
|
24 |
-
donut_model = None
|
25 |
-
donut_processor = None
|
26 |
dataset = None
|
27 |
|
28 |
-
|
29 |
def load_merit_dataset():
|
30 |
global dataset
|
31 |
if dataset is None:
|
@@ -34,7 +34,6 @@ def load_merit_dataset():
|
|
34 |
)
|
35 |
return dataset
|
36 |
|
37 |
-
|
38 |
def get_image_from_dataset(index):
|
39 |
global dataset
|
40 |
if dataset is None:
|
@@ -42,44 +41,84 @@ def get_image_from_dataset(index):
|
|
42 |
image_data = dataset[int(index)]["image"]
|
43 |
return image_data
|
44 |
|
45 |
-
|
46 |
def get_collection_models(tag: str) -> List[str]:
|
47 |
"""Get a list of models from a specific Hugging Face collection."""
|
48 |
models = list_models(author="de-Rodrigo")
|
49 |
return [model.modelId for model in models if tag in model.tags]
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
except Exception as e:
|
64 |
-
logger.error(f"Error loading Donut model: {str(e)}")
|
65 |
-
raise
|
66 |
-
return donut_model, donut_processor
|
67 |
|
|
|
|
|
|
|
68 |
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
try:
|
|
|
|
|
72 |
if not isinstance(image, Image.Image):
|
73 |
image = Image.fromarray(image)
|
74 |
|
75 |
pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda")
|
|
|
76 |
|
77 |
task_prompt = "<s_cord-v2>"
|
78 |
decoder_input_ids = processor.tokenizer(
|
79 |
task_prompt, add_special_tokens=False, return_tensors="pt"
|
80 |
)["input_ids"].to("cuda")
|
81 |
|
82 |
-
outputs = model.generate(
|
|
|
83 |
pixel_values,
|
84 |
decoder_input_ids=decoder_input_ids,
|
85 |
max_length=model.decoder.config.max_position_embeddings,
|
@@ -90,8 +129,11 @@ def process_image_donut(model, processor, image):
|
|
90 |
num_beams=1,
|
91 |
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
92 |
return_dict_in_generate=True,
|
|
|
93 |
)
|
94 |
|
|
|
|
|
95 |
sequence = processor.batch_decode(outputs.sequences)[0]
|
96 |
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
|
97 |
processor.tokenizer.pad_token, ""
|
@@ -99,31 +141,27 @@ def process_image_donut(model, processor, image):
|
|
99 |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
|
100 |
|
101 |
result = processor.token2json(sequence)
|
102 |
-
return json.dumps(result, indent=2)
|
103 |
except Exception as e:
|
104 |
logger.error(f"Error processing image with Donut: {str(e)}")
|
105 |
-
return f"Error: {str(e)}"
|
106 |
|
107 |
-
|
108 |
-
@spaces.GPU
|
109 |
def process_image(model_name, image=None, dataset_image_index=None):
|
110 |
if dataset_image_index is not None:
|
111 |
image = get_image_from_dataset(dataset_image_index)
|
112 |
|
113 |
if model_name == "de-Rodrigo/donut-merit":
|
114 |
-
|
115 |
-
result = process_image_donut(model, processor, image)
|
116 |
else:
|
117 |
-
#
|
118 |
-
result = f"Processing for model {model_name} not implemented"
|
119 |
-
|
120 |
-
return image, result
|
121 |
|
|
|
122 |
|
123 |
def update_image(dataset_image_index):
|
124 |
return get_image_from_dataset(dataset_image_index)
|
125 |
|
126 |
-
|
127 |
if __name__ == "__main__":
|
128 |
# Load the dataset
|
129 |
load_merit_dataset()
|
@@ -180,7 +218,7 @@ if __name__ == "__main__":
|
|
180 |
process_button = gr.Button("Process Image")
|
181 |
|
182 |
with gr.Row():
|
183 |
-
output_image = gr.
|
184 |
output_text = gr.Textbox(label="Result")
|
185 |
|
186 |
# Update preview image when slider changes
|
|
|
10 |
import logging
|
11 |
from datasets import load_dataset
|
12 |
import os
|
13 |
+
import numpy as np
|
14 |
+
from datetime import datetime
|
15 |
+
# Importar utils y save_img si no están ya importados
|
16 |
+
import utils
|
17 |
|
18 |
# Logging configuration
|
19 |
logging.basicConfig(level=logging.INFO)
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
|
|
22 |
# Paths to the static image and GIF
|
23 |
README_IMAGE_PATH = os.path.join("figs", "saliencies-merit-dataset.png")
|
24 |
GIF_PATH = os.path.join("figs", "demo-samples.gif")
|
25 |
|
26 |
# Global variables for Donut model, processor, and dataset
|
|
|
|
|
27 |
dataset = None
|
28 |
|
|
|
29 |
def load_merit_dataset():
|
30 |
global dataset
|
31 |
if dataset is None:
|
|
|
34 |
)
|
35 |
return dataset
|
36 |
|
|
|
37 |
def get_image_from_dataset(index):
|
38 |
global dataset
|
39 |
if dataset is None:
|
|
|
41 |
image_data = dataset[int(index)]["image"]
|
42 |
return image_data
|
43 |
|
|
|
44 |
def get_collection_models(tag: str) -> List[str]:
|
45 |
"""Get a list of models from a specific Hugging Face collection."""
|
46 |
models = list_models(author="de-Rodrigo")
|
47 |
return [model.modelId for model in models if tag in model.tags]
|
48 |
|
49 |
+
def initialize_donut():
|
50 |
+
try:
|
51 |
+
donut_model = VisionEncoderDecoderModel.from_pretrained(
|
52 |
+
"de-Rodrigo/donut-merit"
|
53 |
+
)
|
54 |
+
donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit")
|
55 |
+
donut_model = donut_model.to("cuda")
|
56 |
+
logger.info("Donut model loaded successfully on GPU")
|
57 |
+
return donut_model, donut_processor
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Error loading Donut model: {str(e)}")
|
60 |
+
raise
|
61 |
+
|
62 |
+
def compute_saliency(outputs, pixels, donut_p, image):
|
63 |
+
token_logits = torch.stack(outputs.scores, dim=1)
|
64 |
+
token_probs = torch.softmax(token_logits, dim=-1)
|
65 |
+
token_texts = []
|
66 |
+
saliency_images = []
|
67 |
|
68 |
+
for token_index in range(len(token_probs[0])):
|
69 |
+
target_token_prob = token_probs[
|
70 |
+
0, token_index, outputs.sequences[0, token_index]
|
71 |
+
]
|
72 |
+
|
73 |
+
if pixels.grad is not None:
|
74 |
+
pixels.grad.zero_()
|
75 |
+
|
76 |
+
target_token_prob.backward(retain_graph=True)
|
77 |
+
|
78 |
+
saliency = pixels.grad.data.abs().squeeze().mean(dim=0)
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
token_id = outputs.sequences[0][token_index].item()
|
81 |
+
token_text = donut_p.tokenizer.decode([token_id])
|
82 |
+
logger.info(f"Considered sequence token: {token_text}")
|
83 |
|
84 |
+
safe_token_text = re.sub(r'[<>:"/\\|?*]', "_", token_text)
|
85 |
+
current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")
|
86 |
+
|
87 |
+
unique_safe_token_text = f"{safe_token_text}_{current_datetime}"
|
88 |
+
file_name = f"saliency_{unique_safe_token_text}.png"
|
89 |
+
|
90 |
+
saliency = utils.convert_tensor_to_rgba_image(saliency)
|
91 |
+
|
92 |
+
# Merge saliency image twice
|
93 |
+
saliency = utils.add_transparent_image(np.array(image), saliency)
|
94 |
+
saliency = utils.convert_rgb_to_rgba_image(saliency)
|
95 |
+
saliency = utils.add_transparent_image(np.array(image), saliency, 0.7)
|
96 |
+
|
97 |
+
saliency = utils.label_frame(saliency, token_text)
|
98 |
+
|
99 |
+
saliency_images.append(saliency)
|
100 |
+
token_texts.append(token_text)
|
101 |
+
|
102 |
+
return saliency_images, token_texts
|
103 |
+
|
104 |
+
@spaces.GPU(duration=300)
|
105 |
+
def process_image_donut(image):
|
106 |
try:
|
107 |
+
model, processor = initialize_donut()
|
108 |
+
|
109 |
if not isinstance(image, Image.Image):
|
110 |
image = Image.fromarray(image)
|
111 |
|
112 |
pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda")
|
113 |
+
pixel_values.requires_grad = True
|
114 |
|
115 |
task_prompt = "<s_cord-v2>"
|
116 |
decoder_input_ids = processor.tokenizer(
|
117 |
task_prompt, add_special_tokens=False, return_tensors="pt"
|
118 |
)["input_ids"].to("cuda")
|
119 |
|
120 |
+
outputs = model.generate.__wrapped__(
|
121 |
+
model,
|
122 |
pixel_values,
|
123 |
decoder_input_ids=decoder_input_ids,
|
124 |
max_length=model.decoder.config.max_position_embeddings,
|
|
|
129 |
num_beams=1,
|
130 |
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
131 |
return_dict_in_generate=True,
|
132 |
+
output_scores=True,
|
133 |
)
|
134 |
|
135 |
+
saliency_images, token_texts = compute_saliency(outputs, pixel_values, processor, image)
|
136 |
+
|
137 |
sequence = processor.batch_decode(outputs.sequences)[0]
|
138 |
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
|
139 |
processor.tokenizer.pad_token, ""
|
|
|
141 |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
|
142 |
|
143 |
result = processor.token2json(sequence)
|
144 |
+
return saliency_images, json.dumps(result, indent=2)
|
145 |
except Exception as e:
|
146 |
logger.error(f"Error processing image with Donut: {str(e)}")
|
147 |
+
return None, f"Error: {str(e)}"
|
148 |
|
149 |
+
@spaces.GPU(duration=300)
|
|
|
150 |
def process_image(model_name, image=None, dataset_image_index=None):
|
151 |
if dataset_image_index is not None:
|
152 |
image = get_image_from_dataset(dataset_image_index)
|
153 |
|
154 |
if model_name == "de-Rodrigo/donut-merit":
|
155 |
+
saliency_images, result = process_image_donut(image)
|
|
|
156 |
else:
|
157 |
+
# Aquí deberías implementar el procesamiento para otros modelos
|
158 |
+
saliency_images, result = None, f"Processing for model {model_name} not implemented"
|
|
|
|
|
159 |
|
160 |
+
return saliency_images, result
|
161 |
|
162 |
def update_image(dataset_image_index):
|
163 |
return get_image_from_dataset(dataset_image_index)
|
164 |
|
|
|
165 |
if __name__ == "__main__":
|
166 |
# Load the dataset
|
167 |
load_merit_dataset()
|
|
|
218 |
process_button = gr.Button("Process Image")
|
219 |
|
220 |
with gr.Row():
|
221 |
+
output_image = gr.Gallery(label="Processed Saliency Images")
|
222 |
output_text = gr.Textbox(label="Result")
|
223 |
|
224 |
# Update preview image when slider changes
|
utils.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import os
|
3 |
+
import glob
|
4 |
+
import numpy as np
|
5 |
+
from datetime import datetime
|
6 |
+
|
7 |
+
|
8 |
+
def add_transparent_image(
|
9 |
+
background, foreground, alpha_factor=1.0, x_offset=None, y_offset=None
|
10 |
+
):
|
11 |
+
"""
|
12 |
+
Function sourced from StackOverflow contributor Ben.
|
13 |
+
|
14 |
+
This function was found on StackOverflow and is the work of Ben, a contributor
|
15 |
+
to the community. We are thankful for Ben's assistance by providing this useful
|
16 |
+
method.
|
17 |
+
|
18 |
+
Original Source:
|
19 |
+
https://stackoverflow.com/questions/40895785/
|
20 |
+
using-opencv-to-overlay-transparent-image-onto-another-image
|
21 |
+
"""
|
22 |
+
|
23 |
+
bg_h, bg_w, bg_channels = background.shape
|
24 |
+
fg_h, fg_w, fg_channels = foreground.shape
|
25 |
+
|
26 |
+
assert (
|
27 |
+
bg_channels == 3
|
28 |
+
), f"background image should have exactly 3 channels (RGB). found:{bg_channels}"
|
29 |
+
assert (
|
30 |
+
fg_channels == 4
|
31 |
+
), f"foreground image should have exactly 4 channels (RGBA). found:{fg_channels}"
|
32 |
+
|
33 |
+
# center by default
|
34 |
+
if x_offset is None:
|
35 |
+
x_offset = (bg_w - fg_w) // 2
|
36 |
+
if y_offset is None:
|
37 |
+
y_offset = (bg_h - fg_h) // 2
|
38 |
+
|
39 |
+
w = min(fg_w, bg_w, fg_w + x_offset, bg_w - x_offset)
|
40 |
+
h = min(fg_h, bg_h, fg_h + y_offset, bg_h - y_offset)
|
41 |
+
|
42 |
+
if w < 1 or h < 1:
|
43 |
+
return
|
44 |
+
|
45 |
+
# clip foreground and background images to the overlapping regions
|
46 |
+
bg_x = max(0, x_offset)
|
47 |
+
bg_y = max(0, y_offset)
|
48 |
+
fg_x = max(0, x_offset * -1)
|
49 |
+
fg_y = max(0, y_offset * -1)
|
50 |
+
foreground = foreground[fg_y : fg_y + h, fg_x : fg_x + w]
|
51 |
+
background_subsection = background[bg_y : bg_y + h, bg_x : bg_x + w]
|
52 |
+
|
53 |
+
# separate alpha and color channels from the foreground image
|
54 |
+
foreground_colors = foreground[:, :, :3]
|
55 |
+
foreground_colors = cv2.cvtColor(foreground_colors, cv2.COLOR_BGR2RGB)
|
56 |
+
alpha_channel = foreground[:, :, 3] / 255 * alpha_factor # 0-255 => 0.0-1.0
|
57 |
+
|
58 |
+
# construct an alpha_mask that matches the image shape
|
59 |
+
alpha_mask = np.dstack((alpha_channel, alpha_channel, alpha_channel))
|
60 |
+
|
61 |
+
# combine the background with the overlay image weighted by alpha
|
62 |
+
composite = (
|
63 |
+
background_subsection * (1 - alpha_mask) + foreground_colors * alpha_mask
|
64 |
+
)
|
65 |
+
|
66 |
+
# overwrite the section of the background image that has been updated
|
67 |
+
background[bg_y : bg_y + h, bg_x : bg_x + w] = composite
|
68 |
+
|
69 |
+
return background
|
70 |
+
|
71 |
+
|
72 |
+
def convert_tensor_to_rgba_image(tensor):
|
73 |
+
|
74 |
+
saliency_array = tensor.cpu().numpy()
|
75 |
+
|
76 |
+
# Normalize img a 0-255
|
77 |
+
if saliency_array.dtype != np.uint8:
|
78 |
+
saliency_array = (255 * saliency_array / saliency_array.max()).astype(np.uint8)
|
79 |
+
|
80 |
+
heatmap = cv2.applyColorMap(saliency_array, cv2.COLORMAP_JET)
|
81 |
+
|
82 |
+
# Pixels are transparent where no saliency [128, 0, 0] is black in COLORMAP_JET
|
83 |
+
alpha_channel = np.ones(heatmap.shape[:2], dtype=heatmap.dtype) * 255
|
84 |
+
black_pixels_mask = np.all(heatmap == [128, 0, 0], axis=-1)
|
85 |
+
alpha_channel[black_pixels_mask] = 0
|
86 |
+
|
87 |
+
# Combinar los canales RGB y alfa
|
88 |
+
saliency_rgba = cv2.merge((heatmap, alpha_channel))
|
89 |
+
|
90 |
+
return saliency_rgba
|
91 |
+
|
92 |
+
|
93 |
+
def convert_rgb_to_rgba_image(image):
|
94 |
+
|
95 |
+
alpha_channel = np.ones(image.shape[:2], dtype=image.dtype) * 255
|
96 |
+
rbga = cv2.merge((cv2.cvtColor(image, cv2.COLOR_RGB2BGR), alpha_channel))
|
97 |
+
|
98 |
+
return rbga
|
99 |
+
|
100 |
+
|
101 |
+
def label_frame(image, token):
|
102 |
+
|
103 |
+
# Add the text
|
104 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
105 |
+
font_scale = 0.7
|
106 |
+
text_color = (255, 255, 255)
|
107 |
+
text_thickness = 1
|
108 |
+
text_size, _ = cv2.getTextSize(token, font, font_scale, text_thickness)
|
109 |
+
text_position = (10, 10 + text_size[1])
|
110 |
+
|
111 |
+
# Draw a rectangle behind the text
|
112 |
+
rectangle_color = (0, 0, 0)
|
113 |
+
rectangle_thickness = -1
|
114 |
+
rectangle_position = (10, 10)
|
115 |
+
rectangle_size = (text_size[0] + 5, text_size[1] + 5)
|
116 |
+
cv2.rectangle(
|
117 |
+
image,
|
118 |
+
rectangle_position,
|
119 |
+
(
|
120 |
+
rectangle_position[0] + rectangle_size[0],
|
121 |
+
rectangle_position[1] + rectangle_size[1],
|
122 |
+
),
|
123 |
+
rectangle_color,
|
124 |
+
rectangle_thickness,
|
125 |
+
)
|
126 |
+
|
127 |
+
cv2.putText(
|
128 |
+
image, token, text_position, font, font_scale, text_color, text_thickness
|
129 |
+
)
|
130 |
+
|
131 |
+
return image
|
132 |
+
|
133 |
+
|
134 |
+
def saliency_video(path, sequence):
|
135 |
+
|
136 |
+
image_files = sorted(glob.glob(os.path.join(path, "*.png")), key=os.path.getctime)
|
137 |
+
image = cv2.imread(image_files[0])
|
138 |
+
height = image.shape[0]
|
139 |
+
widht = image.shape[1]
|
140 |
+
|
141 |
+
# Create a VideoWriter object to save the video
|
142 |
+
video_name = os.path.join(path, "saliency.mp4")
|
143 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
144 |
+
|
145 |
+
video = cv2.VideoWriter(video_name, fourcc, 5, (widht, height))
|
146 |
+
|
147 |
+
for image_file, token in zip(image_files, sequence):
|
148 |
+
|
149 |
+
image = cv2.imread(image_file)
|
150 |
+
|
151 |
+
# Write the image to the video
|
152 |
+
video.write(image)
|
153 |
+
|
154 |
+
# Release the VideoWriter object
|
155 |
+
video.release()
|
156 |
+
|
157 |
+
print(f"Video saved as {video_name}")
|