sdsdsdadasd3 commited on
Commit
acd6fd7
·
1 Parent(s): 85331ff

[Release] v1.0.1

Browse files

- improve the performance
- improve efficiency

Files changed (3) hide show
  1. depthcrafter/utils.py +15 -67
  2. requirements.txt +3 -1
  3. run.py +99 -84
depthcrafter/utils.py CHANGED
@@ -1,79 +1,27 @@
 
 
1
  import numpy as np
2
- import cv2
3
  import matplotlib.cm as cm
 
4
  import torch
5
 
6
- dataset_res_dict = {
7
- "sintel":[448, 1024],
8
- "scannet":[640, 832],
9
- "kitti":[384, 1280],
10
- "bonn":[512, 640],
11
- "nyu":[448, 640],
12
- }
13
-
14
- def read_video_frames(video_path, process_length, target_fps, max_res, dataset):
15
- # a simple function to read video frames
16
- cap = cv2.VideoCapture(video_path)
17
- original_fps = cap.get(cv2.CAP_PROP_FPS)
18
- original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
19
- original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
20
- # round the height and width to the nearest multiple of 64
21
-
22
- if dataset=="open":
23
- height = round(original_height / 64) * 64
24
- width = round(original_width / 64) * 64
25
- else:
26
- height = dataset_res_dict[dataset][0]
27
- width = dataset_res_dict[dataset][1]
28
-
29
- # resize the video if the height or width is larger than max_res
30
- if max(height, width) > max_res:
31
- scale = max_res / max(original_height, original_width)
32
- height = round(original_height * scale / 64) * 64
33
- width = round(original_width * scale / 64) * 64
34
-
35
- if target_fps < 0:
36
- target_fps = original_fps
37
-
38
- stride = max(round(original_fps / target_fps), 1)
39
-
40
- frames = []
41
- frame_count = 0
42
- while cap.isOpened():
43
- ret, frame = cap.read()
44
- if not ret or (process_length > 0 and frame_count >= process_length):
45
- break
46
- if frame_count % stride == 0:
47
- frame = cv2.resize(frame, (width, height))
48
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
49
- frames.append(frame.astype("float32") / 255.0)
50
- frame_count += 1
51
- cap.release()
52
-
53
- frames = np.array(frames)
54
- return frames, target_fps
55
-
56
 
57
  def save_video(
58
- video_frames,
59
- output_video_path,
60
- fps: int = 15,
 
61
  ) -> str:
62
- # a simple function to save video frames
63
- height, width = video_frames[0].shape[:2]
64
- is_color = video_frames[0].ndim == 3
65
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
66
- video_writer = cv2.VideoWriter(
67
- output_video_path, fourcc, fps, (width, height), isColor=is_color
68
- )
69
 
70
- for frame in video_frames:
71
- frame = (frame * 255).astype(np.uint8)
72
- if is_color:
73
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
74
- video_writer.write(frame)
75
 
76
- video_writer.release()
 
 
77
  return output_video_path
78
 
79
 
 
1
+ from typing import Union, List
2
+ import tempfile
3
  import numpy as np
4
+ import PIL.Image
5
  import matplotlib.cm as cm
6
+ import mediapy
7
  import torch
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def save_video(
11
+ video_frames: Union[List[np.ndarray], List[PIL.Image.Image]],
12
+ output_video_path: str = None,
13
+ fps: int = 10,
14
+ crf: int = 18,
15
  ) -> str:
16
+ if output_video_path is None:
17
+ output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
 
 
 
 
 
18
 
19
+ if isinstance(video_frames[0], np.ndarray):
20
+ video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]
 
 
 
21
 
22
+ elif isinstance(video_frames[0], PIL.Image.Image):
23
+ video_frames = [np.array(frame) for frame in video_frames]
24
+ mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf)
25
  return output_video_path
26
 
27
 
requirements.txt CHANGED
@@ -2,7 +2,9 @@ torch==2.0.1
2
  diffusers==0.29.1
3
  numpy==1.26.4
4
  matplotlib==3.8.4
5
- opencv-python==4.8.1.78
6
  transformers==4.41.2
7
  accelerate==0.30.1
8
  xformers==0.0.20
 
 
 
 
2
  diffusers==0.29.1
3
  numpy==1.26.4
4
  matplotlib==3.8.4
 
5
  transformers==4.41.2
6
  accelerate==0.30.1
7
  xformers==0.0.20
8
+ mediapy==1.2.0
9
+ fire==0.6.0
10
+ decord==0.6.0
run.py CHANGED
@@ -2,12 +2,22 @@ import gc
2
  import os
3
  import numpy as np
4
  import torch
5
- import argparse
 
6
  from diffusers.training_utils import set_seed
 
7
 
8
  from depthcrafter.depth_crafter_ppl import DepthCrafterPipeline
9
  from depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter
10
- from depthcrafter.utils import vis_sequence_depth, save_video, read_video_frames
 
 
 
 
 
 
 
 
11
 
12
 
13
  class DepthCrafterDemo:
@@ -49,6 +59,45 @@ class DepthCrafterDemo:
49
  print("Xformers is not enabled")
50
  self.pipe.enable_attention_slicing()
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def infer(
53
  self,
54
  video: str,
@@ -67,11 +116,13 @@ class DepthCrafterDemo:
67
  ):
68
  set_seed(seed)
69
 
70
- frames, target_fps = read_video_frames(
71
- video, process_length, target_fps, max_res, dataset,
 
 
 
 
72
  )
73
- print(f"==> video name: {video}, frames shape: {frames.shape}")
74
-
75
  # inference the depth map using the DepthCrafter pipeline
76
  with torch.inference_mode():
77
  res = self.pipe(
@@ -128,91 +179,55 @@ class DepthCrafterDemo:
128
  return res_path[:2]
129
 
130
 
131
- if __name__ == "__main__":
132
- # running configs
133
- # the most important arguments for memory saving are `cpu_offload`, `enable_xformers`, `max_res`, and `window_size`
134
- # the most important arguments for trade-off between quality and speed are
135
- # `num_inference_steps`, `guidance_scale`, and `max_res`
136
- parser = argparse.ArgumentParser(description="DepthCrafter")
137
- parser.add_argument(
138
- "--video-path", type=str, required=True, help="Path to the input video file(s)"
139
- )
140
- parser.add_argument(
141
- "--save-folder",
142
- type=str,
143
- default="./demo_output",
144
- help="Folder to save the output",
145
- )
146
- parser.add_argument(
147
- "--unet-path",
148
- type=str,
149
- default="tencent/DepthCrafter",
150
- help="Path to the UNet model",
151
- )
152
- parser.add_argument(
153
- "--pre-train-path",
154
- type=str,
155
- default="stabilityai/stable-video-diffusion-img2vid-xt",
156
- help="Path to the pre-trained model",
157
- )
158
- parser.add_argument(
159
- "--process-length", type=int, default=195, help="Number of frames to process"
160
- )
161
- parser.add_argument(
162
- "--cpu-offload",
163
- type=str,
164
- default="model",
165
- choices=["model", "sequential", None],
166
- help="CPU offload option",
167
- )
168
- parser.add_argument(
169
- "--target-fps", type=int, default=15, help="Target FPS for the output video"
170
- ) # -1 for original fps
171
- parser.add_argument("--seed", type=int, default=42, help="Random seed")
172
- parser.add_argument(
173
- "--num-inference-steps", type=int, default=25, help="Number of inference steps"
174
- )
175
- parser.add_argument(
176
- "--guidance-scale", type=float, default=1.2, help="Guidance scale"
177
- )
178
- parser.add_argument("--window-size", type=int, default=110, help="Window size")
179
- parser.add_argument("--overlap", type=int, default=25, help="Overlap size")
180
- parser.add_argument("--max-res", type=int, default=1024, help="Maximum resolution")
181
- parser.add_argument(
182
- "--dataset",
183
- type=str,
184
- default="open",
185
- choices=["open", "sintel", "scannet", "kitti", "bonn", 'nyu'],
186
- help="Assigned resolution for specific dataset evaluation"
187
- )
188
- parser.add_argument("--save_npz", type=bool, default=True, help="Save npz file")
189
- parser.add_argument("--track_time", type=bool, default=False, help="Track time")
190
-
191
- args = parser.parse_args()
192
-
193
  depthcrafter_demo = DepthCrafterDemo(
194
- unet_path=args.unet_path,
195
- pre_train_path=args.pre_train_path,
196
- cpu_offload=args.cpu_offload,
197
  )
198
  # process the videos, the video paths are separated by comma
199
- video_paths = args.video_path.split(",")
200
  for video in video_paths:
201
  depthcrafter_demo.infer(
202
  video,
203
- args.num_inference_steps,
204
- args.guidance_scale,
205
- save_folder=args.save_folder,
206
- window_size=args.window_size,
207
- process_length=args.process_length,
208
- overlap=args.overlap,
209
- max_res=args.max_res,
210
- dataset=args.dataset,
211
- target_fps=args.target_fps,
212
- seed=args.seed,
213
- track_time=args.track_time,
214
- save_npz=args.save_npz,
215
  )
216
  # clear the cache for the next video
217
  gc.collect()
218
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
2
  import os
3
  import numpy as np
4
  import torch
5
+
6
+ from decord import VideoReader, cpu
7
  from diffusers.training_utils import set_seed
8
+ from fire import Fire
9
 
10
  from depthcrafter.depth_crafter_ppl import DepthCrafterPipeline
11
  from depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter
12
+ from depthcrafter.utils import vis_sequence_depth, save_video
13
+
14
+ dataset_res_dict = {
15
+ "sintel": [448, 1024],
16
+ "scannet": [640, 832],
17
+ "KITTI": [384, 1280],
18
+ "bonn": [512, 640],
19
+ "NYUv2": [448, 640],
20
+ }
21
 
22
 
23
  class DepthCrafterDemo:
 
59
  print("Xformers is not enabled")
60
  self.pipe.enable_attention_slicing()
61
 
62
+ @staticmethod
63
+ def read_video_frames(
64
+ video_path, process_length, target_fps, max_res, dataset="open"
65
+ ):
66
+ if dataset == "open":
67
+ print("==> processing video: ", video_path)
68
+ vid = VideoReader(video_path, ctx=cpu(0))
69
+ print(
70
+ "==> original video shape: ", (len(vid), *vid.get_batch([0]).shape[1:])
71
+ )
72
+ original_height, original_width = vid.get_batch([0]).shape[1:3]
73
+ height = round(original_height / 64) * 64
74
+ width = round(original_width / 64) * 64
75
+ if max(height, width) > max_res:
76
+ scale = max_res / max(original_height, original_width)
77
+ height = round(original_height * scale / 64) * 64
78
+ width = round(original_width * scale / 64) * 64
79
+ else:
80
+ height = dataset_res_dict[dataset][0]
81
+ width = dataset_res_dict[dataset][1]
82
+
83
+ vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
84
+
85
+ fps = vid.get_avg_fps() if target_fps == -1 else target_fps
86
+ stride = round(vid.get_avg_fps() / fps)
87
+ stride = max(stride, 1)
88
+ frames_idx = list(range(0, len(vid), stride))
89
+ print(
90
+ f"==> downsampled shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}, with stride: {stride}"
91
+ )
92
+ if process_length != -1 and process_length < len(frames_idx):
93
+ frames_idx = frames_idx[:process_length]
94
+ print(
95
+ f"==> final processing shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}"
96
+ )
97
+ frames = vid.get_batch(frames_idx).asnumpy().astype("float32") / 255.0
98
+
99
+ return frames, fps
100
+
101
  def infer(
102
  self,
103
  video: str,
 
116
  ):
117
  set_seed(seed)
118
 
119
+ frames, target_fps = self.read_video_frames(
120
+ video,
121
+ process_length,
122
+ target_fps,
123
+ max_res,
124
+ dataset,
125
  )
 
 
126
  # inference the depth map using the DepthCrafter pipeline
127
  with torch.inference_mode():
128
  res = self.pipe(
 
179
  return res_path[:2]
180
 
181
 
182
+ def main(
183
+ video_path: str,
184
+ save_folder: str = "./demo_output",
185
+ unet_path: str = "tencent/DepthCrafter",
186
+ pre_train_path: str = "stabilityai/stable-video-diffusion-img2vid-xt",
187
+ process_length: int = -1,
188
+ cpu_offload: str = "model",
189
+ target_fps: int = -1,
190
+ seed: int = 42,
191
+ num_inference_steps: int = 5,
192
+ guidance_scale: float = 1.0,
193
+ window_size: int = 110,
194
+ overlap: int = 25,
195
+ max_res: int = 1024,
196
+ dataset: str = "open",
197
+ save_npz: bool = True,
198
+ track_time: bool = False,
199
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  depthcrafter_demo = DepthCrafterDemo(
201
+ unet_path=unet_path,
202
+ pre_train_path=pre_train_path,
203
+ cpu_offload=cpu_offload,
204
  )
205
  # process the videos, the video paths are separated by comma
206
+ video_paths = video_path.split(",")
207
  for video in video_paths:
208
  depthcrafter_demo.infer(
209
  video,
210
+ num_inference_steps,
211
+ guidance_scale,
212
+ save_folder=save_folder,
213
+ window_size=window_size,
214
+ process_length=process_length,
215
+ overlap=overlap,
216
+ max_res=max_res,
217
+ dataset=dataset,
218
+ target_fps=target_fps,
219
+ seed=seed,
220
+ track_time=track_time,
221
+ save_npz=save_npz,
222
  )
223
  # clear the cache for the next video
224
  gc.collect()
225
  torch.cuda.empty_cache()
226
+
227
+
228
+ if __name__ == "__main__":
229
+ # running configs
230
+ # the most important arguments for memory saving are `cpu_offload`, `enable_xformers`, `max_res`, and `window_size`
231
+ # the most important arguments for trade-off between quality and speed are
232
+ # `num_inference_steps`, `guidance_scale`, and `max_res`
233
+ Fire(main)