oguzakif commited on
Commit
fe7ce2c
·
1 Parent(s): 5a30a78

gradio state added for frame lists

Browse files
Files changed (1) hide show
  1. app.py +16 -19
app.py CHANGED
@@ -27,15 +27,9 @@ sys.path.append(os.path.abspath(join(project_name, 'FGT_codes', 'FGT', 'checkpoi
27
  sys.path.append(os.path.abspath(join(project_name, 'FGT_codes', 'LAFC',
28
  'flowCheckPoint', 'raft-things.pth')))
29
 
30
- # sys.path.append(join(project_name, 'SiamMask',
31
- # 'experiments', 'siammask_sharp'))
32
- # sys.path.append(join(project_name, 'SiamMask', 'models'))
33
- # sys.path.append(join(project_name, 'SiamMask'))
34
-
35
  exp_path = join(project_name, 'SiamMask/experiments/siammask_sharp')
36
  pretrained_path1 = join(exp_path, 'SiamMask_DAVIS.pth')
37
 
38
-
39
  print(sys.path)
40
 
41
  torch.set_grad_enabled(False)
@@ -58,8 +52,7 @@ object_y = 0
58
  object_width = 0
59
  object_height = 0
60
  in_fps = 24
61
- original_frame_list = []
62
- mask_list = []
63
 
64
  parser = argparse.ArgumentParser()
65
  # parser.add_argument('--opt', default='configs/object_removal.yaml',
@@ -151,7 +144,7 @@ def getBoundaries(mask):
151
  return x1, y1, (x2-x1), (y2-y1)
152
 
153
 
154
- def track_and_mask(vid, original_frame, masked_frame):
155
  x, y, w, h = getBoundaries(masked_frame)
156
  f = 0
157
 
@@ -183,7 +176,7 @@ def track_and_mask(vid, original_frame, masked_frame):
183
  # track
184
  state = siamese_track(
185
  state, frame, mask_enable=True, refine_enable=True, device=device)
186
- original_frame_list.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
187
  location = state['ploygon'].flatten()
188
  mask = state['mask'] > state['p'].seg_thr
189
  frame[:, :, 2] = (mask > 0) * 255 + \
@@ -206,17 +199,19 @@ def track_and_mask(vid, original_frame, masked_frame):
206
 
207
  print('Original Frame Count: ',len(original_frame_list))
208
  print('Mask Frame Count: ',len(mask_list))
209
- return dt_string+"_output.avi"
 
 
210
 
211
 
212
- def inpaint_video():
213
  args.out_fps = in_fps
214
  video_inpainting(args, original_frame_list, mask_list)
215
-
216
- original_frame_list.clear()
217
- mask_list.clear()
218
-
219
- return dt_string+"_result.mp4"
220
 
221
 
222
  def get_first_frame(video):
@@ -253,6 +248,8 @@ def getStartEndPoints(mask):
253
 
254
 
255
  with gr.Blocks() as demo:
 
 
256
  with gr.Row():
257
  with gr.Column(scale=2):
258
  with gr.Row():
@@ -277,8 +274,8 @@ with gr.Blocks() as demo:
277
  approve_mask.click(lambda x: [x['image'], x['mask']], first_frame, [
278
  original_image, masked_image])
279
  track_mask.click(fn=track_and_mask, inputs=[
280
- in_video, original_image, masked_image], outputs=[out_video])
281
- inpaint.click(fn=inpaint_video, outputs=[out_video_inpaint])
282
 
283
 
284
  demo.launch(debug=True)
 
27
  sys.path.append(os.path.abspath(join(project_name, 'FGT_codes', 'LAFC',
28
  'flowCheckPoint', 'raft-things.pth')))
29
 
 
 
 
 
 
30
  exp_path = join(project_name, 'SiamMask/experiments/siammask_sharp')
31
  pretrained_path1 = join(exp_path, 'SiamMask_DAVIS.pth')
32
 
 
33
  print(sys.path)
34
 
35
  torch.set_grad_enabled(False)
 
52
  object_width = 0
53
  object_height = 0
54
  in_fps = 24
55
+
 
56
 
57
  parser = argparse.ArgumentParser()
58
  # parser.add_argument('--opt', default='configs/object_removal.yaml',
 
144
  return x1, y1, (x2-x1), (y2-y1)
145
 
146
 
147
+ def track_and_mask(vid, masked_frame, original_list, mask_list):
148
  x, y, w, h = getBoundaries(masked_frame)
149
  f = 0
150
 
 
176
  # track
177
  state = siamese_track(
178
  state, frame, mask_enable=True, refine_enable=True, device=device)
179
+ original_list.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
180
  location = state['ploygon'].flatten()
181
  mask = state['mask'] > state['p'].seg_thr
182
  frame[:, :, 2] = (mask > 0) * 255 + \
 
199
 
200
  print('Original Frame Count: ',len(original_frame_list))
201
  print('Mask Frame Count: ',len(mask_list))
202
+ return {out_video_inpaint:dt_string+"_output.avi",
203
+ original_frame_list: original_list,
204
+ mask_list: mask_list}
205
 
206
 
207
+ def inpaint_video(original_frame_list, mask_list):
208
  args.out_fps = in_fps
209
  video_inpainting(args, original_frame_list, mask_list)
210
+ original_frame_list = []
211
+ mask_list = []
212
+ return {out_video_inpaint:dt_string+"_result.mp4",
213
+ original_frame_list: original_frame_list,
214
+ mask_list: mask_list}
215
 
216
 
217
  def get_first_frame(video):
 
248
 
249
 
250
  with gr.Blocks() as demo:
251
+ original_frame_list = gr.State([])
252
+ mask_list = gr.State([])
253
  with gr.Row():
254
  with gr.Column(scale=2):
255
  with gr.Row():
 
274
  approve_mask.click(lambda x: [x['image'], x['mask']], first_frame, [
275
  original_image, masked_image])
276
  track_mask.click(fn=track_and_mask, inputs=[
277
+ in_video, masked_image, original_frame_list, mask_list], outputs=[out_video, original_frame_list, mask_list])
278
+ inpaint.click(fn=inpaint_video, outputs=[out_video_inpaint, original_frame_list, mask_list])
279
 
280
 
281
  demo.launch(debug=True)