de-Rodrigo commited on
Commit
f53adeb
·
1 Parent(s): bedfdc1

Visualize Saliency Maps

Browse files
Files changed (3) hide show
  1. __pycache__/utils.cpython-310.pyc +0 -0
  2. app.py +74 -36
  3. utils.py +157 -0
__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.56 kB). View file
 
app.py CHANGED
@@ -10,22 +10,22 @@ import re
10
  import logging
11
  from datasets import load_dataset
12
  import os
 
 
 
 
13
 
14
  # Logging configuration
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
-
19
  # Paths to the static image and GIF
20
  README_IMAGE_PATH = os.path.join("figs", "saliencies-merit-dataset.png")
21
  GIF_PATH = os.path.join("figs", "demo-samples.gif")
22
 
23
  # Global variables for Donut model, processor, and dataset
24
- donut_model = None
25
- donut_processor = None
26
  dataset = None
27
 
28
-
29
  def load_merit_dataset():
30
  global dataset
31
  if dataset is None:
@@ -34,7 +34,6 @@ def load_merit_dataset():
34
  )
35
  return dataset
36
 
37
-
38
  def get_image_from_dataset(index):
39
  global dataset
40
  if dataset is None:
@@ -42,44 +41,84 @@ def get_image_from_dataset(index):
42
  image_data = dataset[int(index)]["image"]
43
  return image_data
44
 
45
-
46
  def get_collection_models(tag: str) -> List[str]:
47
  """Get a list of models from a specific Hugging Face collection."""
48
  models = list_models(author="de-Rodrigo")
49
  return [model.modelId for model in models if tag in model.tags]
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- @spaces.GPU
53
- def get_donut():
54
- global donut_model, donut_processor
55
- if donut_model is None or donut_processor is None:
56
- try:
57
- donut_model = VisionEncoderDecoderModel.from_pretrained(
58
- "de-Rodrigo/donut-merit"
59
- )
60
- donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit")
61
- donut_model = donut_model.to("cuda")
62
- logger.info("Donut model loaded successfully on GPU")
63
- except Exception as e:
64
- logger.error(f"Error loading Donut model: {str(e)}")
65
- raise
66
- return donut_model, donut_processor
67
 
 
 
 
68
 
69
- @spaces.GPU
70
- def process_image_donut(model, processor, image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  try:
 
 
72
  if not isinstance(image, Image.Image):
73
  image = Image.fromarray(image)
74
 
75
  pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda")
 
76
 
77
  task_prompt = "<s_cord-v2>"
78
  decoder_input_ids = processor.tokenizer(
79
  task_prompt, add_special_tokens=False, return_tensors="pt"
80
  )["input_ids"].to("cuda")
81
 
82
- outputs = model.generate(
 
83
  pixel_values,
84
  decoder_input_ids=decoder_input_ids,
85
  max_length=model.decoder.config.max_position_embeddings,
@@ -90,8 +129,11 @@ def process_image_donut(model, processor, image):
90
  num_beams=1,
91
  bad_words_ids=[[processor.tokenizer.unk_token_id]],
92
  return_dict_in_generate=True,
 
93
  )
94
 
 
 
95
  sequence = processor.batch_decode(outputs.sequences)[0]
96
  sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
97
  processor.tokenizer.pad_token, ""
@@ -99,31 +141,27 @@ def process_image_donut(model, processor, image):
99
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
100
 
101
  result = processor.token2json(sequence)
102
- return json.dumps(result, indent=2)
103
  except Exception as e:
104
  logger.error(f"Error processing image with Donut: {str(e)}")
105
- return f"Error: {str(e)}"
106
 
107
-
108
- @spaces.GPU
109
  def process_image(model_name, image=None, dataset_image_index=None):
110
  if dataset_image_index is not None:
111
  image = get_image_from_dataset(dataset_image_index)
112
 
113
  if model_name == "de-Rodrigo/donut-merit":
114
- model, processor = get_donut()
115
- result = process_image_donut(model, processor, image)
116
  else:
117
- # Here you should implement processing for other models
118
- result = f"Processing for model {model_name} not implemented"
119
-
120
- return image, result
121
 
 
122
 
123
  def update_image(dataset_image_index):
124
  return get_image_from_dataset(dataset_image_index)
125
 
126
-
127
  if __name__ == "__main__":
128
  # Load the dataset
129
  load_merit_dataset()
@@ -180,7 +218,7 @@ if __name__ == "__main__":
180
  process_button = gr.Button("Process Image")
181
 
182
  with gr.Row():
183
- output_image = gr.Image(label="Processed Image")
184
  output_text = gr.Textbox(label="Result")
185
 
186
  # Update preview image when slider changes
 
10
  import logging
11
  from datasets import load_dataset
12
  import os
13
+ import numpy as np
14
+ from datetime import datetime
15
+ # Importar utils y save_img si no están ya importados
16
+ import utils
17
 
18
  # Logging configuration
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
 
22
  # Paths to the static image and GIF
23
  README_IMAGE_PATH = os.path.join("figs", "saliencies-merit-dataset.png")
24
  GIF_PATH = os.path.join("figs", "demo-samples.gif")
25
 
26
  # Global variables for Donut model, processor, and dataset
 
 
27
  dataset = None
28
 
 
29
  def load_merit_dataset():
30
  global dataset
31
  if dataset is None:
 
34
  )
35
  return dataset
36
 
 
37
  def get_image_from_dataset(index):
38
  global dataset
39
  if dataset is None:
 
41
  image_data = dataset[int(index)]["image"]
42
  return image_data
43
 
 
44
  def get_collection_models(tag: str) -> List[str]:
45
  """Get a list of models from a specific Hugging Face collection."""
46
  models = list_models(author="de-Rodrigo")
47
  return [model.modelId for model in models if tag in model.tags]
48
 
49
+ def initialize_donut():
50
+ try:
51
+ donut_model = VisionEncoderDecoderModel.from_pretrained(
52
+ "de-Rodrigo/donut-merit"
53
+ )
54
+ donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit")
55
+ donut_model = donut_model.to("cuda")
56
+ logger.info("Donut model loaded successfully on GPU")
57
+ return donut_model, donut_processor
58
+ except Exception as e:
59
+ logger.error(f"Error loading Donut model: {str(e)}")
60
+ raise
61
+
62
+ def compute_saliency(outputs, pixels, donut_p, image):
63
+ token_logits = torch.stack(outputs.scores, dim=1)
64
+ token_probs = torch.softmax(token_logits, dim=-1)
65
+ token_texts = []
66
+ saliency_images = []
67
 
68
+ for token_index in range(len(token_probs[0])):
69
+ target_token_prob = token_probs[
70
+ 0, token_index, outputs.sequences[0, token_index]
71
+ ]
72
+
73
+ if pixels.grad is not None:
74
+ pixels.grad.zero_()
75
+
76
+ target_token_prob.backward(retain_graph=True)
77
+
78
+ saliency = pixels.grad.data.abs().squeeze().mean(dim=0)
 
 
 
 
79
 
80
+ token_id = outputs.sequences[0][token_index].item()
81
+ token_text = donut_p.tokenizer.decode([token_id])
82
+ logger.info(f"Considered sequence token: {token_text}")
83
 
84
+ safe_token_text = re.sub(r'[<>:"/\\|?*]', "_", token_text)
85
+ current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")
86
+
87
+ unique_safe_token_text = f"{safe_token_text}_{current_datetime}"
88
+ file_name = f"saliency_{unique_safe_token_text}.png"
89
+
90
+ saliency = utils.convert_tensor_to_rgba_image(saliency)
91
+
92
+ # Merge saliency image twice
93
+ saliency = utils.add_transparent_image(np.array(image), saliency)
94
+ saliency = utils.convert_rgb_to_rgba_image(saliency)
95
+ saliency = utils.add_transparent_image(np.array(image), saliency, 0.7)
96
+
97
+ saliency = utils.label_frame(saliency, token_text)
98
+
99
+ saliency_images.append(saliency)
100
+ token_texts.append(token_text)
101
+
102
+ return saliency_images, token_texts
103
+
104
+ @spaces.GPU(duration=300)
105
+ def process_image_donut(image):
106
  try:
107
+ model, processor = initialize_donut()
108
+
109
  if not isinstance(image, Image.Image):
110
  image = Image.fromarray(image)
111
 
112
  pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda")
113
+ pixel_values.requires_grad = True
114
 
115
  task_prompt = "<s_cord-v2>"
116
  decoder_input_ids = processor.tokenizer(
117
  task_prompt, add_special_tokens=False, return_tensors="pt"
118
  )["input_ids"].to("cuda")
119
 
120
+ outputs = model.generate.__wrapped__(
121
+ model,
122
  pixel_values,
123
  decoder_input_ids=decoder_input_ids,
124
  max_length=model.decoder.config.max_position_embeddings,
 
129
  num_beams=1,
130
  bad_words_ids=[[processor.tokenizer.unk_token_id]],
131
  return_dict_in_generate=True,
132
+ output_scores=True,
133
  )
134
 
135
+ saliency_images, token_texts = compute_saliency(outputs, pixel_values, processor, image)
136
+
137
  sequence = processor.batch_decode(outputs.sequences)[0]
138
  sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
139
  processor.tokenizer.pad_token, ""
 
141
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
142
 
143
  result = processor.token2json(sequence)
144
+ return saliency_images, json.dumps(result, indent=2)
145
  except Exception as e:
146
  logger.error(f"Error processing image with Donut: {str(e)}")
147
+ return None, f"Error: {str(e)}"
148
 
149
+ @spaces.GPU(duration=300)
 
150
  def process_image(model_name, image=None, dataset_image_index=None):
151
  if dataset_image_index is not None:
152
  image = get_image_from_dataset(dataset_image_index)
153
 
154
  if model_name == "de-Rodrigo/donut-merit":
155
+ saliency_images, result = process_image_donut(image)
 
156
  else:
157
+ # Aquí deberías implementar el procesamiento para otros modelos
158
+ saliency_images, result = None, f"Processing for model {model_name} not implemented"
 
 
159
 
160
+ return saliency_images, result
161
 
162
  def update_image(dataset_image_index):
163
  return get_image_from_dataset(dataset_image_index)
164
 
 
165
  if __name__ == "__main__":
166
  # Load the dataset
167
  load_merit_dataset()
 
218
  process_button = gr.Button("Process Image")
219
 
220
  with gr.Row():
221
+ output_image = gr.Gallery(label="Processed Saliency Images")
222
  output_text = gr.Textbox(label="Result")
223
 
224
  # Update preview image when slider changes
utils.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import glob
4
+ import numpy as np
5
+ from datetime import datetime
6
+
7
+
8
+ def add_transparent_image(
9
+ background, foreground, alpha_factor=1.0, x_offset=None, y_offset=None
10
+ ):
11
+ """
12
+ Function sourced from StackOverflow contributor Ben.
13
+
14
+ This function was found on StackOverflow and is the work of Ben, a contributor
15
+ to the community. We are thankful for Ben's assistance by providing this useful
16
+ method.
17
+
18
+ Original Source:
19
+ https://stackoverflow.com/questions/40895785/
20
+ using-opencv-to-overlay-transparent-image-onto-another-image
21
+ """
22
+
23
+ bg_h, bg_w, bg_channels = background.shape
24
+ fg_h, fg_w, fg_channels = foreground.shape
25
+
26
+ assert (
27
+ bg_channels == 3
28
+ ), f"background image should have exactly 3 channels (RGB). found:{bg_channels}"
29
+ assert (
30
+ fg_channels == 4
31
+ ), f"foreground image should have exactly 4 channels (RGBA). found:{fg_channels}"
32
+
33
+ # center by default
34
+ if x_offset is None:
35
+ x_offset = (bg_w - fg_w) // 2
36
+ if y_offset is None:
37
+ y_offset = (bg_h - fg_h) // 2
38
+
39
+ w = min(fg_w, bg_w, fg_w + x_offset, bg_w - x_offset)
40
+ h = min(fg_h, bg_h, fg_h + y_offset, bg_h - y_offset)
41
+
42
+ if w < 1 or h < 1:
43
+ return
44
+
45
+ # clip foreground and background images to the overlapping regions
46
+ bg_x = max(0, x_offset)
47
+ bg_y = max(0, y_offset)
48
+ fg_x = max(0, x_offset * -1)
49
+ fg_y = max(0, y_offset * -1)
50
+ foreground = foreground[fg_y : fg_y + h, fg_x : fg_x + w]
51
+ background_subsection = background[bg_y : bg_y + h, bg_x : bg_x + w]
52
+
53
+ # separate alpha and color channels from the foreground image
54
+ foreground_colors = foreground[:, :, :3]
55
+ foreground_colors = cv2.cvtColor(foreground_colors, cv2.COLOR_BGR2RGB)
56
+ alpha_channel = foreground[:, :, 3] / 255 * alpha_factor # 0-255 => 0.0-1.0
57
+
58
+ # construct an alpha_mask that matches the image shape
59
+ alpha_mask = np.dstack((alpha_channel, alpha_channel, alpha_channel))
60
+
61
+ # combine the background with the overlay image weighted by alpha
62
+ composite = (
63
+ background_subsection * (1 - alpha_mask) + foreground_colors * alpha_mask
64
+ )
65
+
66
+ # overwrite the section of the background image that has been updated
67
+ background[bg_y : bg_y + h, bg_x : bg_x + w] = composite
68
+
69
+ return background
70
+
71
+
72
+ def convert_tensor_to_rgba_image(tensor):
73
+
74
+ saliency_array = tensor.cpu().numpy()
75
+
76
+ # Normalize img a 0-255
77
+ if saliency_array.dtype != np.uint8:
78
+ saliency_array = (255 * saliency_array / saliency_array.max()).astype(np.uint8)
79
+
80
+ heatmap = cv2.applyColorMap(saliency_array, cv2.COLORMAP_JET)
81
+
82
+ # Pixels are transparent where no saliency [128, 0, 0] is black in COLORMAP_JET
83
+ alpha_channel = np.ones(heatmap.shape[:2], dtype=heatmap.dtype) * 255
84
+ black_pixels_mask = np.all(heatmap == [128, 0, 0], axis=-1)
85
+ alpha_channel[black_pixels_mask] = 0
86
+
87
+ # Combinar los canales RGB y alfa
88
+ saliency_rgba = cv2.merge((heatmap, alpha_channel))
89
+
90
+ return saliency_rgba
91
+
92
+
93
+ def convert_rgb_to_rgba_image(image):
94
+
95
+ alpha_channel = np.ones(image.shape[:2], dtype=image.dtype) * 255
96
+ rbga = cv2.merge((cv2.cvtColor(image, cv2.COLOR_RGB2BGR), alpha_channel))
97
+
98
+ return rbga
99
+
100
+
101
+ def label_frame(image, token):
102
+
103
+ # Add the text
104
+ font = cv2.FONT_HERSHEY_SIMPLEX
105
+ font_scale = 0.7
106
+ text_color = (255, 255, 255)
107
+ text_thickness = 1
108
+ text_size, _ = cv2.getTextSize(token, font, font_scale, text_thickness)
109
+ text_position = (10, 10 + text_size[1])
110
+
111
+ # Draw a rectangle behind the text
112
+ rectangle_color = (0, 0, 0)
113
+ rectangle_thickness = -1
114
+ rectangle_position = (10, 10)
115
+ rectangle_size = (text_size[0] + 5, text_size[1] + 5)
116
+ cv2.rectangle(
117
+ image,
118
+ rectangle_position,
119
+ (
120
+ rectangle_position[0] + rectangle_size[0],
121
+ rectangle_position[1] + rectangle_size[1],
122
+ ),
123
+ rectangle_color,
124
+ rectangle_thickness,
125
+ )
126
+
127
+ cv2.putText(
128
+ image, token, text_position, font, font_scale, text_color, text_thickness
129
+ )
130
+
131
+ return image
132
+
133
+
134
+ def saliency_video(path, sequence):
135
+
136
+ image_files = sorted(glob.glob(os.path.join(path, "*.png")), key=os.path.getctime)
137
+ image = cv2.imread(image_files[0])
138
+ height = image.shape[0]
139
+ widht = image.shape[1]
140
+
141
+ # Create a VideoWriter object to save the video
142
+ video_name = os.path.join(path, "saliency.mp4")
143
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
144
+
145
+ video = cv2.VideoWriter(video_name, fourcc, 5, (widht, height))
146
+
147
+ for image_file, token in zip(image_files, sequence):
148
+
149
+ image = cv2.imread(image_file)
150
+
151
+ # Write the image to the video
152
+ video.write(image)
153
+
154
+ # Release the VideoWriter object
155
+ video.release()
156
+
157
+ print(f"Video saved as {video_name}")