Spaces:
Running
Running
Update mimicmotion/pipelines/pipeline_mimicmotion.py
Browse files
mimicmotion/pipelines/pipeline_mimicmotion.py
CHANGED
@@ -222,40 +222,33 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
222 |
decode_chunk_size: int = 8):
|
223 |
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
|
224 |
latents = latents.flatten(0, 1)
|
|
|
225 |
latents = 1 / self.vae.config.scaling_factor * latents
|
226 |
-
|
227 |
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
|
228 |
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
|
229 |
-
|
230 |
-
#
|
231 |
-
def process_chunk(start, end, frames_list):
|
232 |
-
decode_kwargs = {}
|
233 |
-
if accepts_num_frames:
|
234 |
-
decode_kwargs["num_frames"] = end - start
|
235 |
-
frame = self.vae.decode(latents[start:end], **decode_kwargs).sample
|
236 |
-
frames_list.append(frame.cpu())
|
237 |
-
|
238 |
-
threads = []
|
239 |
frames = []
|
240 |
-
|
241 |
-
# Dividindo o trabalho em chunks e criando threads para processá-los
|
242 |
for i in range(0, latents.shape[0], decode_chunk_size):
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
252 |
frames = torch.cat(frames, dim=0)
|
|
|
|
|
253 |
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
|
254 |
-
|
255 |
-
#
|
256 |
frames = frames.float()
|
257 |
return frames
|
258 |
|
|
|
259 |
def check_inputs(self, image, height, width):
|
260 |
if (
|
261 |
not isinstance(image, torch.Tensor)
|
@@ -563,17 +556,21 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
563 |
# expand the latents if we are doing classifier free guidance
|
564 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
565 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
566 |
-
|
567 |
# Concatenate image_latents over channels dimension
|
568 |
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
569 |
-
|
570 |
# predict the noise residual
|
571 |
noise_pred = torch.zeros_like(image_latents)
|
572 |
noise_pred_cnt = image_latents.new_zeros((num_frames,))
|
573 |
weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size
|
574 |
weight = torch.minimum(weight, 2 - weight)
|
575 |
-
|
576 |
-
|
|
|
|
|
|
|
|
|
577 |
# classification-free inference
|
578 |
pose_latents = self.pose_net(image_pose[idx].to(device))
|
579 |
_noise_pred = self.unet(
|
@@ -585,8 +582,8 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
585 |
image_only_indicator=image_only_indicator,
|
586 |
return_dict=False,
|
587 |
)[0]
|
588 |
-
|
589 |
-
|
590 |
# normal inference
|
591 |
_noise_pred = self.unet(
|
592 |
latent_model_input[1:, idx],
|
@@ -597,26 +594,34 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
597 |
image_only_indicator=image_only_indicator,
|
598 |
return_dict=False,
|
599 |
)[0]
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
noise_pred.div_(noise_pred_cnt[:, None, None, None])
|
605 |
-
|
606 |
# perform guidance
|
607 |
if self.do_classifier_free_guidance:
|
608 |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
609 |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
610 |
-
|
611 |
# compute the previous noisy sample x_t -> x_t-1
|
612 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
613 |
-
|
614 |
if callback_on_step_end is not None:
|
615 |
callback_kwargs = {}
|
616 |
for k in callback_on_step_end_tensor_inputs:
|
617 |
callback_kwargs[k] = locals()[k]
|
618 |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
619 |
-
|
620 |
latents = callback_outputs.pop("latents", latents)
|
621 |
|
622 |
self.pose_net.cpu()
|
|
|
222 |
decode_chunk_size: int = 8):
|
223 |
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
|
224 |
latents = latents.flatten(0, 1)
|
225 |
+
|
226 |
latents = 1 / self.vae.config.scaling_factor * latents
|
227 |
+
|
228 |
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
|
229 |
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
|
230 |
+
|
231 |
+
# decode decode_chunk_size frames at a time to avoid OOM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
frames = []
|
|
|
|
|
233 |
for i in range(0, latents.shape[0], decode_chunk_size):
|
234 |
+
num_frames_in = latents[i: i + decode_chunk_size].shape[0]
|
235 |
+
decode_kwargs = {}
|
236 |
+
if accepts_num_frames:
|
237 |
+
# we only pass num_frames_in if it's expected
|
238 |
+
decode_kwargs["num_frames"] = num_frames_in
|
239 |
+
|
240 |
+
frame = self.vae.decode(latents[i: i + decode_chunk_size], **decode_kwargs).sample
|
241 |
+
frames.append(frame.cpu())
|
|
|
242 |
frames = torch.cat(frames, dim=0)
|
243 |
+
|
244 |
+
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
245 |
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
|
246 |
+
|
247 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
248 |
frames = frames.float()
|
249 |
return frames
|
250 |
|
251 |
+
|
252 |
def check_inputs(self, image, height, width):
|
253 |
if (
|
254 |
not isinstance(image, torch.Tensor)
|
|
|
556 |
# expand the latents if we are doing classifier free guidance
|
557 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
558 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
559 |
+
|
560 |
# Concatenate image_latents over channels dimension
|
561 |
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
562 |
+
|
563 |
# predict the noise residual
|
564 |
noise_pred = torch.zeros_like(image_latents)
|
565 |
noise_pred_cnt = image_latents.new_zeros((num_frames,))
|
566 |
weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size
|
567 |
weight = torch.minimum(weight, 2 - weight)
|
568 |
+
|
569 |
+
# Paralelização do loop sobre `indices` usando ThreadPoolExecutor
|
570 |
+
def process_index(idx):
|
571 |
+
nonlocal noise_pred, noise_pred_cnt
|
572 |
+
result = torch.zeros_like(image_latents[:1, idx]) # Placeholder for thread-safe accumulation
|
573 |
+
|
574 |
# classification-free inference
|
575 |
pose_latents = self.pose_net(image_pose[idx].to(device))
|
576 |
_noise_pred = self.unet(
|
|
|
582 |
image_only_indicator=image_only_indicator,
|
583 |
return_dict=False,
|
584 |
)[0]
|
585 |
+
result[:1] += _noise_pred * weight[:, None, None, None]
|
586 |
+
|
587 |
# normal inference
|
588 |
_noise_pred = self.unet(
|
589 |
latent_model_input[1:, idx],
|
|
|
594 |
image_only_indicator=image_only_indicator,
|
595 |
return_dict=False,
|
596 |
)[0]
|
597 |
+
result[1:] += _noise_pred * weight[:, None, None, None]
|
598 |
+
|
599 |
+
return result, idx
|
600 |
+
|
601 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
602 |
+
futures = [executor.submit(process_index, idx) for idx in indices]
|
603 |
+
for future in concurrent.futures.as_completed(futures):
|
604 |
+
_noise_pred, idx = future.result()
|
605 |
+
noise_pred[:, idx] += _noise_pred
|
606 |
+
noise_pred_cnt[idx] += weight
|
607 |
+
progress_bar.update()
|
608 |
+
|
609 |
noise_pred.div_(noise_pred_cnt[:, None, None, None])
|
610 |
+
|
611 |
# perform guidance
|
612 |
if self.do_classifier_free_guidance:
|
613 |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
614 |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
615 |
+
|
616 |
# compute the previous noisy sample x_t -> x_t-1
|
617 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
618 |
+
|
619 |
if callback_on_step_end is not None:
|
620 |
callback_kwargs = {}
|
621 |
for k in callback_on_step_end_tensor_inputs:
|
622 |
callback_kwargs[k] = locals()[k]
|
623 |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
624 |
+
|
625 |
latents = callback_outputs.pop("latents", latents)
|
626 |
|
627 |
self.pose_net.cpu()
|