guardiancc commited on
Commit
f3672f8
·
verified ·
1 Parent(s): 46da058

Update mimicmotion/pipelines/pipeline_mimicmotion.py

Browse files
mimicmotion/pipelines/pipeline_mimicmotion.py CHANGED
@@ -222,40 +222,33 @@ class MimicMotionPipeline(DiffusionPipeline):
222
  decode_chunk_size: int = 8):
223
  # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
224
  latents = latents.flatten(0, 1)
 
225
  latents = 1 / self.vae.config.scaling_factor * latents
226
-
227
  forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
228
  accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
229
-
230
- # Função auxiliar para processar um chunk de frames
231
- def process_chunk(start, end, frames_list):
232
- decode_kwargs = {}
233
- if accepts_num_frames:
234
- decode_kwargs["num_frames"] = end - start
235
- frame = self.vae.decode(latents[start:end], **decode_kwargs).sample
236
- frames_list.append(frame.cpu())
237
-
238
- threads = []
239
  frames = []
240
-
241
- # Dividindo o trabalho em chunks e criando threads para processá-los
242
  for i in range(0, latents.shape[0], decode_chunk_size):
243
- t = threading.Thread(target=process_chunk, args=(i, i + decode_chunk_size, frames))
244
- threads.append(t)
245
- t.start()
246
-
247
- # Aguardando todas as threads terminarem
248
- for t in threads:
249
- t.join()
250
-
251
- # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
252
  frames = torch.cat(frames, dim=0)
 
 
253
  frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
254
-
255
- # Cast para float32 para compatibilidade com bfloat16
256
  frames = frames.float()
257
  return frames
258
 
 
259
  def check_inputs(self, image, height, width):
260
  if (
261
  not isinstance(image, torch.Tensor)
@@ -563,17 +556,21 @@ class MimicMotionPipeline(DiffusionPipeline):
563
  # expand the latents if we are doing classifier free guidance
564
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
565
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
566
-
567
  # Concatenate image_latents over channels dimension
568
  latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
569
-
570
  # predict the noise residual
571
  noise_pred = torch.zeros_like(image_latents)
572
  noise_pred_cnt = image_latents.new_zeros((num_frames,))
573
  weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size
574
  weight = torch.minimum(weight, 2 - weight)
575
- for idx in indices:
576
-
 
 
 
 
577
  # classification-free inference
578
  pose_latents = self.pose_net(image_pose[idx].to(device))
579
  _noise_pred = self.unet(
@@ -585,8 +582,8 @@ class MimicMotionPipeline(DiffusionPipeline):
585
  image_only_indicator=image_only_indicator,
586
  return_dict=False,
587
  )[0]
588
- noise_pred[:1, idx] += _noise_pred * weight[:, None, None, None]
589
-
590
  # normal inference
591
  _noise_pred = self.unet(
592
  latent_model_input[1:, idx],
@@ -597,26 +594,34 @@ class MimicMotionPipeline(DiffusionPipeline):
597
  image_only_indicator=image_only_indicator,
598
  return_dict=False,
599
  )[0]
600
- noise_pred[1:, idx] += _noise_pred * weight[:, None, None, None]
601
-
602
- noise_pred_cnt[idx] += weight
603
- progress_bar.update()
 
 
 
 
 
 
 
 
604
  noise_pred.div_(noise_pred_cnt[:, None, None, None])
605
-
606
  # perform guidance
607
  if self.do_classifier_free_guidance:
608
  noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
609
  noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
610
-
611
  # compute the previous noisy sample x_t -> x_t-1
612
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
613
-
614
  if callback_on_step_end is not None:
615
  callback_kwargs = {}
616
  for k in callback_on_step_end_tensor_inputs:
617
  callback_kwargs[k] = locals()[k]
618
  callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
619
-
620
  latents = callback_outputs.pop("latents", latents)
621
 
622
  self.pose_net.cpu()
 
222
  decode_chunk_size: int = 8):
223
  # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
224
  latents = latents.flatten(0, 1)
225
+
226
  latents = 1 / self.vae.config.scaling_factor * latents
227
+
228
  forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
229
  accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
230
+
231
+ # decode decode_chunk_size frames at a time to avoid OOM
 
 
 
 
 
 
 
 
232
  frames = []
 
 
233
  for i in range(0, latents.shape[0], decode_chunk_size):
234
+ num_frames_in = latents[i: i + decode_chunk_size].shape[0]
235
+ decode_kwargs = {}
236
+ if accepts_num_frames:
237
+ # we only pass num_frames_in if it's expected
238
+ decode_kwargs["num_frames"] = num_frames_in
239
+
240
+ frame = self.vae.decode(latents[i: i + decode_chunk_size], **decode_kwargs).sample
241
+ frames.append(frame.cpu())
 
242
  frames = torch.cat(frames, dim=0)
243
+
244
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
245
  frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
246
+
247
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
248
  frames = frames.float()
249
  return frames
250
 
251
+
252
  def check_inputs(self, image, height, width):
253
  if (
254
  not isinstance(image, torch.Tensor)
 
556
  # expand the latents if we are doing classifier free guidance
557
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
558
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
559
+
560
  # Concatenate image_latents over channels dimension
561
  latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
562
+
563
  # predict the noise residual
564
  noise_pred = torch.zeros_like(image_latents)
565
  noise_pred_cnt = image_latents.new_zeros((num_frames,))
566
  weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size
567
  weight = torch.minimum(weight, 2 - weight)
568
+
569
+ # Paralelização do loop sobre `indices` usando ThreadPoolExecutor
570
+ def process_index(idx):
571
+ nonlocal noise_pred, noise_pred_cnt
572
+ result = torch.zeros_like(image_latents[:1, idx]) # Placeholder for thread-safe accumulation
573
+
574
  # classification-free inference
575
  pose_latents = self.pose_net(image_pose[idx].to(device))
576
  _noise_pred = self.unet(
 
582
  image_only_indicator=image_only_indicator,
583
  return_dict=False,
584
  )[0]
585
+ result[:1] += _noise_pred * weight[:, None, None, None]
586
+
587
  # normal inference
588
  _noise_pred = self.unet(
589
  latent_model_input[1:, idx],
 
594
  image_only_indicator=image_only_indicator,
595
  return_dict=False,
596
  )[0]
597
+ result[1:] += _noise_pred * weight[:, None, None, None]
598
+
599
+ return result, idx
600
+
601
+ with concurrent.futures.ThreadPoolExecutor() as executor:
602
+ futures = [executor.submit(process_index, idx) for idx in indices]
603
+ for future in concurrent.futures.as_completed(futures):
604
+ _noise_pred, idx = future.result()
605
+ noise_pred[:, idx] += _noise_pred
606
+ noise_pred_cnt[idx] += weight
607
+ progress_bar.update()
608
+
609
  noise_pred.div_(noise_pred_cnt[:, None, None, None])
610
+
611
  # perform guidance
612
  if self.do_classifier_free_guidance:
613
  noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
614
  noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
615
+
616
  # compute the previous noisy sample x_t -> x_t-1
617
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
618
+
619
  if callback_on_step_end is not None:
620
  callback_kwargs = {}
621
  for k in callback_on_step_end_tensor_inputs:
622
  callback_kwargs[k] = locals()[k]
623
  callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
624
+
625
  latents = callback_outputs.pop("latents", latents)
626
 
627
  self.pose_net.cpu()