watchtowerss commited on
Commit
23d6e96
1 Parent(s): 98d86dc

memory usage reduce for tracking

Browse files
Files changed (4) hide show
  1. app.py +95 -45
  2. inpainter/base_inpainter.py +20 -4
  3. track_anything.py +21 -6
  4. tracker/.DS_Store +0 -0
app.py CHANGED
@@ -8,15 +8,14 @@ import sys
8
  sys.path.append(sys.path[0]+"/tracker")
9
  sys.path.append(sys.path[0]+"/tracker/model")
10
  from track_anything import TrackingAnything
11
- from track_anything import parse_augment
12
  import requests
13
  import json
14
  import torchvision
15
  import torch
16
- from tools.interact_tools import SamControler
17
- from tracker.base_tracker import BaseTracker
18
  from tools.painter import mask_painter
19
  import psutil
 
20
  try:
21
  from mmcv.cnn import ConvModule
22
  except:
@@ -71,6 +70,7 @@ def get_prompt(click_state, click_input):
71
  return prompt
72
 
73
 
 
74
  # extract frames from upload video
75
  def get_frames_from_video(video_input, video_state):
76
  """
@@ -81,49 +81,72 @@ def get_frames_from_video(video_input, video_state):
81
  [[0:nearest_frame], [nearest_frame:], nearest_frame]
82
  """
83
  video_path = video_input
84
- frames = []
 
 
 
85
 
 
 
86
  operation_log = [("",""),("Upload video already. Try click the image for adding targets to track and inpaint.","Normal")]
87
  try:
88
  cap = cv2.VideoCapture(video_path)
89
  fps = cap.get(cv2.CAP_PROP_FPS)
 
 
 
 
 
 
 
 
90
  while cap.isOpened():
91
  ret, frame = cap.read()
92
  if ret == True:
93
  current_memory_usage = psutil.virtual_memory().percent
94
- frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
95
- if current_memory_usage > 70:
96
- operation_log = [("Memory usage is too high (>70%). Stop the video extraction. Please reduce the video resolution or frame rate or wait for other users to complete the operation.", "Error")]
97
- print("Memory usage is too high (>50%). Please reduce the video resolution or frame rate.")
 
 
 
 
98
  break
99
  else:
100
  break
 
101
  except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
 
 
102
  print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
103
- image_size = (frames[0].shape[0],frames[0].shape[1])
 
 
 
 
 
104
  # initialize video_state
105
  video_state = {
 
106
  "video_name": os.path.split(video_path)[-1],
107
  "origin_images": frames,
108
  "painted_images": frames.copy(),
109
- "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames),
110
  "logits": [None]*len(frames),
111
  "select_frame_number": 0,
112
  "fps": fps
113
  }
114
  video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
115
  model.samcontroler.sam_controler.reset_image()
116
- model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
117
- return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
118
- gr.update(visible=True),\
119
- gr.update(visible=True), gr.update(visible=True), \
120
- gr.update(visible=True), gr.update(visible=True), \
121
- gr.update(visible=True), gr.update(visible=True), \
122
- gr.update(visible=True), gr.update(visible=True), \
123
- gr.update(visible=True, value=operation_log)
124
 
125
  def run_example(example):
126
- return video_input
127
  # get the select frame from gradio slider
128
  def select_template(image_selection_slider, video_state, interactive_state):
129
 
@@ -134,21 +157,22 @@ def select_template(image_selection_slider, video_state, interactive_state):
134
  # once select a new template frame, set the image in sam
135
 
136
  model.samcontroler.sam_controler.reset_image()
137
- model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
138
 
139
  # update the masks when select a new template frame
140
  # if video_state["masks"][image_selection_slider] is not None:
141
  # video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
142
  operation_log = [("",""), ("Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider),"Normal")]
143
 
144
- return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log
145
 
146
  # set the tracking end frame
147
  def get_end_number(track_pause_number_slider, video_state, interactive_state):
 
148
  interactive_state["track_end_number"] = track_pause_number_slider
149
  operation_log = [("",""),("Set the tracking finish at frame {}".format(track_pause_number_slider),"Normal")]
150
 
151
- return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log
152
 
153
  def get_resize_ratio(resize_ratio_slider, interactive_state):
154
  interactive_state["resize_ratio"] = resize_ratio_slider
@@ -172,18 +196,18 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
172
 
173
  # prompt for sam model
174
  model.samcontroler.sam_controler.reset_image()
175
- model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
176
  prompt = get_prompt(click_state=click_state, click_input=coordinate)
177
 
178
  mask, logit, painted_image = model.first_frame_click(
179
- image=video_state["origin_images"][video_state["select_frame_number"]],
180
  points=np.array(prompt["input_point"]),
181
  labels=np.array(prompt["input_label"]),
182
  multimask=prompt["multimask_output"],
183
  )
184
  video_state["masks"][video_state["select_frame_number"]] = mask
185
  video_state["logits"][video_state["select_frame_number"]] = logit
186
- video_state["painted_images"][video_state["select_frame_number"]] = painted_image
187
 
188
  operation_log = [("",""), ("Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment","Normal")]
189
  return painted_image, video_state, interactive_state, operation_log
@@ -203,7 +227,7 @@ def add_multi_mask(video_state, interactive_state, mask_dropdown):
203
 
204
  def clear_click(video_state, click_state):
205
  click_state = [[],[]]
206
- template_frame = video_state["origin_images"][video_state["select_frame_number"]]
207
  operation_log = [("",""), ("Clear points history and refresh the image.","Normal")]
208
  return template_frame, click_state, operation_log
209
 
@@ -216,7 +240,7 @@ def remove_multi_mask(interactive_state, mask_dropdown):
216
 
217
  def show_mask(video_state, interactive_state, mask_dropdown):
218
  mask_dropdown.sort()
219
- select_frame = video_state["origin_images"][video_state["select_frame_number"]]
220
 
221
  for i in range(len(mask_dropdown)):
222
  mask_number = int(mask_dropdown[i].split("_")[1]) - 1
@@ -253,18 +277,18 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
253
  template_mask[0][0]=1
254
  operation_log = [("Error! Please add at least one mask to track by clicking the left image.","Error"), ("","")]
255
  # return video_output, video_state, interactive_state, operation_error
256
- masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
257
  # clear GPU memory
258
  model.xmem.clear_memory()
259
 
260
  if interactive_state["track_end_number"]:
261
  video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
262
  video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
263
- video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images
264
  else:
265
  video_state["masks"][video_state["select_frame_number"]:] = masks
266
  video_state["logits"][video_state["select_frame_number"]:] = logits
267
- video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
268
 
269
  video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
270
  interactive_state["inference_times"] += 1
@@ -283,20 +307,16 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
283
  for mask in video_state["masks"]:
284
  np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask)
285
  i+=1
286
- # save_mask(video_state["masks"], video_state["video_name"])
287
  #### shanggao code for mask save
288
  return video_output, video_state, interactive_state, operation_log
289
 
290
- # extracting masks from mask_dropdown
291
- # def extract_sole_mask(video_state, mask_dropdown):
292
- # combined_masks =
293
- # unique_masks = np.unique(combined_masks)
294
- # return 0
295
 
296
  # inpaint
297
  def inpaint_video(video_state, interactive_state, mask_dropdown):
298
  operation_log = [("",""), ("Removed the selected masks.","Normal")]
299
 
 
300
  frames = np.asarray(video_state["origin_images"])
301
  fps = video_state["fps"]
302
  inpaint_masks = np.asarray(video_state["masks"])
@@ -319,13 +339,39 @@ def inpaint_video(video_state, interactive_state, mask_dropdown):
319
  except:
320
  operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")]
321
  inpainted_frames = video_state["origin_images"]
322
- video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
323
-
324
  return video_output, operation_log
325
 
326
 
327
  # generate video after vos inference
328
- def generate_video_from_frames(frames, output_path, fps=30):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  """
330
  Generates a video from a list of frames.
331
 
@@ -375,8 +421,8 @@ folder ="./checkpoints"
375
  SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
376
  xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
377
  e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
378
- # args.port = 12214
379
- # args.device = "cuda:2"
380
  # args.mask_save = True
381
 
382
  # initialize sam, xmem, e2fgvi models
@@ -409,6 +455,7 @@ with gr.Blocks() as iface:
409
 
410
  video_state = gr.State(
411
  {
 
412
  "video_name": "",
413
  "origin_images": None,
414
  "painted_images": None,
@@ -458,7 +505,7 @@ with gr.Blocks() as iface:
458
  track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
459
 
460
  with gr.Column():
461
- run_status = gr.HighlightedText(value=[("Text","Error"),("to be","Label 2"),("highlighted","Label 3")], visible=False)
462
  mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
463
  video_output = gr.Video(autosize=True, visible=False).style(height=360)
464
  with gr.Row():
@@ -471,9 +518,10 @@ with gr.Blocks() as iface:
471
  inputs=[
472
  video_input, video_state
473
  ],
474
- outputs=[video_state, video_info, template_frame,
475
- image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button, template_frame,
476
- tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status]
 
477
  )
478
 
479
  # second step: select images from slider
@@ -532,6 +580,8 @@ with gr.Blocks() as iface:
532
  video_input.clear(
533
  lambda: (
534
  {
 
 
535
  "origin_images": None,
536
  "painted_images": None,
537
  "masks": None,
 
8
  sys.path.append(sys.path[0]+"/tracker")
9
  sys.path.append(sys.path[0]+"/tracker/model")
10
  from track_anything import TrackingAnything
11
+ from track_anything import parse_augment, save_image_to_userfolder, read_image_from_userfolder
12
  import requests
13
  import json
14
  import torchvision
15
  import torch
 
 
16
  from tools.painter import mask_painter
17
  import psutil
18
+ import time
19
  try:
20
  from mmcv.cnn import ConvModule
21
  except:
 
70
  return prompt
71
 
72
 
73
+
74
  # extract frames from upload video
75
  def get_frames_from_video(video_input, video_state):
76
  """
 
81
  [[0:nearest_frame], [nearest_frame:], nearest_frame]
82
  """
83
  video_path = video_input
84
+ frames = [] # save image path
85
+ user_name = time.time()
86
+ video_state["video_name"] = os.path.split(video_path)[-1]
87
+ video_state["user_name"] = user_name
88
 
89
+ os.makedirs(os.path.join("/tmp/{}/originimages/{}".format(video_state["user_name"], video_state["video_name"])), exist_ok=True)
90
+ os.makedirs(os.path.join("/tmp/{}/paintedimages/{}".format(video_state["user_name"], video_state["video_name"])), exist_ok=True)
91
  operation_log = [("",""),("Upload video already. Try click the image for adding targets to track and inpaint.","Normal")]
92
  try:
93
  cap = cv2.VideoCapture(video_path)
94
  fps = cap.get(cv2.CAP_PROP_FPS)
95
+ if not cap.isOpened():
96
+ operation_log = [("No frames extracted, please input video file with '.mp4.' '.mov'.", "Error")]
97
+ print("No frames extracted, please input video file with '.mp4.' '.mov'.")
98
+ return None, None, None, None, \
99
+ None, None, None, None, \
100
+ None, None, None, None, \
101
+ None, None, gr.update(visible=True, value=operation_log)
102
+ image_index = 0
103
  while cap.isOpened():
104
  ret, frame = cap.read()
105
  if ret == True:
106
  current_memory_usage = psutil.virtual_memory().percent
107
+
108
+ # try solve memory usage problem, save image to disk instead of memory
109
+ frames.append(save_image_to_userfolder(video_state, image_index, frame, True))
110
+ image_index +=1
111
+ # frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
112
+ if current_memory_usage > 90:
113
+ operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")]
114
+ print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.")
115
  break
116
  else:
117
  break
118
+
119
  except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
120
+ # except:
121
+ operation_log = [("read_frame_source:{} error. {}\n".format(video_path, str(e)), "Error")]
122
  print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
123
+ return None, None, None, None, \
124
+ None, None, None, None, \
125
+ None, None, None, None, \
126
+ None, None, gr.update(visible=True, value=operation_log)
127
+ first_image = read_image_from_userfolder(frames[0])
128
+ image_size = (first_image.shape[0], first_image.shape[1])
129
  # initialize video_state
130
  video_state = {
131
+ "user_name": user_name,
132
  "video_name": os.path.split(video_path)[-1],
133
  "origin_images": frames,
134
  "painted_images": frames.copy(),
135
+ "masks": [np.zeros((image_size[0], image_size[1]), np.uint8)]*len(frames),
136
  "logits": [None]*len(frames),
137
  "select_frame_number": 0,
138
  "fps": fps
139
  }
140
  video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
141
  model.samcontroler.sam_controler.reset_image()
142
+ model.samcontroler.sam_controler.set_image(first_image)
143
+ return video_state, video_info, first_image, gr.update(visible=True, maximum=len(frames), value=1), \
144
+ gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
145
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
146
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True, value=operation_log),
 
 
 
147
 
148
  def run_example(example):
149
+ return example
150
  # get the select frame from gradio slider
151
  def select_template(image_selection_slider, video_state, interactive_state):
152
 
 
157
  # once select a new template frame, set the image in sam
158
 
159
  model.samcontroler.sam_controler.reset_image()
160
+ model.samcontroler.sam_controler.set_image(read_image_from_userfolder(video_state["origin_images"][image_selection_slider]))
161
 
162
  # update the masks when select a new template frame
163
  # if video_state["masks"][image_selection_slider] is not None:
164
  # video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
165
  operation_log = [("",""), ("Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider),"Normal")]
166
 
167
+ return read_image_from_userfolder(video_state["painted_images"][image_selection_slider]), video_state, interactive_state, operation_log
168
 
169
  # set the tracking end frame
170
  def get_end_number(track_pause_number_slider, video_state, interactive_state):
171
+ track_pause_number_slider -= 1
172
  interactive_state["track_end_number"] = track_pause_number_slider
173
  operation_log = [("",""),("Set the tracking finish at frame {}".format(track_pause_number_slider),"Normal")]
174
 
175
+ return read_image_from_userfolder(video_state["painted_images"][track_pause_number_slider]),interactive_state, operation_log
176
 
177
  def get_resize_ratio(resize_ratio_slider, interactive_state):
178
  interactive_state["resize_ratio"] = resize_ratio_slider
 
196
 
197
  # prompt for sam model
198
  model.samcontroler.sam_controler.reset_image()
199
+ model.samcontroler.sam_controler.set_image(read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]]))
200
  prompt = get_prompt(click_state=click_state, click_input=coordinate)
201
 
202
  mask, logit, painted_image = model.first_frame_click(
203
+ image=read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]]),
204
  points=np.array(prompt["input_point"]),
205
  labels=np.array(prompt["input_label"]),
206
  multimask=prompt["multimask_output"],
207
  )
208
  video_state["masks"][video_state["select_frame_number"]] = mask
209
  video_state["logits"][video_state["select_frame_number"]] = logit
210
+ video_state["painted_images"][video_state["select_frame_number"]] = save_image_to_userfolder(video_state, index=video_state["select_frame_number"], image=cv2.cvtColor(np.asarray(painted_image),cv2.COLOR_BGR2RGB),type=False)
211
 
212
  operation_log = [("",""), ("Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment","Normal")]
213
  return painted_image, video_state, interactive_state, operation_log
 
227
 
228
  def clear_click(video_state, click_state):
229
  click_state = [[],[]]
230
+ template_frame = read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]])
231
  operation_log = [("",""), ("Clear points history and refresh the image.","Normal")]
232
  return template_frame, click_state, operation_log
233
 
 
240
 
241
  def show_mask(video_state, interactive_state, mask_dropdown):
242
  mask_dropdown.sort()
243
+ select_frame = read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]])
244
 
245
  for i in range(len(mask_dropdown)):
246
  mask_number = int(mask_dropdown[i].split("_")[1]) - 1
 
277
  template_mask[0][0]=1
278
  operation_log = [("Error! Please add at least one mask to track by clicking the left image.","Error"), ("","")]
279
  # return video_output, video_state, interactive_state, operation_error
280
+ masks, logits, painted_images_path = model.generator(images=following_frames, template_mask=template_mask, video_state=video_state)
281
  # clear GPU memory
282
  model.xmem.clear_memory()
283
 
284
  if interactive_state["track_end_number"]:
285
  video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
286
  video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
287
+ video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images_path
288
  else:
289
  video_state["masks"][video_state["select_frame_number"]:] = masks
290
  video_state["logits"][video_state["select_frame_number"]:] = logits
291
+ video_state["painted_images"][video_state["select_frame_number"]:] = painted_images_path
292
 
293
  video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
294
  interactive_state["inference_times"] += 1
 
307
  for mask in video_state["masks"]:
308
  np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask)
309
  i+=1
 
310
  #### shanggao code for mask save
311
  return video_output, video_state, interactive_state, operation_log
312
 
313
+
 
 
 
 
314
 
315
  # inpaint
316
  def inpaint_video(video_state, interactive_state, mask_dropdown):
317
  operation_log = [("",""), ("Removed the selected masks.","Normal")]
318
 
319
+ # solve memory
320
  frames = np.asarray(video_state["origin_images"])
321
  fps = video_state["fps"]
322
  inpaint_masks = np.asarray(video_state["masks"])
 
339
  except:
340
  operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")]
341
  inpainted_frames = video_state["origin_images"]
342
+ video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
343
+ video_output = generate_video_from_paintedframes(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps)
344
  return video_output, operation_log
345
 
346
 
347
  # generate video after vos inference
348
+ def generate_video_from_frames(frames_path, output_path, fps=30):
349
+ """
350
+ Generates a video from a list of frames.
351
+
352
+ Args:
353
+ frames (list of numpy arrays): The frames to include in the video.
354
+ output_path (str): The path to save the generated video.
355
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
356
+ """
357
+ # height, width, layers = frames[0].shape
358
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
359
+ # video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
360
+ # print(output_path)
361
+ # for frame in frames:
362
+ # video.write(frame)
363
+
364
+ # video.release()
365
+ frames = []
366
+ for file in frames_path:
367
+ frames.append(read_image_from_userfolder(file))
368
+ frames = torch.from_numpy(np.asarray(frames))
369
+ if not os.path.exists(os.path.dirname(output_path)):
370
+ os.makedirs(os.path.dirname(output_path))
371
+ torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
372
+ return output_path
373
+
374
+ def generate_video_from_paintedframes(frames, output_path, fps=30):
375
  """
376
  Generates a video from a list of frames.
377
 
 
421
  SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
422
  xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
423
  e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
424
+ # args.port = 12213
425
+ # args.device = "cuda:1"
426
  # args.mask_save = True
427
 
428
  # initialize sam, xmem, e2fgvi models
 
455
 
456
  video_state = gr.State(
457
  {
458
+ "user_name": "",
459
  "video_name": "",
460
  "origin_images": None,
461
  "painted_images": None,
 
505
  track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
506
 
507
  with gr.Column():
508
+ run_status = gr.HighlightedText(value=[("Text","Error"),("to be","Label 2"),("highlighted","Label 3")], visible=True)
509
  mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
510
  video_output = gr.Video(autosize=True, visible=False).style(height=360)
511
  with gr.Row():
 
518
  inputs=[
519
  video_input, video_state
520
  ],
521
+ outputs=[video_state, video_info, template_frame, image_selection_slider,
522
+ track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button,
523
+ template_frame, tracking_video_predict_button, video_output, mask_dropdown,
524
+ remove_mask_button, inpaint_video_predict_button, run_status]
525
  )
526
 
527
  # second step: select images from slider
 
580
  video_input.clear(
581
  lambda: (
582
  {
583
+ "user_name": "",
584
+ "video_name": "",
585
  "origin_images": None,
586
  "painted_images": None,
587
  "masks": None,
inpainter/base_inpainter.py CHANGED
@@ -1,17 +1,28 @@
1
  import os
2
  import glob
3
  from PIL import Image
4
-
5
  import torch
6
  import yaml
7
  import cv2
8
  import importlib
9
  import numpy as np
10
  from tqdm import tqdm
11
-
12
  from inpainter.util.tensor_util import resize_frames, resize_masks
13
 
14
-
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  class BaseInpainter:
16
  def __init__(self, E2FGVI_checkpoint, device) -> None:
17
  """
@@ -46,7 +57,7 @@ class BaseInpainter:
46
  ref_index.append(i)
47
  return ref_index
48
 
49
- def inpaint(self, frames, masks, dilate_radius=15, ratio=1):
50
  """
51
  frames: numpy array, T, H, W, 3
52
  masks: numpy array, T, H, W
@@ -56,6 +67,11 @@ class BaseInpainter:
56
  Output:
57
  inpainted_frames: numpy array, T, H, W, 3
58
  """
 
 
 
 
 
59
  assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
60
  assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
61
  masks = masks.copy()
 
1
  import os
2
  import glob
3
  from PIL import Image
 
4
  import torch
5
  import yaml
6
  import cv2
7
  import importlib
8
  import numpy as np
9
  from tqdm import tqdm
 
10
  from inpainter.util.tensor_util import resize_frames, resize_masks
11
 
12
+ def read_image_from_userfolder(image_path):
13
+ # if type:
14
+ image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
15
+ # else:
16
+ # image = cv2.cvtColor(cv2.imread("/tmp/{}/paintedimages/{}/{:08d}.png".format(username, video_state["video_name"], index+ ".png")), cv2.COLOR_BGR2RGB)
17
+ return image
18
+
19
+ def save_image_to_userfolder(video_state, index, image, type:bool):
20
+ if type:
21
+ image_path = "/tmp/{}/originimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
22
+ else:
23
+ image_path = "/tmp/{}/paintedimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
24
+ cv2.imwrite(image_path, image)
25
+ return image_path
26
  class BaseInpainter:
27
  def __init__(self, E2FGVI_checkpoint, device) -> None:
28
  """
 
57
  ref_index.append(i)
58
  return ref_index
59
 
60
+ def inpaint(self, frames_path, masks, dilate_radius=15, ratio=1):
61
  """
62
  frames: numpy array, T, H, W, 3
63
  masks: numpy array, T, H, W
 
67
  Output:
68
  inpainted_frames: numpy array, T, H, W, 3
69
  """
70
+ frames = []
71
+ for file in frames_path:
72
+ frames.append(read_image_from_userfolder(file))
73
+ frames = np.asarray(frames)
74
+
75
  assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
76
  assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
77
  masks = masks.copy()
track_anything.py CHANGED
@@ -6,9 +6,22 @@ from tracker.base_tracker import BaseTracker
6
  from inpainter.base_inpainter import BaseInpainter
7
  import numpy as np
8
  import argparse
 
9
 
 
 
 
 
 
 
10
 
11
-
 
 
 
 
 
 
12
  class TrackingAnything():
13
  def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args):
14
  self.args = args
@@ -39,23 +52,25 @@ class TrackingAnything():
39
  # mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
40
  # return mask, logit, painted_image
41
 
42
- def generator(self, images: list, template_mask:np.ndarray):
43
 
44
  masks = []
45
  logits = []
46
  painted_images = []
47
  for i in tqdm(range(len(images)), desc="Tracking image"):
48
  if i ==0:
49
- mask, logit, painted_image = self.xmem.track(images[i], template_mask)
50
  masks.append(mask)
51
  logits.append(logit)
52
- painted_images.append(painted_image)
 
53
 
54
  else:
55
- mask, logit, painted_image = self.xmem.track(images[i])
56
  masks.append(mask)
57
  logits.append(logit)
58
- painted_images.append(painted_image)
 
59
  return masks, logits, painted_images
60
 
61
 
 
6
  from inpainter.base_inpainter import BaseInpainter
7
  import numpy as np
8
  import argparse
9
+ import cv2
10
 
11
+ def read_image_from_userfolder(image_path):
12
+ # if type:
13
+ image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
14
+ # else:
15
+ # image = cv2.cvtColor(cv2.imread("/tmp/{}/paintedimages/{}/{:08d}.png".format(username, video_state["video_name"], index+ ".png")), cv2.COLOR_BGR2RGB)
16
+ return image
17
 
18
+ def save_image_to_userfolder(video_state, index, image, type:bool):
19
+ if type:
20
+ image_path = "/tmp/{}/originimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
21
+ else:
22
+ image_path = "/tmp/{}/paintedimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
23
+ cv2.imwrite(image_path, image)
24
+ return image_path
25
  class TrackingAnything():
26
  def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args):
27
  self.args = args
 
52
  # mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
53
  # return mask, logit, painted_image
54
 
55
+ def generator(self, images: list, template_mask:np.ndarray, video_state:dict):
56
 
57
  masks = []
58
  logits = []
59
  painted_images = []
60
  for i in tqdm(range(len(images)), desc="Tracking image"):
61
  if i ==0:
62
+ mask, logit, painted_image = self.xmem.track(read_image_from_userfolder(images[i]), template_mask)
63
  masks.append(mask)
64
  logits.append(logit)
65
+ # painted_images.append(painted_image)
66
+ painted_images.append(save_image_to_userfolder(video_state, index=i, image=cv2.cvtColor(np.asarray(painted_image),cv2.COLOR_BGR2RGB), type=False))
67
 
68
  else:
69
+ mask, logit, painted_image = self.xmem.track(read_image_from_userfolder(images[i]))
70
  masks.append(mask)
71
  logits.append(logit)
72
+ # painted_images.append(painted_image)
73
+ painted_images.append(save_image_to_userfolder(video_state, index=i, image=cv2.cvtColor(np.asarray(painted_image),cv2.COLOR_BGR2RGB), type=False))
74
  return masks, logits, painted_images
75
 
76
 
tracker/.DS_Store CHANGED
Binary files a/tracker/.DS_Store and b/tracker/.DS_Store differ