mattmdjaga commited on
Commit
390940a
·
1 Parent(s): 7c918f9

Changed checking of cached data to take into account people using the app at the same time

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -8,12 +8,13 @@ import cv2
8
  from typing import List
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
11
 
12
  # Load model and processor
13
  model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
14
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
15
 
16
- embedding = None
17
 
18
  def mask_2_dots(mask: np.ndarray) -> List[List[int]]:
19
  gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
@@ -31,14 +32,16 @@ def mask_2_dots(mask: np.ndarray) -> List[List[int]]:
31
 
32
  @torch.no_grad()
33
  def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
34
- global embedding
35
  image_input = Image.fromarray(image_input)
36
  inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
37
- if not isinstance(embedding, torch.Tensor):
38
  embedding = model.get_image_embeddings(inputs["pixel_values"])
 
 
39
  del inputs["pixel_values"]
40
 
41
- outputs = model.forward(image_embeddings=embedding, **inputs)
42
  masks = processor.image_processor.post_process_masks(
43
  outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
44
  )
@@ -63,9 +66,9 @@ def main_func(inputs) -> List[Image.Image]:
63
 
64
  return pred_masks
65
 
66
- def reset_embedding():
67
- global embedding
68
- embedding = None
69
 
70
  with gr.Blocks() as demo:
71
  gr.Markdown("# How to use")
@@ -81,6 +84,6 @@ with gr.Blocks() as demo:
81
  image_button = gr.Button("Segment Image")
82
 
83
  image_button.click(main_func, inputs=image_input, outputs=image_output)
84
- image_input.upload(reset_embedding)
85
 
86
  demo.launch()
 
8
  from typing import List
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ device = 'cpu'
12
 
13
  # Load model and processor
14
  model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
15
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
16
 
17
+ cache_data = None
18
 
19
  def mask_2_dots(mask: np.ndarray) -> List[List[int]]:
20
  gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
 
32
 
33
  @torch.no_grad()
34
  def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
35
+ global cache_data
36
  image_input = Image.fromarray(image_input)
37
  inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
38
+ if not cache_data or not torch.equal(inputs['pixel_values'],cache_data[0]):
39
  embedding = model.get_image_embeddings(inputs["pixel_values"])
40
+ pixels = inputs["pixel_values"]
41
+ cache_data = [pixels, embedding]
42
  del inputs["pixel_values"]
43
 
44
+ outputs = model.forward(image_embeddings=cache_data[1], **inputs)
45
  masks = processor.image_processor.post_process_masks(
46
  outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
47
  )
 
66
 
67
  return pred_masks
68
 
69
+ def reset_data():
70
+ global cache_data
71
+ cache_data = None
72
 
73
  with gr.Blocks() as demo:
74
  gr.Markdown("# How to use")
 
84
  image_button = gr.Button("Segment Image")
85
 
86
  image_button.click(main_func, inputs=image_input, outputs=image_output)
87
+ image_input.upload(reset_data)
88
 
89
  demo.launch()