DepthCrafter / depthcrafter /depth_crafter_ppl.py
wbhu-tc's picture
update
7c1a14b
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
_resize_with_antialiasing,
StableVideoDiffusionPipelineOutput,
StableVideoDiffusionPipeline,
retrieve_timesteps,
)
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class DepthCrafterPipeline(StableVideoDiffusionPipeline):
@torch.inference_mode()
def encode_video(
self,
video: torch.Tensor,
chunk_size: int = 14,
) -> torch.Tensor:
"""
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
:param chunk_size: the chunk size to encode video
:return: image_embeddings in shape of [b, 1024]
"""
video_224 = _resize_with_antialiasing(video.float(), (224, 224))
video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1]
embeddings = []
for i in range(0, video_224.shape[0], chunk_size):
tmp = self.feature_extractor(
images=video_224[i : i + chunk_size],
do_normalize=True,
do_center_crop=False,
do_resize=False,
do_rescale=False,
return_tensors="pt",
).pixel_values.to(video.device, dtype=video.dtype)
embeddings.append(self.image_encoder(tmp).image_embeds) # [b, 1024]
embeddings = torch.cat(embeddings, dim=0) # [t, 1024]
return embeddings
@torch.inference_mode()
def encode_vae_video(
self,
video: torch.Tensor,
chunk_size: int = 14,
):
"""
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
:param chunk_size: the chunk size to encode video
:return: vae latents in shape of [b, c, h, w]
"""
video_latents = []
for i in range(0, video.shape[0], chunk_size):
video_latents.append(
self.vae.encode(video[i : i + chunk_size]).latent_dist.mode()
)
video_latents = torch.cat(video_latents, dim=0)
return video_latents
@staticmethod
def check_inputs(video, height, width):
"""
:param video:
:param height:
:param width:
:return:
"""
if not isinstance(video, torch.Tensor) and not isinstance(video, np.ndarray):
raise ValueError(
f"Expected `video` to be a `torch.Tensor` or `VideoReader`, but got a {type(video)}"
)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
)
@torch.no_grad()
def __call__(
self,
video: Union[np.ndarray, torch.Tensor],
height: int = 576,
width: int = 1024,
num_inference_steps: int = 25,
guidance_scale: float = 1.0,
window_size: Optional[int] = 110,
noise_aug_strength: float = 0.02,
decode_chunk_size: Optional[int] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
return_dict: bool = True,
overlap: int = 25,
track_time: bool = False,
):
"""
:param video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1]
:param height:
:param width:
:param num_inference_steps:
:param guidance_scale:
:param window_size: sliding window processing size
:param fps:
:param motion_bucket_id:
:param noise_aug_strength:
:param decode_chunk_size:
:param generator:
:param latents:
:param output_type:
:param callback_on_step_end:
:param callback_on_step_end_tensor_inputs:
:param return_dict:
:return:
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
num_frames = video.shape[0]
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
if num_frames <= window_size:
window_size = num_frames
overlap = 0
stride = window_size - overlap
# 1. Check inputs. Raise error if not correct
self.check_inputs(video, height, width)
# 2. Define call parameters
batch_size = 1
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
self._guidance_scale = guidance_scale
# 3. Encode input video
if isinstance(video, np.ndarray):
video = torch.from_numpy(video.transpose(0, 3, 1, 2))
else:
assert isinstance(video, torch.Tensor)
video = video.to(device=device, dtype=self.dtype)
video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w]
if track_time:
start_event = torch.cuda.Event(enable_timing=True)
encode_event = torch.cuda.Event(enable_timing=True)
denoise_event = torch.cuda.Event(enable_timing=True)
decode_event = torch.cuda.Event(enable_timing=True)
start_event.record()
video_embeddings = self.encode_video(
video, chunk_size=decode_chunk_size
).unsqueeze(
0
) # [1, t, 1024]
torch.cuda.empty_cache()
# 4. Encode input image using VAE
noise = randn_tensor(
video.shape, generator=generator, device=device, dtype=video.dtype
)
video = video + noise_aug_strength * noise # in [t, c, h, w]
# pdb.set_trace()
needs_upcasting = (
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
)
if needs_upcasting:
self.vae.to(dtype=torch.float32)
video_latents = self.encode_vae_video(
video.to(self.vae.dtype),
chunk_size=decode_chunk_size,
).unsqueeze(
0
) # [1, t, c, h, w]
if track_time:
encode_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(encode_event)
print(f"Elapsed time for encoding video: {elapsed_time_ms} ms")
torch.cuda.empty_cache()
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
# 5. Get Added Time IDs
added_time_ids = self._get_add_time_ids(
7,
127,
noise_aug_strength,
video_embeddings.dtype,
batch_size,
1,
False,
) # [1 or 2, 3]
added_time_ids = added_time_ids.to(device)
# 6. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, None, None
)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
# 7. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents_init = self.prepare_latents(
batch_size,
window_size,
num_channels_latents,
height,
width,
video_embeddings.dtype,
device,
generator,
latents,
) # [1, t, c, h, w]
latents_all = None
idx_start = 0
if overlap > 0:
weights = torch.linspace(0, 1, overlap, device=device)
weights = weights.view(1, overlap, 1, 1, 1)
else:
weights = None
torch.cuda.empty_cache()
# inference strategy for long videos
# two main strategies: 1. noise init from previous frame, 2. segments stitching
while idx_start < num_frames - overlap:
idx_end = min(idx_start + window_size, num_frames)
self.scheduler.set_timesteps(num_inference_steps, device=device)
# 9. Denoising loop
latents = latents_init[:, : idx_end - idx_start].clone()
latents_init = torch.cat(
[latents_init[:, -overlap:], latents_init[:, :stride]], dim=1
)
video_latents_current = video_latents[:, idx_start:idx_end]
video_embeddings_current = video_embeddings[:, idx_start:idx_end]
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if latents_all is not None and i == 0:
latents[:, :overlap] = (
latents_all[:, -overlap:]
+ latents[:, :overlap]
/ self.scheduler.init_noise_sigma
* self.scheduler.sigmas[i]
)
latent_model_input = latents # [1, t, c, h, w]
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
) # [1, t, c, h, w]
latent_model_input = torch.cat(
[latent_model_input, video_latents_current], dim=2
)
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=video_embeddings_current,
added_time_ids=added_time_ids,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
latent_model_input = latents
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = torch.cat(
[latent_model_input, torch.zeros_like(latent_model_input)],
dim=2,
)
noise_pred_uncond = self.unet(
latent_model_input,
t,
encoder_hidden_states=torch.zeros_like(
video_embeddings_current
),
added_time_ids=added_time_ids,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred - noise_pred_uncond
)
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(
self, i, t, callback_kwargs
)
latents = callback_outputs.pop("latents", latents)
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps
and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if latents_all is None:
latents_all = latents.clone()
else:
assert weights is not None
# latents_all[:, -overlap:] = (
# latents[:, :overlap] + latents_all[:, -overlap:]
# ) / 2.0
latents_all[:, -overlap:] = latents[
:, :overlap
] * weights + latents_all[:, -overlap:] * (1 - weights)
latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1)
idx_start += stride
if track_time:
denoise_event.record()
torch.cuda.synchronize()
elapsed_time_ms = encode_event.elapsed_time(denoise_event)
print(f"Elapsed time for denoising video: {elapsed_time_ms} ms")
if not output_type == "latent":
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
frames = self.decode_latents(latents_all, num_frames, decode_chunk_size)
if track_time:
decode_event.record()
torch.cuda.synchronize()
elapsed_time_ms = denoise_event.elapsed_time(decode_event)
print(f"Elapsed time for decoding video: {elapsed_time_ms} ms")
frames = self.video_processor.postprocess_video(
video=frames, output_type=output_type
)
else:
frames = latents_all
self.maybe_free_model_hooks()
if not return_dict:
return frames
return StableVideoDiffusionPipelineOutput(frames=frames)