svjack commited on
Commit
a042cad
·
verified ·
1 Parent(s): 27e6007

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +22 -0
  2. .ipynb_checkpoints/README-checkpoint.md +12 -0
  3. README.md +10 -10
  4. __assets__/feature_visualization.png +3 -0
  5. __assets__/pipeline.png +3 -0
  6. __assets__/teaser.gif +3 -0
  7. __assets__/teaser.mp4 +3 -0
  8. condition_images/rgb/dog_on_grass.png +3 -0
  9. condition_images/scribble/lion_forest.png +0 -0
  10. configs/i2v_rgb.jsonl +1 -0
  11. configs/i2v_rgb.yaml +20 -0
  12. configs/i2v_sketch.jsonl +1 -0
  13. configs/i2v_sketch.yaml +20 -0
  14. configs/model_config/inference-v1.yaml +25 -0
  15. configs/model_config/inference-v2.yaml +24 -0
  16. configs/model_config/inference-v3.yaml +22 -0
  17. configs/model_config/model_config copy.yaml +22 -0
  18. configs/model_config/model_config.yaml +21 -0
  19. configs/model_config/model_config_public.yaml +25 -0
  20. configs/sparsectrl/image_condition.yaml +17 -0
  21. configs/sparsectrl/latent_condition.yaml +17 -0
  22. configs/t2v_camera.jsonl +12 -0
  23. configs/t2v_camera.yaml +19 -0
  24. configs/t2v_object.jsonl +6 -0
  25. configs/t2v_object.yaml +19 -0
  26. environment.yaml +25 -0
  27. generated_videos/camera_zoom_out_Dog,_lying_on_the_grass76739_76739.mp4 +3 -0
  28. generated_videos/inference_config.json +21 -0
  29. generated_videos/sample_white_tiger_Lion,_walks_in_the_forest76739_76739.mp4 +3 -0
  30. i2v_video_app.py +284 -0
  31. i2v_video_sample.py +157 -0
  32. models/Motion_Module/Put motion module checkpoints here.txt +0 -0
  33. motionclone/models/__pycache__/attention.cpython-310.pyc +0 -0
  34. motionclone/models/__pycache__/attention.cpython-38.pyc +0 -0
  35. motionclone/models/__pycache__/motion_module.cpython-310.pyc +0 -0
  36. motionclone/models/__pycache__/motion_module.cpython-38.pyc +0 -0
  37. motionclone/models/__pycache__/resnet.cpython-310.pyc +0 -0
  38. motionclone/models/__pycache__/resnet.cpython-38.pyc +0 -0
  39. motionclone/models/__pycache__/sparse_controlnet.cpython-38.pyc +0 -0
  40. motionclone/models/__pycache__/unet.cpython-310.pyc +0 -0
  41. motionclone/models/__pycache__/unet.cpython-38.pyc +0 -0
  42. motionclone/models/__pycache__/unet_blocks.cpython-310.pyc +0 -0
  43. motionclone/models/__pycache__/unet_blocks.cpython-38.pyc +0 -0
  44. motionclone/models/attention.py +611 -0
  45. motionclone/models/motion_module.py +347 -0
  46. motionclone/models/resnet.py +218 -0
  47. motionclone/models/scheduler.py +155 -0
  48. motionclone/models/sparse_controlnet.py +593 -0
  49. motionclone/models/unet.py +515 -0
  50. motionclone/models/unet_blocks.py +760 -0
.gitattributes CHANGED
@@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ __assets__/feature_visualization.png filter=lfs diff=lfs merge=lfs -text
37
+ __assets__/pipeline.png filter=lfs diff=lfs merge=lfs -text
38
+ __assets__/teaser.gif filter=lfs diff=lfs merge=lfs -text
39
+ __assets__/teaser.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ condition_images/rgb/dog_on_grass.png filter=lfs diff=lfs merge=lfs -text
41
+ generated_videos/camera_zoom_out_Dog,_lying_on_the_grass76739_76739.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ generated_videos/sample_white_tiger_Lion,_walks_in_the_forest76739_76739.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ reference_videos/camera_1.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ reference_videos/camera_pan_down.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ reference_videos/camera_pan_up.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ reference_videos/camera_translation_1.mp4 filter=lfs diff=lfs merge=lfs -text
47
+ reference_videos/camera_translation_2.mp4 filter=lfs diff=lfs merge=lfs -text
48
+ reference_videos/camera_zoom_in.mp4 filter=lfs diff=lfs merge=lfs -text
49
+ reference_videos/camera_zoom_out.mp4 filter=lfs diff=lfs merge=lfs -text
50
+ reference_videos/sample_astronaut.mp4 filter=lfs diff=lfs merge=lfs -text
51
+ reference_videos/sample_blackswan.mp4 filter=lfs diff=lfs merge=lfs -text
52
+ reference_videos/sample_cat.mp4 filter=lfs diff=lfs merge=lfs -text
53
+ reference_videos/sample_cow.mp4 filter=lfs diff=lfs merge=lfs -text
54
+ reference_videos/sample_fox.mp4 filter=lfs diff=lfs merge=lfs -text
55
+ reference_videos/sample_leaves.mp4 filter=lfs diff=lfs merge=lfs -text
56
+ reference_videos/sample_white_tiger.mp4 filter=lfs diff=lfs merge=lfs -text
57
+ reference_videos/sample_wolf.mp4 filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MotionClone-Image-to-Video
3
+ emoji: 📷
4
+ colorFrom: red
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.44.1
8
+ app_file: i2v_video_app.py
9
+ pinned: true
10
+ license: bsd-3-clause
11
+ short_description: Motion cloning for controllable video generation
12
+ ---
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: MotionClone Image To Video
3
- emoji: 🦀
4
- colorFrom: indigo
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.17.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: MotionClone-Image-to-Video
3
+ emoji: 📷
4
+ colorFrom: red
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.44.1
8
+ app_file: i2v_video_app.py
9
+ pinned: true
10
+ license: bsd-3-clause
11
+ short_description: Motion cloning for controllable video generation
12
+ ---
__assets__/feature_visualization.png ADDED

Git LFS Details

  • SHA256: 4c0891fbfe56b1650d6c65dac700d02faee46cff0cc56515c8a23a8be0c9a46b
  • Pointer size: 131 Bytes
  • Size of remote file: 944 kB
__assets__/pipeline.png ADDED

Git LFS Details

  • SHA256: bc9926f5f4a746475cb1963a4e908671db82d0cc630c8a5e9cd43f78885fd82d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.01 MB
__assets__/teaser.gif ADDED

Git LFS Details

  • SHA256: 2ee4ff21495ae52ff2c9f4ff9ad5406c3f4445633a437664f9cc20277460ea6f
  • Pointer size: 133 Bytes
  • Size of remote file: 14.6 MB
__assets__/teaser.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:201747f42691e708b9efe48ea054961fd82cf54b83ac43e0d97a43f81779c00b
3
+ size 4957080
condition_images/rgb/dog_on_grass.png ADDED

Git LFS Details

  • SHA256: 1b3ead35573919274f59d763c5085608ca78a993bf508448ca22af31ebcab113
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
condition_images/scribble/lion_forest.png ADDED
configs/i2v_rgb.jsonl ADDED
@@ -0,0 +1 @@
 
 
1
+ {"video_path":"reference_videos/camera_zoom_out.mp4", "condition_image_paths":["condition_images/rgb/dog_on_grass.png"], "new_prompt": "Dog, lying on the grass"}
configs/i2v_rgb.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ motion_module: "models/Motion_Module/v3_sd15_mm.ckpt"
2
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"
3
+ model_config: "configs/model_config/model_config.yaml"
4
+ controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt"
5
+ controlnet_config: "configs/sparsectrl/latent_condition.yaml"
6
+ adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt"
7
+
8
+ cfg_scale: 7.5 # in default realistic classifer-free guidance
9
+ negative_prompt: "ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers"
10
+
11
+ inference_steps: 100 # the total denosing step for inference
12
+ guidance_scale: 0.3 # which scale of time step to end guidance
13
+ guidance_steps: 40 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance
14
+ warm_up_steps: 10
15
+ cool_up_steps: 10
16
+
17
+ motion_guidance_weight: 2000
18
+ motion_guidance_blocks: ['up_blocks.1']
19
+
20
+ add_noise_step: 400
configs/i2v_sketch.jsonl ADDED
@@ -0,0 +1 @@
 
 
1
+ {"video_path":"reference_videos/sample_white_tiger.mp4", "condition_image_paths":["condition_images/scribble/lion_forest.png"], "new_prompt": "Lion, walks in the forest"}
configs/i2v_sketch.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ motion_module: "models/Motion_Module/v3_sd15_mm.ckpt"
2
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"
3
+ model_config: "configs/model_config/model_config.yaml"
4
+ controlnet_config: "configs/sparsectrl/image_condition.yaml"
5
+ controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt"
6
+ adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt"
7
+
8
+ cfg_scale: 7.5 # in default realistic classifer-free guidance
9
+ negative_prompt: "ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers"
10
+
11
+ inference_steps: 200 # the total denosing step for inference
12
+ guidance_scale: 0.4 # which scale of time step to end guidance
13
+ guidance_steps: 120 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance
14
+ warm_up_steps: 10
15
+ cool_up_steps: 10
16
+
17
+ motion_guidance_weight: 2000
18
+ motion_guidance_blocks: ['up_blocks.1']
19
+
20
+ add_noise_step: 400
configs/model_config/inference-v1.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true # from config v3
3
+
4
+
5
+ use_motion_module: true
6
+ motion_module_resolutions: [1,2,4,8]
7
+ motion_module_mid_block: false
8
+ motion_module_decoder_only: false
9
+ motion_module_type: "Vanilla"
10
+
11
+ motion_module_kwargs:
12
+ num_attention_heads: 8
13
+ num_transformer_block: 1
14
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
15
+ temporal_position_encoding: true
16
+ temporal_position_encoding_max_len: 32
17
+ temporal_attention_dim_div: 1
18
+ zero_initialize: true # from config v3
19
+
20
+ noise_scheduler_kwargs:
21
+ beta_start: 0.00085
22
+ beta_end: 0.012
23
+ beta_schedule: "linear"
24
+ steps_offset: 1
25
+ clip_sample: False
configs/model_config/inference-v2.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true
3
+ unet_use_cross_frame_attention: false
4
+ unet_use_temporal_attention: false
5
+ use_motion_module: true
6
+ motion_module_resolutions: [1,2,4,8]
7
+ motion_module_mid_block: true
8
+ motion_module_decoder_only: false
9
+ motion_module_type: "Vanilla"
10
+
11
+ motion_module_kwargs:
12
+ num_attention_heads: 8
13
+ num_transformer_block: 1
14
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
15
+ temporal_position_encoding: true
16
+ temporal_position_encoding_max_len: 32
17
+ temporal_attention_dim_div: 1
18
+
19
+ noise_scheduler_kwargs:
20
+ beta_start: 0.00085
21
+ beta_end: 0.012
22
+ beta_schedule: "linear"
23
+ steps_offset: 1
24
+ clip_sample: False
configs/model_config/inference-v3.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true
3
+ use_motion_module: true
4
+ motion_module_resolutions: [1,2,4,8]
5
+ motion_module_mid_block: false
6
+ motion_module_type: Vanilla
7
+
8
+ motion_module_kwargs:
9
+ num_attention_heads: 8
10
+ num_transformer_block: 1
11
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
12
+ temporal_position_encoding: true
13
+ temporal_position_encoding_max_len: 32
14
+ temporal_attention_dim_div: 1
15
+ zero_initialize: true
16
+
17
+ noise_scheduler_kwargs:
18
+ beta_start: 0.00085
19
+ beta_end: 0.012
20
+ beta_schedule: "linear"
21
+ steps_offset: 1
22
+ clip_sample: False
configs/model_config/model_config copy.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true # from config v3
3
+ use_motion_module: true
4
+ motion_module_resolutions: [1,2,4,8]
5
+ motion_module_mid_block: false
6
+ motion_module_type: "Vanilla"
7
+
8
+ motion_module_kwargs:
9
+ num_attention_heads: 8
10
+ num_transformer_block: 1
11
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
12
+ temporal_position_encoding: true
13
+ temporal_position_encoding_max_len: 32
14
+ temporal_attention_dim_div: 1
15
+ zero_initialize: true # from config v3
16
+
17
+ noise_scheduler_kwargs:
18
+ beta_start: 0.00085
19
+ beta_end: 0.012
20
+ beta_schedule: "linear"
21
+ steps_offset: 1
22
+ clip_sample: False
configs/model_config/model_config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true
3
+ use_motion_module: true
4
+ motion_module_resolutions: [1,2,4,8]
5
+ motion_module_mid_block: false
6
+ motion_module_type: "Vanilla"
7
+
8
+ motion_module_kwargs:
9
+ num_attention_heads: 8
10
+ num_transformer_block: 1
11
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
12
+ temporal_position_encoding: true
13
+ temporal_attention_dim_div: 1
14
+ zero_initialize: true
15
+
16
+ noise_scheduler_kwargs:
17
+ beta_start: 0.00085
18
+ beta_end: 0.012
19
+ beta_schedule: "linear"
20
+ steps_offset: 1
21
+ clip_sample: false
configs/model_config/model_config_public.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true # from config v3
3
+ unet_use_cross_frame_attention: false
4
+ unet_use_temporal_attention: false
5
+ use_motion_module: true
6
+ motion_module_resolutions: [1,2,4,8]
7
+ motion_module_mid_block: false
8
+ motion_module_decoder_only: false
9
+ motion_module_type: "Vanilla"
10
+
11
+ motion_module_kwargs:
12
+ num_attention_heads: 8
13
+ num_transformer_block: 1
14
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
15
+ temporal_position_encoding: true
16
+ temporal_position_encoding_max_len: 32
17
+ temporal_attention_dim_div: 1
18
+ zero_initialize: true # from config v3
19
+
20
+ noise_scheduler_kwargs:
21
+ beta_start: 0.00085
22
+ beta_end: 0.012
23
+ beta_schedule: "linear"
24
+ steps_offset: 1
25
+ clip_sample: False
configs/sparsectrl/image_condition.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ controlnet_additional_kwargs:
2
+ set_noisy_sample_input_to_zero: true
3
+ use_simplified_condition_embedding: false
4
+ conditioning_channels: 3
5
+
6
+ use_motion_module: true
7
+ motion_module_resolutions: [1,2,4,8]
8
+ motion_module_mid_block: false
9
+ motion_module_type: "Vanilla"
10
+
11
+ motion_module_kwargs:
12
+ num_attention_heads: 8
13
+ num_transformer_block: 1
14
+ attention_block_types: [ "Temporal_Self" ]
15
+ temporal_position_encoding: true
16
+ temporal_position_encoding_max_len: 32
17
+ temporal_attention_dim_div: 1
configs/sparsectrl/latent_condition.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ controlnet_additional_kwargs:
2
+ set_noisy_sample_input_to_zero: true
3
+ use_simplified_condition_embedding: true
4
+ conditioning_channels: 4
5
+
6
+ use_motion_module: true
7
+ motion_module_resolutions: [1,2,4,8]
8
+ motion_module_mid_block: false
9
+ motion_module_type: "Vanilla"
10
+
11
+ motion_module_kwargs:
12
+ num_attention_heads: 8
13
+ num_transformer_block: 1
14
+ attention_block_types: [ "Temporal_Self" ]
15
+ temporal_position_encoding: true
16
+ temporal_position_encoding_max_len: 32
17
+ temporal_attention_dim_div: 1
configs/t2v_camera.jsonl ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Relics on the seabed", "seed": 42}
2
+ {"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "A road in the mountain", "seed": 42}
3
+ {"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Caves, a path for exploration", "seed": 2026}
4
+ {"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Railway for train"}
5
+ {"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Tree, in the mountain", "seed": 2026}
6
+ {"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Red car on the track", "seed": 2026}
7
+ {"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Man, standing in his garden.", "seed": 2026}
8
+ {"video_path":"reference_videos/camera_1.mp4", "new_prompt": "A island, on the ocean, sunny day"}
9
+ {"video_path":"reference_videos/camera_1.mp4", "new_prompt": "A tower, with fireworks"}
10
+ {"video_path":"reference_videos/camera_pan_up.mp4", "new_prompt": "Beautiful house, around with flowers", "seed": 42}
11
+ {"video_path":"reference_videos/camera_translation_2.mp4", "new_prompt": "Forest, in winter", "seed": 2028}
12
+ {"video_path":"reference_videos/camera_pan_down.mp4", "new_prompt": "Eagle, standing in the tree", "seed": 2026}
configs/t2v_camera.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ motion_module: "models/Motion_Module/v3_sd15_mm.ckpt"
3
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"
4
+ model_config: "configs/model_config/model_config.yaml"
5
+
6
+ cfg_scale: 7.5 # in default realistic classifer-free guidance
7
+ negative_prompt: "bad anatomy, extra limbs, ugly, deformed, noisy, blurry, distorted, out of focus, poorly drawn face, poorly drawn hands, missing fingers"
8
+ postive_prompt: " 8k, high detailed, best quality, film grain, Fujifilm XT3"
9
+
10
+ inference_steps: 100 # the total denosing step for inference
11
+ guidance_scale: 0.3 # which scale of time step to end guidance 0.2/40
12
+ guidance_steps: 50 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance
13
+ warm_up_steps: 10
14
+ cool_up_steps: 10
15
+
16
+ motion_guidance_weight: 2000
17
+ motion_guidance_blocks: ['up_blocks.1']
18
+
19
+ add_noise_step: 400
configs/t2v_object.jsonl ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {"video_path":"reference_videos/sample_astronaut.mp4", "new_prompt": "Robot, walks in the street.","seed":59}
2
+ {"video_path":"reference_videos/sample_cat.mp4", "new_prompt": "Tiger, raises its head.", "seed": 2025}
3
+ {"video_path":"reference_videos/sample_leaves.mp4", "new_prompt": "Petals falling in the wind.","seed":3407}
4
+ {"video_path":"reference_videos/sample_fox.mp4", "new_prompt": "Cat, turns its head in the living room."}
5
+ {"video_path":"reference_videos/sample_blackswan.mp4", "new_prompt": "Duck, swims in the river.","seed":3407}
6
+ {"video_path":"reference_videos/sample_cow.mp4", "new_prompt": "Pig, drinks water on beach.","seed":3407}
configs/t2v_object.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ motion_module: "models/Motion_Module/v3_sd15_mm.ckpt"
3
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"
4
+ model_config: "configs/model_config/model_config.yaml"
5
+
6
+ cfg_scale: 7.5 # in default realistic classifer-free guidance
7
+ negative_prompt: "bad anatomy, extra limbs, ugly, deformed, noisy, blurry, distorted, out of focus, poorly drawn face, poorly drawn hands, missing fingers"
8
+ postive_prompt: "8k, high detailed, best quality, film grain, Fujifilm XT3"
9
+
10
+ inference_steps: 300 # the total denosing step for inference
11
+ guidance_scale: 0.4 # which scale of time step to end guidance
12
+ guidance_steps: 180 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance
13
+ warm_up_steps: 10
14
+ cool_up_steps: 10
15
+
16
+ motion_guidance_weight: 2000
17
+ motion_guidance_blocks: ['up_blocks.1',]
18
+
19
+ add_noise_step: 400
environment.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: motionclone
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python=3.11.3
7
+ - pytorch=2.0.1
8
+ - torchvision=0.15.2
9
+ - pytorch-cuda=11.8
10
+ - pip
11
+ - pip:
12
+ - accelerate
13
+ - diffusers==0.16.0
14
+ - transformers==4.28.1
15
+ - xformers==0.0.20
16
+ - imageio[ffmpeg]
17
+ - decord==0.6.0
18
+ - gdown
19
+ - einops
20
+ - omegaconf
21
+ - safetensors
22
+ - gradio
23
+ - wandb
24
+ - triton
25
+ - opencv-python
generated_videos/camera_zoom_out_Dog,_lying_on_the_grass76739_76739.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63ecf6f1250b83d71b50352a020c97eb60223ee33813219b2bd8d7588f1ecfec
3
+ size 285735
generated_videos/inference_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ motion_module: models/Motion_Module/v3_sd15_mm.ckpt
2
+ dreambooth_path: models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors
3
+ model_config: configs/model_config/model_config.yaml
4
+ controlnet_config: configs/sparsectrl/image_condition.yaml
5
+ controlnet_path: models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt
6
+ adapter_lora_path: models/Motion_Module/v3_sd15_adapter.ckpt
7
+ cfg_scale: 7.5
8
+ negative_prompt: ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy,
9
+ extra limbs, poorly drawn face, poorly drawn hands, missing fingers
10
+ inference_steps: 200
11
+ guidance_scale: 0.4
12
+ guidance_steps: 120
13
+ warm_up_steps: 10
14
+ cool_up_steps: 10
15
+ motion_guidance_weight: 2000
16
+ motion_guidance_blocks:
17
+ - up_blocks.1
18
+ add_noise_step: 400
19
+ width: 512
20
+ height: 512
21
+ video_length: 16
generated_videos/sample_white_tiger_Lion,_walks_in_the_forest76739_76739.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ae68b549f1c6541417009d1cdd35d01286876bada07fb53a3354ad9225856cf
3
+ size 538343
i2v_video_app.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from omegaconf import OmegaConf
3
+ import torch
4
+ from diffusers import AutoencoderKL, DDIMScheduler
5
+ from transformers import CLIPTextModel, CLIPTokenizer
6
+ from motionclone.models.unet import UNet3DConditionModel
7
+ from motionclone.models.sparse_controlnet import SparseControlNetModel
8
+ from motionclone.pipelines.pipeline_animation import AnimationPipeline
9
+ from motionclone.utils.util import load_weights, auto_download
10
+ from diffusers.utils.import_utils import is_xformers_available
11
+ from motionclone.utils.motionclone_functions import *
12
+ import json
13
+ from motionclone.utils.xformer_attention import *
14
+ import os
15
+ import numpy as np
16
+ import imageio
17
+ import shutil
18
+ import subprocess
19
+ from types import SimpleNamespace
20
+
21
+ # 模型下载逻辑
22
+ def download_weights():
23
+ try:
24
+ # 创建模型目录
25
+ os.makedirs("models", exist_ok=True)
26
+ os.makedirs("models/DreamBooth_LoRA", exist_ok=True)
27
+ os.makedirs("models/Motion_Module", exist_ok=True)
28
+ os.makedirs("models/SparseCtrl", exist_ok=True)
29
+
30
+ # 下载 Stable Diffusion 模型
31
+ if not os.path.exists("models/StableDiffusion"):
32
+ subprocess.run(["git", "clone", "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5", "models/StableDiffusion"])
33
+
34
+ # 下载 DreamBooth LoRA 模型
35
+ if not os.path.exists("models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"):
36
+ subprocess.run(["wget", "https://huggingface.co/svjack/Realistic-Vision-V6.0-B1/resolve/main/realisticVisionV60B1_v51VAE.safetensors", "-O", "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"])
37
+
38
+ # 下载 Motion Module 模型
39
+ if not os.path.exists("models/Motion_Module/v3_sd15_mm.ckpt"):
40
+ subprocess.run(["wget", "https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_mm.ckpt", "-O", "models/Motion_Module/v3_sd15_mm.ckpt"])
41
+ if not os.path.exists("models/Motion_Module/v3_sd15_adapter.ckpt"):
42
+ subprocess.run(["wget", "https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_adapter.ckpt", "-O", "models/Motion_Module/v3_sd15_adapter.ckpt"])
43
+
44
+ # 下载 SparseCtrl 模型
45
+ if not os.path.exists("models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt"):
46
+ subprocess.run(["wget", "https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_sparsectrl_rgb.ckpt", "-O", "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt"])
47
+ if not os.path.exists("models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt"):
48
+ subprocess.run(["wget", "https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_sparsectrl_scribble.ckpt", "-O", "models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt"])
49
+
50
+ print("Weights downloaded successfully.")
51
+ except Exception as e:
52
+ print(f"Error downloading weights: {e}")
53
+
54
+ # 下载权重
55
+ download_weights()
56
+
57
+ # 模型初始化逻辑
58
+ def initialize_models(pretrained_model_path, config):
59
+ # 设置设备
60
+ adopted_dtype = torch.float16
61
+ device = "cuda"
62
+ set_all_seed(42)
63
+
64
+ # 加载模型组件
65
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
66
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device).to(dtype=adopted_dtype)
67
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device).to(dtype=adopted_dtype)
68
+
69
+ # 更新配置
70
+ config["width"] = config.get("W", 512)
71
+ config["height"] = config.get("H", 512)
72
+ config["video_length"] = config.get("L", 16)
73
+
74
+ # 加载模型配置
75
+ model_config = OmegaConf.load(config.get("model_config", ""))
76
+ unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(model_config.unet_additional_kwargs)).to(device).to(dtype=adopted_dtype)
77
+
78
+ # 加载 controlnet 模型
79
+ controlnet = None
80
+ if config.get("controlnet_path", "") != "":
81
+ assert config.get("controlnet_config", "") != ""
82
+
83
+ unet.config.num_attention_heads = 8
84
+ unet.config.projection_class_embeddings_input_dim = None
85
+
86
+ controlnet_config = OmegaConf.load(config["controlnet_config"])
87
+ controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {})).to(device).to(dtype=adopted_dtype)
88
+
89
+ auto_download(config["controlnet_path"], is_dreambooth_lora=False)
90
+ print(f"loading controlnet checkpoint from ", config["controlnet_path"])
91
+ controlnet_state_dict = torch.load(config["controlnet_path"], map_location="cpu")
92
+ controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
93
+ controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name}
94
+ controlnet_state_dict.pop("animatediff_config", "")
95
+ controlnet.load_state_dict(controlnet_state_dict)
96
+ del controlnet_state_dict
97
+
98
+ # 启用 xformers
99
+ if is_xformers_available():
100
+ unet.enable_xformers_memory_efficient_attention()
101
+
102
+ # 创建 pipeline
103
+ pipeline = AnimationPipeline(
104
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
105
+ controlnet=controlnet,
106
+ scheduler=DDIMScheduler(**model_config.noise_scheduler_kwargs),
107
+ ).to(device)
108
+
109
+ # 加载权重
110
+ pipeline = load_weights(
111
+ pipeline,
112
+ motion_module_path=config.get("motion_module", ""),
113
+ adapter_lora_path=config.get("adapter_lora_path", ""),
114
+ adapter_lora_scale=config.get("adapter_lora_scale", 1.0),
115
+ dreambooth_model_path=config.get("dreambooth_path", ""),
116
+ ).to(device)
117
+ pipeline.text_encoder.to(dtype=adopted_dtype)
118
+
119
+ # 加载自定义函数
120
+ pipeline.scheduler.customized_step = schedule_customized_step.__get__(pipeline.scheduler)
121
+ pipeline.scheduler.customized_set_timesteps = schedule_set_timesteps.__get__(pipeline.scheduler)
122
+ pipeline.unet.forward = unet_customized_forward.__get__(pipeline.unet)
123
+ pipeline.sample_video = sample_video.__get__(pipeline)
124
+ pipeline.single_step_video = single_step_video.__get__(pipeline)
125
+ pipeline.get_temp_attn_prob = get_temp_attn_prob.__get__(pipeline)
126
+ pipeline.add_noise = add_noise.__get__(pipeline)
127
+ pipeline.compute_temp_loss = compute_temp_loss.__get__(pipeline)
128
+ pipeline.obtain_motion_representation = obtain_motion_representation.__get__(pipeline)
129
+
130
+ # 冻结 UNet 和 ControlNet 参数
131
+ for param in pipeline.unet.parameters():
132
+ param.requires_grad = False
133
+ if pipeline.controlnet is not None:
134
+ for param in pipeline.controlnet.parameters():
135
+ param.requires_grad = False
136
+
137
+ pipeline.input_config, pipeline.unet.input_config = SimpleNamespace(**config), SimpleNamespace(**config)
138
+ pipeline.unet = prep_unet_attention(pipeline.unet, config.get("motion_guidance_blocks", []))
139
+ pipeline.unet = prep_unet_conv(pipeline.unet)
140
+
141
+ return pipeline
142
+
143
+ # 硬编码的配置值
144
+ config = {
145
+ "motion_module": "models/Motion_Module/v3_sd15_mm.ckpt",
146
+ "dreambooth_path": "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors",
147
+ "model_config": "configs/model_config/model_config.yaml",
148
+ "controlnet_path": "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt",
149
+ "controlnet_config": "configs/sparsectrl/latent_condition.yaml",
150
+ "adapter_lora_path": "models/Motion_Module/v3_sd15_adapter.ckpt",
151
+ "W": 512,
152
+ "H": 512,
153
+ "L": 16,
154
+ "motion_guidance_blocks": ['up_blocks.1'],
155
+ }
156
+
157
+ # 初始化模型
158
+ pretrained_model_path = "models/StableDiffusion"
159
+ pipeline = initialize_models(pretrained_model_path, config)
160
+
161
+ # 视频生成函数
162
+ def generate_video(uploaded_video, condition_images, new_prompt, seed, motion_representation_save_dir, generated_videos_save_dir, visible_gpu, without_xformers, cfg_scale, negative_prompt, positive_prompt, inference_steps, guidance_scale, guidance_steps, warm_up_steps, cool_up_steps, motion_guidance_weight, motion_guidance_blocks, add_noise_step):
163
+ # 更新配置
164
+ config.update({
165
+ "cfg_scale": cfg_scale,
166
+ "negative_prompt": negative_prompt,
167
+ "positive_prompt": positive_prompt,
168
+ "inference_steps": inference_steps,
169
+ "guidance_scale": guidance_scale,
170
+ "guidance_steps": guidance_steps,
171
+ "warm_up_steps": warm_up_steps,
172
+ "cool_up_steps": cool_up_steps,
173
+ "motion_guidance_weight": motion_guidance_weight,
174
+ #"motion_guidance_blocks": motion_guidance_blocks,
175
+ "add_noise_step": add_noise_step
176
+ })
177
+
178
+ # 设置环境变量
179
+ os.environ["CUDA_VISIBLE_DEVICES"] = visible_gpu or str(os.getenv('CUDA_VISIBLE_DEVICES', 0))
180
+
181
+ device = pipeline.device
182
+
183
+ # 创建保存目录
184
+ if not os.path.exists(generated_videos_save_dir):
185
+ os.makedirs(generated_videos_save_dir)
186
+ if not os.path.exists(motion_representation_save_dir):
187
+ os.makedirs(motion_representation_save_dir)
188
+
189
+ # 处理上传的视频
190
+ if uploaded_video is not None:
191
+ pipeline.scheduler.customized_set_timesteps(config["inference_steps"], config["guidance_steps"], config["guidance_scale"], device=device, timestep_spacing_type="uneven")
192
+
193
+ # 将上传的视频保存到指定路径
194
+ video_path = os.path.join(generated_videos_save_dir, os.path.basename(uploaded_video))
195
+ shutil.copy2(uploaded_video, video_path)
196
+
197
+ # 更新配置
198
+ config["video_path"] = video_path
199
+ config["condition_image_path_list"] = condition_images
200
+ config["image_index"] = [0] * len(condition_images)
201
+ config["new_prompt"] = new_prompt + config.get("positive_prompt", "")
202
+ config["controlnet_scale"] = 1.0
203
+
204
+ pipeline.input_config, pipeline.unet.input_config = SimpleNamespace(**config), SimpleNamespace(**config)
205
+
206
+ # 提取运动表示
207
+ seed_motion = seed if seed is not None else 76739
208
+ generator = torch.Generator(device=pipeline.device)
209
+ generator.manual_seed(seed_motion)
210
+ motion_representation_path = os.path.join(motion_representation_save_dir, os.path.splitext(os.path.basename(config["video_path"]))[0] + '.pt')
211
+ pipeline.obtain_motion_representation(generator=generator, motion_representation_path=motion_representation_path, use_controlnet=True)
212
+
213
+ # 生成视频
214
+ seed = seed_motion
215
+ generator = torch.Generator(device=pipeline.device)
216
+ generator.manual_seed(seed)
217
+ pipeline.input_config.seed = seed
218
+ videos = pipeline.sample_video(generator=generator, add_controlnet=True)
219
+
220
+ videos = rearrange(videos, "b c f h w -> b f h w c")
221
+ save_path = os.path.join(generated_videos_save_dir, os.path.splitext(os.path.basename(config["video_path"]))[0] + "_" + config["new_prompt"].strip().replace(' ', '_') + str(seed_motion) + "_" + str(seed) + '.mp4')
222
+ videos_uint8 = (videos[0] * 255).astype(np.uint8)
223
+ imageio.mimwrite(save_path, videos_uint8, fps=8)
224
+ print(save_path, "is done")
225
+
226
+ return save_path
227
+ else:
228
+ return "No video uploaded."
229
+
230
+ # 使用 Gradio 构建界面
231
+ with gr.Blocks() as demo:
232
+ gr.Markdown("# MotionClone Video Generation")
233
+ with gr.Row():
234
+ with gr.Column():
235
+ uploaded_video = gr.Video(label="Upload Video")
236
+ condition_images = gr.Files(label="Condition Images")
237
+ new_prompt = gr.Textbox(label="New Prompt", value="A beautiful scene")
238
+ seed = gr.Number(label="Seed", value=76739)
239
+ generate_button = gr.Button("Generate Video")
240
+ with gr.Column():
241
+ output_video = gr.Video(label="Generated Video")
242
+
243
+ with gr.Accordion("Advanced Settings", open=False):
244
+ motion_representation_save_dir = gr.Textbox(label="Motion Representation Save Dir", value="motion_representation/")
245
+ generated_videos_save_dir = gr.Textbox(label="Generated Videos Save Dir", value="generated_videos/")
246
+ visible_gpu = gr.Textbox(label="Visible GPU", value="0")
247
+ without_xformers = gr.Checkbox(label="Without Xformers", value=False)
248
+ cfg_scale = gr.Number(label="CFG Scale", value=7.5)
249
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers")
250
+ positive_prompt = gr.Textbox(label="Positive Prompt", value="8k, high detailed, best quality, film grain, Fujifilm XT3")
251
+ inference_steps = gr.Number(label="Inference Steps", value=100)
252
+ guidance_scale = gr.Number(label="Guidance Scale", value=0.3)
253
+ guidance_steps = gr.Number(label="Guidance Steps", value=40)
254
+ warm_up_steps = gr.Number(label="Warm Up Steps", value=10)
255
+ cool_up_steps = gr.Number(label="Cool Up Steps", value=10)
256
+ motion_guidance_weight = gr.Number(label="Motion Guidance Weight", value=2000)
257
+ motion_guidance_blocks = gr.Textbox(label="Motion Guidance Blocks", value="['up_blocks.1']")
258
+ add_noise_step = gr.Number(label="Add Noise Step", value=400)
259
+
260
+ # 绑定生成函数
261
+ generate_button.click(
262
+ generate_video,
263
+ inputs=[
264
+ uploaded_video, condition_images, new_prompt, seed, motion_representation_save_dir, generated_videos_save_dir, visible_gpu, without_xformers, cfg_scale, negative_prompt, positive_prompt, inference_steps, guidance_scale, guidance_steps, warm_up_steps, cool_up_steps, motion_guidance_weight, motion_guidance_blocks, add_noise_step
265
+ ],
266
+ outputs=output_video
267
+ )
268
+
269
+ # 添加示例
270
+ examples = [
271
+ {"video_path": "reference_videos/camera_zoom_out.mp4", "condition_image_paths": ["condition_images/rgb/dog_on_grass.png"], "new_prompt": "Dog, lying on the grass", "seed": 42}
272
+ ]
273
+ examples = list(map(lambda d: [d["video_path"], d["condition_image_paths"], d["new_prompt"], d["seed"]], examples))
274
+
275
+ gr.Examples(
276
+ examples=examples,
277
+ inputs=[uploaded_video, condition_images, new_prompt, seed],
278
+ outputs=output_video,
279
+ fn=generate_video,
280
+ cache_examples=False
281
+ )
282
+
283
+ # 启动应用
284
+ demo.launch(share=True)
i2v_video_sample.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from omegaconf import OmegaConf
3
+ import torch
4
+ from diffusers import AutoencoderKL, DDIMScheduler
5
+ from transformers import CLIPTextModel, CLIPTokenizer
6
+ from motionclone.models.unet import UNet3DConditionModel
7
+ from motionclone.models.sparse_controlnet import SparseControlNetModel
8
+ from motionclone.pipelines.pipeline_animation import AnimationPipeline
9
+ from motionclone.utils.util import load_weights, auto_download
10
+ from diffusers.utils.import_utils import is_xformers_available
11
+ from motionclone.utils.motionclone_functions import *
12
+ import json
13
+ from motionclone.utils.xformer_attention import *
14
+
15
+
16
+ def main(args):
17
+
18
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu or str(os.getenv('CUDA_VISIBLE_DEVICES', 0))
19
+
20
+ config = OmegaConf.load(args.inference_config)
21
+ adopted_dtype = torch.float16
22
+ device = "cuda"
23
+ set_all_seed(42)
24
+
25
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
26
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").to(device).to(dtype=adopted_dtype)
27
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device).to(dtype=adopted_dtype)
28
+
29
+ config.width = config.get("W", args.W)
30
+ config.height = config.get("H", args.H)
31
+ config.video_length = config.get("L", args.L)
32
+
33
+ if not os.path.exists(args.generated_videos_save_dir):
34
+ os.makedirs(args.generated_videos_save_dir)
35
+ OmegaConf.save(config, os.path.join(args.generated_videos_save_dir,"inference_config.json"))
36
+
37
+ model_config = OmegaConf.load(config.get("model_config", ""))
38
+ unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(model_config.unet_additional_kwargs),).to(device).to(dtype=adopted_dtype)
39
+
40
+ # load controlnet model
41
+ controlnet = None
42
+ if config.get("controlnet_path", "") != "":
43
+ # assert model_config.get("controlnet_images", "") != ""
44
+ assert config.get("controlnet_config", "") != ""
45
+
46
+ unet.config.num_attention_heads = 8
47
+ unet.config.projection_class_embeddings_input_dim = None
48
+
49
+ controlnet_config = OmegaConf.load(config.controlnet_config)
50
+ controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {})).to(device).to(dtype=adopted_dtype)
51
+
52
+ auto_download(config.controlnet_path, is_dreambooth_lora=False)
53
+ print(f"loading controlnet checkpoint from {config.controlnet_path} ...")
54
+ controlnet_state_dict = torch.load(config.controlnet_path, map_location="cpu")
55
+ controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
56
+ controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name}
57
+ controlnet_state_dict.pop("animatediff_config", "")
58
+ controlnet.load_state_dict(controlnet_state_dict)
59
+ del controlnet_state_dict
60
+
61
+ # set xformers
62
+ if is_xformers_available() and (not args.without_xformers):
63
+ unet.enable_xformers_memory_efficient_attention()
64
+
65
+ pipeline = AnimationPipeline(
66
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
67
+ controlnet=controlnet,
68
+ scheduler=DDIMScheduler(**OmegaConf.to_container(model_config.noise_scheduler_kwargs)),
69
+ ).to(device)
70
+
71
+ pipeline = load_weights(
72
+ pipeline,
73
+ # motion module
74
+ motion_module_path = config.get("motion_module", ""),
75
+ # domain adapter
76
+ adapter_lora_path = config.get("adapter_lora_path", ""),
77
+ adapter_lora_scale = config.get("adapter_lora_scale", 1.0),
78
+ # image layer
79
+ dreambooth_model_path = config.get("dreambooth_path", ""),
80
+ ).to(device)
81
+ pipeline.text_encoder.to(dtype=adopted_dtype)
82
+
83
+ # customized functions in motionclone_functions
84
+ pipeline.scheduler.customized_step = schedule_customized_step.__get__(pipeline.scheduler)
85
+ pipeline.scheduler.customized_set_timesteps = schedule_set_timesteps.__get__(pipeline.scheduler)
86
+ pipeline.unet.forward = unet_customized_forward.__get__(pipeline.unet)
87
+ pipeline.sample_video = sample_video.__get__(pipeline)
88
+ pipeline.single_step_video = single_step_video.__get__(pipeline)
89
+ pipeline.get_temp_attn_prob = get_temp_attn_prob.__get__(pipeline)
90
+ pipeline.add_noise = add_noise.__get__(pipeline)
91
+ pipeline.compute_temp_loss = compute_temp_loss.__get__(pipeline)
92
+ pipeline.obtain_motion_representation = obtain_motion_representation.__get__(pipeline)
93
+
94
+ for param in pipeline.unet.parameters():
95
+ param.requires_grad = False
96
+ for param in pipeline.controlnet.parameters():
97
+ param.requires_grad = False
98
+
99
+ pipeline.input_config, pipeline.unet.input_config = config, config
100
+ pipeline.unet = prep_unet_attention(pipeline.unet,pipeline.input_config.motion_guidance_blocks)
101
+ pipeline.unet = prep_unet_conv(pipeline.unet)
102
+ pipeline.scheduler.customized_set_timesteps(config.inference_steps, config.guidance_steps,config.guidance_scale,device=device,timestep_spacing_type = "uneven")
103
+
104
+ with open(args.examples, 'r') as files:
105
+ for line in files:
106
+ # prepare infor of each case
107
+ example_infor = json.loads(line)
108
+ config.video_path = example_infor["video_path"]
109
+ config.condition_image_path_list = example_infor["condition_image_paths"]
110
+ config.image_index = example_infor.get("image_index",[0])
111
+ assert len(config.image_index) == len(config.condition_image_path_list)
112
+ config.new_prompt = example_infor["new_prompt"] + config.get("positive_prompt", "")
113
+ config.controlnet_scale = example_infor.get("controlnet_scale", 1.0)
114
+ pipeline.input_config, pipeline.unet.input_config = config, config # update config
115
+
116
+ # perform motion representation extraction
117
+ seed_motion = seed_motion = example_infor.get("seed", args.default_seed)
118
+ generator = torch.Generator(device=pipeline.device)
119
+ generator.manual_seed(seed_motion)
120
+ if not os.path.exists(args.motion_representation_save_dir):
121
+ os.makedirs(args.motion_representation_save_dir)
122
+ motion_representation_path = os.path.join(args.motion_representation_save_dir, os.path.splitext(os.path.basename(config.video_path))[0] + '.pt')
123
+ pipeline.obtain_motion_representation(generator= generator, motion_representation_path = motion_representation_path, use_controlnet=True,)
124
+
125
+ # perform video generation
126
+ seed = seed_motion # can assign other seed here
127
+ generator = torch.Generator(device=pipeline.device)
128
+ generator.manual_seed(seed)
129
+ pipeline.input_config.seed = seed
130
+ videos = pipeline.sample_video(generator = generator, add_controlnet=True,)
131
+
132
+ videos = rearrange(videos, "b c f h w -> b f h w c")
133
+ save_path = os.path.join(args.generated_videos_save_dir, os.path.splitext(os.path.basename(config.video_path))[0]
134
+ + "_" + config.new_prompt.strip().replace(' ', '_') + str(seed_motion) + "_" +str(seed)+'.mp4')
135
+ videos_uint8 = (videos[0] * 255).astype(np.uint8)
136
+ imageio.mimwrite(save_path, videos_uint8, fps=8)
137
+ print(save_path,"is done")
138
+
139
+ if __name__ == "__main__":
140
+ parser = argparse.ArgumentParser()
141
+ parser.add_argument("--pretrained-model-path", type=str, default="models/StableDiffusion",)
142
+
143
+ parser.add_argument("--inference_config", type=str, default="configs/i2v_sketch.yaml")
144
+ parser.add_argument("--examples", type=str, default="configs/i2v_sketch.jsonl")
145
+ parser.add_argument("--motion-representation-save-dir", type=str, default="motion_representation/")
146
+ parser.add_argument("--generated-videos-save-dir", type=str, default="generated_videos/")
147
+
148
+ parser.add_argument("--visible_gpu", type=str, default=None)
149
+ parser.add_argument("--default-seed", type=int, default=76739)
150
+ parser.add_argument("--L", type=int, default=16)
151
+ parser.add_argument("--W", type=int, default=512)
152
+ parser.add_argument("--H", type=int, default=512)
153
+
154
+ parser.add_argument("--without-xformers", action="store_true")
155
+
156
+ args = parser.parse_args()
157
+ main(args)
models/Motion_Module/Put motion module checkpoints here.txt ADDED
File without changes
motionclone/models/__pycache__/attention.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
motionclone/models/__pycache__/attention.cpython-38.pyc ADDED
Binary file (13.6 kB). View file
 
motionclone/models/__pycache__/motion_module.cpython-310.pyc ADDED
Binary file (8.71 kB). View file
 
motionclone/models/__pycache__/motion_module.cpython-38.pyc ADDED
Binary file (8.67 kB). View file
 
motionclone/models/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (5.31 kB). View file
 
motionclone/models/__pycache__/resnet.cpython-38.pyc ADDED
Binary file (5.41 kB). View file
 
motionclone/models/__pycache__/sparse_controlnet.cpython-38.pyc ADDED
Binary file (14 kB). View file
 
motionclone/models/__pycache__/unet.cpython-310.pyc ADDED
Binary file (12.7 kB). View file
 
motionclone/models/__pycache__/unet.cpython-38.pyc ADDED
Binary file (12.4 kB). View file
 
motionclone/models/__pycache__/unet_blocks.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
motionclone/models/__pycache__/unet_blocks.cpython-38.pyc ADDED
Binary file (12.1 kB). View file
 
motionclone/models/attention.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import FeedForward, AdaLayerNorm
15
+
16
+ from einops import rearrange, repeat
17
+ import pdb
18
+
19
+ @dataclass
20
+ class Transformer3DModelOutput(BaseOutput):
21
+ sample: torch.FloatTensor
22
+
23
+
24
+ if is_xformers_available():
25
+ import xformers
26
+ import xformers.ops
27
+ else:
28
+ xformers = None
29
+
30
+
31
+ class Transformer3DModel(ModelMixin, ConfigMixin):
32
+ @register_to_config
33
+ def __init__(
34
+ self,
35
+ num_attention_heads: int = 16,
36
+ attention_head_dim: int = 88,
37
+ in_channels: Optional[int] = None,
38
+ num_layers: int = 1,
39
+ dropout: float = 0.0,
40
+ norm_num_groups: int = 32,
41
+ cross_attention_dim: Optional[int] = None,
42
+ attention_bias: bool = False,
43
+ activation_fn: str = "geglu",
44
+ num_embeds_ada_norm: Optional[int] = None,
45
+ use_linear_projection: bool = False,
46
+ only_cross_attention: bool = False,
47
+ upcast_attention: bool = False,
48
+
49
+ unet_use_cross_frame_attention=None,
50
+ unet_use_temporal_attention=None,
51
+ ):
52
+ super().__init__()
53
+ self.use_linear_projection = use_linear_projection
54
+ self.num_attention_heads = num_attention_heads
55
+ self.attention_head_dim = attention_head_dim
56
+ inner_dim = num_attention_heads * attention_head_dim
57
+
58
+ # Define input layers
59
+ self.in_channels = in_channels
60
+
61
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
62
+ if use_linear_projection:
63
+ self.proj_in = nn.Linear(in_channels, inner_dim)
64
+ else:
65
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
66
+
67
+ # Define transformers blocks
68
+ self.transformer_blocks = nn.ModuleList(
69
+ [
70
+ BasicTransformerBlock(
71
+ inner_dim,
72
+ num_attention_heads,
73
+ attention_head_dim,
74
+ dropout=dropout,
75
+ cross_attention_dim=cross_attention_dim,
76
+ activation_fn=activation_fn,
77
+ num_embeds_ada_norm=num_embeds_ada_norm,
78
+ attention_bias=attention_bias,
79
+ only_cross_attention=only_cross_attention,
80
+ upcast_attention=upcast_attention,
81
+
82
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83
+ unet_use_temporal_attention=unet_use_temporal_attention,
84
+ )
85
+ for d in range(num_layers)
86
+ ]
87
+ )
88
+
89
+ # 4. Define output layers
90
+ if use_linear_projection:
91
+ self.proj_out = nn.Linear(in_channels, inner_dim)
92
+ else:
93
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
94
+
95
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
96
+ # Input
97
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
98
+ video_length = hidden_states.shape[2]
99
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
100
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
101
+
102
+ batch, channel, height, weight = hidden_states.shape
103
+ residual = hidden_states
104
+
105
+ hidden_states = self.norm(hidden_states)
106
+ if not self.use_linear_projection:
107
+ hidden_states = self.proj_in(hidden_states)
108
+ inner_dim = hidden_states.shape[1]
109
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
110
+ else:
111
+ inner_dim = hidden_states.shape[1]
112
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
113
+ hidden_states = self.proj_in(hidden_states)
114
+
115
+ # Blocks
116
+ for block in self.transformer_blocks:
117
+ hidden_states = block(
118
+ hidden_states,
119
+ encoder_hidden_states=encoder_hidden_states,
120
+ timestep=timestep,
121
+ video_length=video_length
122
+ )
123
+
124
+ # Output
125
+ if not self.use_linear_projection:
126
+ hidden_states = (
127
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
128
+ )
129
+ hidden_states = self.proj_out(hidden_states)
130
+ else:
131
+ hidden_states = self.proj_out(hidden_states)
132
+ hidden_states = (
133
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
134
+ )
135
+
136
+ output = hidden_states + residual
137
+
138
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
139
+ if not return_dict:
140
+ return (output,)
141
+
142
+ return Transformer3DModelOutput(sample=output)
143
+
144
+
145
+ class BasicTransformerBlock(nn.Module):
146
+ def __init__(
147
+ self,
148
+ dim: int,
149
+ num_attention_heads: int,
150
+ attention_head_dim: int,
151
+ dropout=0.0,
152
+ cross_attention_dim: Optional[int] = None,
153
+ activation_fn: str = "geglu",
154
+ num_embeds_ada_norm: Optional[int] = None,
155
+ attention_bias: bool = False,
156
+ only_cross_attention: bool = False,
157
+ upcast_attention: bool = False,
158
+
159
+ unet_use_cross_frame_attention = None,
160
+ unet_use_temporal_attention = None,
161
+ ):
162
+ super().__init__()
163
+ self.only_cross_attention = only_cross_attention
164
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
165
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
166
+ self.unet_use_temporal_attention = unet_use_temporal_attention
167
+
168
+ # SC-Attn
169
+ assert unet_use_cross_frame_attention is not None
170
+ if unet_use_cross_frame_attention:
171
+ self.attn1 = SparseCausalAttention2D(
172
+ query_dim=dim,
173
+ heads=num_attention_heads,
174
+ dim_head=attention_head_dim,
175
+ dropout=dropout,
176
+ bias=attention_bias,
177
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
178
+ upcast_attention=upcast_attention,
179
+ )
180
+ else:
181
+ self.attn1 = CrossAttention(
182
+ query_dim=dim,
183
+ heads=num_attention_heads,
184
+ dim_head=attention_head_dim,
185
+ dropout=dropout,
186
+ bias=attention_bias,
187
+ upcast_attention=upcast_attention,
188
+ )
189
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
190
+
191
+ # Cross-Attn
192
+ if cross_attention_dim is not None:
193
+ self.attn2 = CrossAttention(
194
+ query_dim=dim,
195
+ cross_attention_dim=cross_attention_dim,
196
+ heads=num_attention_heads,
197
+ dim_head=attention_head_dim,
198
+ dropout=dropout,
199
+ bias=attention_bias,
200
+ upcast_attention=upcast_attention,
201
+ )
202
+ else:
203
+ self.attn2 = None
204
+
205
+ if cross_attention_dim is not None:
206
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
207
+ else:
208
+ self.norm2 = None
209
+
210
+ # Feed-forward
211
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
212
+ self.norm3 = nn.LayerNorm(dim)
213
+
214
+ # Temp-Attn
215
+ assert unet_use_temporal_attention is not None
216
+ if unet_use_temporal_attention:
217
+ self.attn_temp = CrossAttention(
218
+ query_dim=dim,
219
+ heads=num_attention_heads,
220
+ dim_head=attention_head_dim,
221
+ dropout=dropout,
222
+ bias=attention_bias,
223
+ upcast_attention=upcast_attention,
224
+ )
225
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
226
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
227
+
228
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None):
229
+ if not is_xformers_available():
230
+ print("Here is how to install it")
231
+ raise ModuleNotFoundError(
232
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
233
+ " xformers",
234
+ name="xformers",
235
+ )
236
+ elif not torch.cuda.is_available():
237
+ raise ValueError(
238
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
239
+ " available for GPU "
240
+ )
241
+ else:
242
+ try:
243
+ # Make sure we can run the memory efficient attention
244
+ _ = xformers.ops.memory_efficient_attention(
245
+ torch.randn((1, 2, 40), device="cuda"),
246
+ torch.randn((1, 2, 40), device="cuda"),
247
+ torch.randn((1, 2, 40), device="cuda"),
248
+ )
249
+ except Exception as e:
250
+ raise e
251
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
252
+ if self.attn2 is not None:
253
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
254
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
255
+
256
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
257
+ # SparseCausal-Attention
258
+ norm_hidden_states = (
259
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
260
+ )
261
+
262
+ # if self.only_cross_attention:
263
+ # hidden_states = (
264
+ # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
265
+ # )
266
+ # else:
267
+ # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
268
+
269
+ # pdb.set_trace()
270
+ if self.unet_use_cross_frame_attention:
271
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
272
+ else:
273
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
274
+
275
+ if self.attn2 is not None:
276
+ # Cross-Attention
277
+ norm_hidden_states = (
278
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
279
+ )
280
+ hidden_states = (
281
+ self.attn2(
282
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
283
+ )
284
+ + hidden_states
285
+ )
286
+
287
+ # Feed-forward
288
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
289
+
290
+ # Temporal-Attention
291
+ if self.unet_use_temporal_attention:
292
+ d = hidden_states.shape[1]
293
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
294
+ norm_hidden_states = (
295
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
296
+ )
297
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
298
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
299
+
300
+ return hidden_states
301
+
302
+ class CrossAttention(nn.Module):
303
+ r"""
304
+ A cross attention layer.
305
+
306
+ Parameters:
307
+ query_dim (`int`): The number of channels in the query.
308
+ cross_attention_dim (`int`, *optional*):
309
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
310
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
311
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
312
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
313
+ bias (`bool`, *optional*, defaults to False):
314
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
315
+ """
316
+
317
+ def __init__(
318
+ self,
319
+ query_dim: int,
320
+ cross_attention_dim: Optional[int] = None,
321
+ heads: int = 8,
322
+ dim_head: int = 64,
323
+ dropout: float = 0.0,
324
+ bias=False,
325
+ upcast_attention: bool = False,
326
+ upcast_softmax: bool = False,
327
+ added_kv_proj_dim: Optional[int] = None,
328
+ norm_num_groups: Optional[int] = None,
329
+ ):
330
+ super().__init__()
331
+ inner_dim = dim_head * heads
332
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
333
+ self.upcast_attention = upcast_attention
334
+ self.upcast_softmax = upcast_softmax
335
+
336
+ self.scale = dim_head**-0.5
337
+
338
+ self.heads = heads
339
+ # for slice_size > 0 the attention score computation
340
+ # is split across the batch axis to save memory
341
+ # You can set slice_size with `set_attention_slice`
342
+ self.sliceable_head_dim = heads
343
+ self._slice_size = None
344
+ self._use_memory_efficient_attention_xformers = False
345
+ self.added_kv_proj_dim = added_kv_proj_dim
346
+
347
+ #### add processer
348
+ self.processor = None
349
+
350
+ if norm_num_groups is not None:
351
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
352
+ else:
353
+ self.group_norm = None
354
+
355
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
356
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
357
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
358
+
359
+ if self.added_kv_proj_dim is not None:
360
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
361
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
362
+
363
+ self.to_out = nn.ModuleList([])
364
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
365
+ self.to_out.append(nn.Dropout(dropout))
366
+
367
+ def reshape_heads_to_batch_dim(self, tensor):
368
+ batch_size, seq_len, dim = tensor.shape
369
+ head_size = self.heads
370
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
371
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
372
+ return tensor
373
+
374
+ def reshape_batch_dim_to_heads(self, tensor):
375
+ batch_size, seq_len, dim = tensor.shape
376
+ head_size = self.heads
377
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
378
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
379
+ return tensor
380
+
381
+ def set_attention_slice(self, slice_size):
382
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
383
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
384
+
385
+ self._slice_size = slice_size
386
+
387
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
388
+ batch_size, sequence_length, _ = hidden_states.shape
389
+
390
+ encoder_hidden_states = encoder_hidden_states
391
+
392
+ if self.group_norm is not None:
393
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
394
+
395
+ query = self.to_q(hidden_states)
396
+ dim = query.shape[-1]
397
+ # query = self.reshape_heads_to_batch_dim(query) # move backwards
398
+
399
+ if self.added_kv_proj_dim is not None:
400
+ key = self.to_k(hidden_states)
401
+ value = self.to_v(hidden_states)
402
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
403
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
404
+
405
+ ######record###### record before reshape heads to batch dim
406
+ if self.processor is not None:
407
+ self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask)
408
+ ##################
409
+
410
+ key = self.reshape_heads_to_batch_dim(key)
411
+ value = self.reshape_heads_to_batch_dim(value)
412
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
413
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
414
+
415
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
416
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
417
+ else:
418
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
419
+ key = self.to_k(encoder_hidden_states)
420
+ value = self.to_v(encoder_hidden_states)
421
+
422
+ ######record######
423
+ if self.processor is not None:
424
+ self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask)
425
+ ##################
426
+
427
+ key = self.reshape_heads_to_batch_dim(key)
428
+ value = self.reshape_heads_to_batch_dim(value)
429
+
430
+ query = self.reshape_heads_to_batch_dim(query) # reshape query
431
+
432
+ if attention_mask is not None:
433
+ if attention_mask.shape[-1] != query.shape[1]:
434
+ target_length = query.shape[1]
435
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
436
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
437
+
438
+ ######record######
439
+ if self.processor is not None:
440
+ self.processor.record_attn_mask(self, hidden_states, query, key, value, attention_mask)
441
+ ##################
442
+
443
+ # attention, what we cannot get enough of
444
+ if self._use_memory_efficient_attention_xformers:
445
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
446
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
447
+ hidden_states = hidden_states.to(query.dtype)
448
+ else:
449
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
450
+ hidden_states = self._attention(query, key, value, attention_mask)
451
+ else:
452
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
453
+
454
+ # linear proj
455
+ hidden_states = self.to_out[0](hidden_states)
456
+
457
+ # dropout
458
+ hidden_states = self.to_out[1](hidden_states)
459
+ return hidden_states
460
+
461
+ def _attention(self, query, key, value, attention_mask=None):
462
+ if self.upcast_attention:
463
+ query = query.float()
464
+ key = key.float()
465
+
466
+ attention_scores = torch.baddbmm(
467
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
468
+ query,
469
+ key.transpose(-1, -2),
470
+ beta=0,
471
+ alpha=self.scale,
472
+ )
473
+
474
+ if attention_mask is not None:
475
+ attention_scores = attention_scores + attention_mask
476
+
477
+ if self.upcast_softmax:
478
+ attention_scores = attention_scores.float()
479
+
480
+ attention_probs = attention_scores.softmax(dim=-1)
481
+
482
+ # cast back to the original dtype
483
+ attention_probs = attention_probs.to(value.dtype)
484
+
485
+ # compute attention output
486
+ hidden_states = torch.bmm(attention_probs, value)
487
+
488
+ # reshape hidden_states
489
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
490
+ return hidden_states
491
+
492
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
493
+ batch_size_attention = query.shape[0]
494
+ hidden_states = torch.zeros(
495
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
496
+ )
497
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
498
+ for i in range(hidden_states.shape[0] // slice_size):
499
+ start_idx = i * slice_size
500
+ end_idx = (i + 1) * slice_size
501
+
502
+ query_slice = query[start_idx:end_idx]
503
+ key_slice = key[start_idx:end_idx]
504
+
505
+ if self.upcast_attention:
506
+ query_slice = query_slice.float()
507
+ key_slice = key_slice.float()
508
+
509
+ attn_slice = torch.baddbmm(
510
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
511
+ query_slice,
512
+ key_slice.transpose(-1, -2),
513
+ beta=0,
514
+ alpha=self.scale,
515
+ )
516
+
517
+ if attention_mask is not None:
518
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
519
+
520
+ if self.upcast_softmax:
521
+ attn_slice = attn_slice.float()
522
+
523
+ attn_slice = attn_slice.softmax(dim=-1)
524
+
525
+ # cast back to the original dtype
526
+ attn_slice = attn_slice.to(value.dtype)
527
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
528
+
529
+ hidden_states[start_idx:end_idx] = attn_slice
530
+
531
+ # reshape hidden_states
532
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
533
+ return hidden_states
534
+
535
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
536
+ # TODO attention_mask
537
+ query = query.contiguous()
538
+ key = key.contiguous()
539
+ value = value.contiguous()
540
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
541
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
542
+ return hidden_states
543
+
544
+ def set_processor(self, processor: "AttnProcessor") -> None:
545
+ r"""
546
+ Set the attention processor to use.
547
+
548
+ Args:
549
+ processor (`AttnProcessor`):
550
+ The attention processor to use.
551
+ """
552
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
553
+ # pop `processor` from `self._modules`
554
+ if (
555
+ hasattr(self, "processor")
556
+ and isinstance(self.processor, torch.nn.Module)
557
+ and not isinstance(processor, torch.nn.Module)
558
+ ):
559
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
560
+ self._modules.pop("processor")
561
+
562
+ self.processor = processor
563
+
564
+ def get_attention_scores(
565
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
566
+ ) -> torch.Tensor:
567
+ r"""
568
+ Compute the attention scores.
569
+
570
+ Args:
571
+ query (`torch.Tensor`): The query tensor.
572
+ key (`torch.Tensor`): The key tensor.
573
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
574
+
575
+ Returns:
576
+ `torch.Tensor`: The attention probabilities/scores.
577
+ """
578
+ dtype = query.dtype
579
+ if self.upcast_attention:
580
+ query = query.float()
581
+ key = key.float()
582
+
583
+ if attention_mask is None:
584
+ baddbmm_input = torch.empty(
585
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
586
+ )
587
+ beta = 0
588
+ else:
589
+ baddbmm_input = attention_mask
590
+ beta = 1
591
+
592
+
593
+
594
+ attention_scores = torch.baddbmm(
595
+ baddbmm_input,
596
+ query,
597
+ key.transpose(-1, -2),
598
+ beta=beta,
599
+ alpha=self.scale,
600
+ )
601
+ del baddbmm_input
602
+
603
+ if self.upcast_softmax:
604
+ attention_scores = attention_scores.float()
605
+
606
+ attention_probs = attention_scores.softmax(dim=-1)
607
+ del attention_scores
608
+
609
+ attention_probs = attention_probs.to(dtype)
610
+
611
+ return attention_probs
motionclone/models/motion_module.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ import torchvision
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import FeedForward
15
+ from .attention import CrossAttention
16
+
17
+ from einops import rearrange, repeat
18
+ import math
19
+
20
+
21
+ def zero_module(module):
22
+ # Zero out the parameters of a module and return it.
23
+ for p in module.parameters():
24
+ p.detach().zero_()
25
+ return module
26
+
27
+
28
+ @dataclass
29
+ class TemporalTransformer3DModelOutput(BaseOutput):
30
+ sample: torch.FloatTensor
31
+
32
+
33
+ if is_xformers_available():
34
+ import xformers
35
+ import xformers.ops
36
+ else:
37
+ xformers = None
38
+
39
+
40
+ def get_motion_module( # 只能返回VanillaTemporalModule类
41
+ in_channels,
42
+ motion_module_type: str,
43
+ motion_module_kwargs: dict
44
+ ):
45
+ if motion_module_type == "Vanilla":
46
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
47
+ else:
48
+ raise ValueError
49
+
50
+
51
+ class VanillaTemporalModule(nn.Module):
52
+ def __init__(
53
+ self,
54
+ in_channels,
55
+ num_attention_heads = 8,
56
+ num_transformer_block = 2,
57
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
58
+ cross_frame_attention_mode = None,
59
+ temporal_position_encoding = False,
60
+ temporal_position_encoding_max_len = 32,
61
+ temporal_attention_dim_div = 1,
62
+ zero_initialize = True,
63
+ ):
64
+ super().__init__()
65
+
66
+ self.temporal_transformer = TemporalTransformer3DModel(
67
+ in_channels=in_channels,
68
+ num_attention_heads=num_attention_heads,
69
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
70
+ num_layers=num_transformer_block,
71
+ attention_block_types=attention_block_types,
72
+ cross_frame_attention_mode=cross_frame_attention_mode,
73
+ temporal_position_encoding=temporal_position_encoding,
74
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
75
+ )
76
+
77
+ if zero_initialize:
78
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
79
+
80
+ def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
81
+ hidden_states = input_tensor
82
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
83
+
84
+ output = hidden_states
85
+ return output
86
+
87
+
88
+ class TemporalTransformer3DModel(nn.Module):
89
+ def __init__(
90
+ self,
91
+ in_channels,
92
+ num_attention_heads,
93
+ attention_head_dim,
94
+
95
+ num_layers,
96
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ), # 两个TempAttn
97
+ dropout = 0.0,
98
+ norm_num_groups = 32,
99
+ cross_attention_dim = 768,
100
+ activation_fn = "geglu",
101
+ attention_bias = False,
102
+ upcast_attention = False,
103
+
104
+ cross_frame_attention_mode = None,
105
+ temporal_position_encoding = False,
106
+ temporal_position_encoding_max_len = 24,
107
+ ):
108
+ super().__init__()
109
+
110
+ inner_dim = num_attention_heads * attention_head_dim
111
+
112
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
113
+ self.proj_in = nn.Linear(in_channels, inner_dim)
114
+
115
+ self.transformer_blocks = nn.ModuleList(
116
+ [
117
+ TemporalTransformerBlock(
118
+ dim=inner_dim,
119
+ num_attention_heads=num_attention_heads,
120
+ attention_head_dim=attention_head_dim,
121
+ attention_block_types=attention_block_types,
122
+ dropout=dropout,
123
+ norm_num_groups=norm_num_groups,
124
+ cross_attention_dim=cross_attention_dim,
125
+ activation_fn=activation_fn,
126
+ attention_bias=attention_bias,
127
+ upcast_attention=upcast_attention,
128
+ cross_frame_attention_mode=cross_frame_attention_mode,
129
+ temporal_position_encoding=temporal_position_encoding,
130
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
131
+ )
132
+ for d in range(num_layers)
133
+ ]
134
+ )
135
+ self.proj_out = nn.Linear(inner_dim, in_channels)
136
+
137
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
138
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
139
+ video_length = hidden_states.shape[2]
140
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
141
+
142
+ batch, channel, height, weight = hidden_states.shape
143
+ residual = hidden_states
144
+
145
+ hidden_states = self.norm(hidden_states)
146
+ inner_dim = hidden_states.shape[1]
147
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
148
+ hidden_states = self.proj_in(hidden_states)
149
+
150
+ # Transformer Blocks
151
+ for block in self.transformer_blocks:
152
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
153
+
154
+ # output
155
+ hidden_states = self.proj_out(hidden_states)
156
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
157
+
158
+ output = hidden_states + residual
159
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
160
+
161
+ return output
162
+
163
+
164
+ class TemporalTransformerBlock(nn.Module):
165
+ def __init__(
166
+ self,
167
+ dim,
168
+ num_attention_heads,
169
+ attention_head_dim,
170
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
171
+ dropout = 0.0,
172
+ norm_num_groups = 32,
173
+ cross_attention_dim = 768,
174
+ activation_fn = "geglu",
175
+ attention_bias = False,
176
+ upcast_attention = False,
177
+ cross_frame_attention_mode = None,
178
+ temporal_position_encoding = False,
179
+ temporal_position_encoding_max_len = 24,
180
+ ):
181
+ super().__init__()
182
+
183
+ attention_blocks = []
184
+ norms = []
185
+
186
+ for block_name in attention_block_types:
187
+ attention_blocks.append(
188
+ VersatileAttention(
189
+ attention_mode=block_name.split("_")[0],
190
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
191
+
192
+ query_dim=dim,
193
+ heads=num_attention_heads,
194
+ dim_head=attention_head_dim,
195
+ dropout=dropout,
196
+ bias=attention_bias,
197
+ upcast_attention=upcast_attention,
198
+
199
+ cross_frame_attention_mode=cross_frame_attention_mode,
200
+ temporal_position_encoding=temporal_position_encoding,
201
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
202
+ )
203
+ )
204
+ norms.append(nn.LayerNorm(dim))
205
+
206
+ self.attention_blocks = nn.ModuleList(attention_blocks)
207
+ self.norms = nn.ModuleList(norms)
208
+
209
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
210
+ self.ff_norm = nn.LayerNorm(dim)
211
+
212
+
213
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
214
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
215
+ norm_hidden_states = norm(hidden_states)
216
+ hidden_states = attention_block(
217
+ norm_hidden_states,
218
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
219
+ video_length=video_length,
220
+ ) + hidden_states
221
+
222
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
223
+
224
+ output = hidden_states
225
+ return output
226
+
227
+
228
+ class PositionalEncoding(nn.Module):
229
+ def __init__(
230
+ self,
231
+ d_model,
232
+ dropout = 0.,
233
+ max_len = 24
234
+ ):
235
+ super().__init__()
236
+ self.dropout = nn.Dropout(p=dropout)
237
+ position = torch.arange(max_len).unsqueeze(1)
238
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
239
+ pe = torch.zeros(1, max_len, d_model)
240
+ pe[0, :, 0::2] = torch.sin(position * div_term)
241
+ pe[0, :, 1::2] = torch.cos(position * div_term)
242
+ # self.register_buffer('pe', pe)
243
+ self.register_buffer('pe', pe, persistent=False)
244
+
245
+ def forward(self, x):
246
+ x = x + self.pe[:, :x.size(1)]
247
+ return self.dropout(x)
248
+
249
+
250
+ class VersatileAttention(CrossAttention): # 继承CrossAttention类,不需要在额外写set_processor功能
251
+ def __init__(
252
+ self,
253
+ attention_mode = None,
254
+ cross_frame_attention_mode = None,
255
+ temporal_position_encoding = False,
256
+ temporal_position_encoding_max_len = 24,
257
+ *args, **kwargs
258
+ ):
259
+ super().__init__(*args, **kwargs)
260
+ assert attention_mode == "Temporal"
261
+
262
+ self.attention_mode = attention_mode
263
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
264
+
265
+ self.pos_encoder = PositionalEncoding(
266
+ kwargs["query_dim"],
267
+ dropout=0.,
268
+ max_len=temporal_position_encoding_max_len
269
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
270
+
271
+ def extra_repr(self):
272
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
273
+
274
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
275
+ batch_size, sequence_length, _ = hidden_states.shape
276
+
277
+ if self.attention_mode == "Temporal":
278
+ d = hidden_states.shape[1]
279
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
280
+
281
+ if self.pos_encoder is not None:
282
+ hidden_states = self.pos_encoder(hidden_states)
283
+
284
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
285
+ else:
286
+ raise NotImplementedError
287
+
288
+ encoder_hidden_states = encoder_hidden_states
289
+
290
+ if self.group_norm is not None:
291
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
292
+
293
+ query = self.to_q(hidden_states)
294
+ dim = query.shape[-1]
295
+ # query = self.reshape_heads_to_batch_dim(query) # move backwards
296
+
297
+ if self.added_kv_proj_dim is not None:
298
+ raise NotImplementedError
299
+
300
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
301
+ key = self.to_k(encoder_hidden_states)
302
+ value = self.to_v(encoder_hidden_states)
303
+
304
+ ######record###### record before reshape heads to batch dim
305
+ if self.processor is not None:
306
+ self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask)
307
+ ##################
308
+
309
+ key = self.reshape_heads_to_batch_dim(key)
310
+ value = self.reshape_heads_to_batch_dim(value)
311
+
312
+ query = self.reshape_heads_to_batch_dim(query) # reshape query here
313
+
314
+ if attention_mask is not None:
315
+ if attention_mask.shape[-1] != query.shape[1]:
316
+ target_length = query.shape[1]
317
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
318
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
319
+
320
+ ######record######
321
+ # if self.processor is not None:
322
+ # self.processor.record_attn_mask(self, hidden_states, query, key, value, attention_mask)
323
+ ##################
324
+
325
+ # attention, what we cannot get enough of
326
+ if self._use_memory_efficient_attention_xformers:
327
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
328
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
329
+ hidden_states = hidden_states.to(query.dtype)
330
+ else:
331
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
332
+ hidden_states = self._attention(query, key, value, attention_mask)
333
+ else:
334
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
335
+
336
+ # linear proj
337
+ hidden_states = self.to_out[0](hidden_states)
338
+
339
+ # dropout
340
+ hidden_states = self.to_out[1](hidden_states)
341
+
342
+ if self.attention_mode == "Temporal":
343
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
344
+
345
+ return hidden_states
346
+
347
+
motionclone/models/resnet.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange
8
+
9
+
10
+ class InflatedConv3d(nn.Conv2d):
11
+ def forward(self, x):
12
+ video_length = x.shape[2]
13
+
14
+ x = rearrange(x, "b c f h w -> (b f) c h w")
15
+ x = super().forward(x)
16
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17
+
18
+ return x
19
+
20
+
21
+ class InflatedGroupNorm(nn.GroupNorm):
22
+ def forward(self, x):
23
+ video_length = x.shape[2]
24
+
25
+ x = rearrange(x, "b c f h w -> (b f) c h w")
26
+ x = super().forward(x)
27
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
28
+
29
+ return x
30
+
31
+
32
+ class Upsample3D(nn.Module):
33
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
34
+ super().__init__()
35
+ self.channels = channels
36
+ self.out_channels = out_channels or channels
37
+ self.use_conv = use_conv
38
+ self.use_conv_transpose = use_conv_transpose
39
+ self.name = name
40
+
41
+ conv = None
42
+ if use_conv_transpose:
43
+ raise NotImplementedError
44
+ elif use_conv:
45
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
46
+
47
+ def forward(self, hidden_states, output_size=None):
48
+ assert hidden_states.shape[1] == self.channels
49
+
50
+ if self.use_conv_transpose:
51
+ raise NotImplementedError
52
+
53
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
54
+ dtype = hidden_states.dtype
55
+ if dtype == torch.bfloat16:
56
+ hidden_states = hidden_states.to(torch.float32)
57
+
58
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
59
+ if hidden_states.shape[0] >= 64:
60
+ hidden_states = hidden_states.contiguous()
61
+
62
+ # if `output_size` is passed we force the interpolation output
63
+ # size and do not make use of `scale_factor=2`
64
+ if output_size is None:
65
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
66
+ else:
67
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
68
+
69
+ # If the input is bfloat16, we cast back to bfloat16
70
+ if dtype == torch.bfloat16:
71
+ hidden_states = hidden_states.to(dtype)
72
+
73
+ # if self.use_conv:
74
+ # if self.name == "conv":
75
+ # hidden_states = self.conv(hidden_states)
76
+ # else:
77
+ # hidden_states = self.Conv2d_0(hidden_states)
78
+ hidden_states = self.conv(hidden_states)
79
+
80
+ return hidden_states
81
+
82
+
83
+ class Downsample3D(nn.Module):
84
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
85
+ super().__init__()
86
+ self.channels = channels
87
+ self.out_channels = out_channels or channels
88
+ self.use_conv = use_conv
89
+ self.padding = padding
90
+ stride = 2
91
+ self.name = name
92
+
93
+ if use_conv:
94
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ def forward(self, hidden_states):
99
+ assert hidden_states.shape[1] == self.channels
100
+ if self.use_conv and self.padding == 0:
101
+ raise NotImplementedError
102
+
103
+ assert hidden_states.shape[1] == self.channels
104
+ hidden_states = self.conv(hidden_states)
105
+
106
+ return hidden_states
107
+
108
+
109
+ class ResnetBlock3D(nn.Module):
110
+ def __init__(
111
+ self,
112
+ *,
113
+ in_channels,
114
+ out_channels=None,
115
+ conv_shortcut=False,
116
+ dropout=0.0,
117
+ temb_channels=512,
118
+ groups=32,
119
+ groups_out=None,
120
+ pre_norm=True,
121
+ eps=1e-6,
122
+ non_linearity="swish",
123
+ time_embedding_norm="default",
124
+ output_scale_factor=1.0,
125
+ use_in_shortcut=None,
126
+ use_inflated_groupnorm=False,
127
+ ):
128
+ super().__init__()
129
+ self.pre_norm = pre_norm
130
+ self.pre_norm = True
131
+ self.in_channels = in_channels
132
+ out_channels = in_channels if out_channels is None else out_channels
133
+ self.out_channels = out_channels
134
+ self.use_conv_shortcut = conv_shortcut
135
+ self.time_embedding_norm = time_embedding_norm
136
+ self.output_scale_factor = output_scale_factor
137
+ self.upsample = self.downsample = None
138
+
139
+ if groups_out is None:
140
+ groups_out = groups
141
+
142
+ assert use_inflated_groupnorm != None
143
+ if use_inflated_groupnorm:
144
+ self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
145
+ else:
146
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
147
+
148
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
149
+
150
+ if temb_channels is not None:
151
+ if self.time_embedding_norm == "default":
152
+ time_emb_proj_out_channels = out_channels
153
+ elif self.time_embedding_norm == "scale_shift":
154
+ time_emb_proj_out_channels = out_channels * 2
155
+ else:
156
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
157
+
158
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
159
+ else:
160
+ self.time_emb_proj = None
161
+
162
+ if use_inflated_groupnorm:
163
+ self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
164
+ else:
165
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
166
+
167
+ self.dropout = torch.nn.Dropout(dropout)
168
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
169
+
170
+ if non_linearity == "swish":
171
+ self.nonlinearity = lambda x: F.silu(x)
172
+ elif non_linearity == "mish":
173
+ self.nonlinearity = Mish()
174
+ elif non_linearity == "silu":
175
+ self.nonlinearity = nn.SiLU()
176
+
177
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
178
+
179
+ self.conv_shortcut = None
180
+ if self.use_in_shortcut:
181
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
182
+
183
+ def forward(self, input_tensor, temb):
184
+ hidden_states = input_tensor
185
+
186
+ hidden_states = self.norm1(hidden_states)
187
+ hidden_states = self.nonlinearity(hidden_states)
188
+
189
+ hidden_states = self.conv1(hidden_states)
190
+
191
+ if temb is not None:
192
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
193
+
194
+ if temb is not None and self.time_embedding_norm == "default":
195
+ hidden_states = hidden_states + temb
196
+
197
+ hidden_states = self.norm2(hidden_states)
198
+
199
+ if temb is not None and self.time_embedding_norm == "scale_shift":
200
+ scale, shift = torch.chunk(temb, 2, dim=1)
201
+ hidden_states = hidden_states * (1 + scale) + shift
202
+
203
+ hidden_states = self.nonlinearity(hidden_states)
204
+
205
+ hidden_states = self.dropout(hidden_states)
206
+ hidden_states = self.conv2(hidden_states)
207
+
208
+ if self.conv_shortcut is not None:
209
+ input_tensor = self.conv_shortcut(input_tensor)
210
+
211
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
212
+
213
+ return output_tensor
214
+
215
+
216
+ class Mish(torch.nn.Module):
217
+ def forward(self, hidden_states):
218
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
motionclone/models/scheduler.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from diffusers import DDIMScheduler
5
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
6
+ from diffusers.utils.torch_utils import randn_tensor
7
+
8
+
9
+ class CustomDDIMScheduler(DDIMScheduler):
10
+ @torch.no_grad()
11
+ def step(
12
+ self,
13
+ model_output: torch.FloatTensor,
14
+ timestep: int,
15
+ sample: torch.FloatTensor,
16
+ eta: float = 0.0,
17
+ use_clipped_model_output: bool = False,
18
+ generator=None,
19
+ variance_noise: Optional[torch.FloatTensor] = None,
20
+ return_dict: bool = True,
21
+
22
+ # Guidance parameters
23
+ score=None,
24
+ guidance_scale=0.0,
25
+ indices=None, # [0]
26
+
27
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
28
+ """
29
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
30
+ process from the learned model outputs (most often the predicted noise).
31
+
32
+ Args:
33
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
34
+ timestep (`int`): current discrete timestep in the diffusion chain.
35
+ sample (`torch.FloatTensor`):
36
+ current instance of sample being created by diffusion process.
37
+ eta (`float`): weight of noise for added noise in diffusion step.
38
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
39
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
40
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
41
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
42
+ generator: random number generator.
43
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
44
+ can directly provide the noise for the variance itself. This is useful for methods such as
45
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
46
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
47
+
48
+ Returns:
49
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
50
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
51
+ returning a tuple, the first element is the sample tensor.
52
+
53
+ """
54
+ if self.num_inference_steps is None:
55
+ raise ValueError(
56
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
57
+ )
58
+
59
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
60
+ # Ideally, read DDIM paper in-detail understanding
61
+
62
+ # Notation (<variable name> -> <name in paper>
63
+ # - pred_noise_t -> e_theta(x_t, t)
64
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
65
+ # - std_dev_t -> sigma_t
66
+ # - eta -> η
67
+ # - pred_sample_direction -> "direction pointing to x_t"
68
+ # - pred_prev_sample -> "x_t-1"
69
+
70
+
71
+ # Support IF models
72
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
73
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
74
+ else:
75
+ predicted_variance = None
76
+
77
+ # 1. get previous step value (=t-1)
78
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
79
+
80
+ # 2. compute alphas, betas
81
+ alpha_prod_t = self.alphas_cumprod[timestep]
82
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
83
+
84
+ beta_prod_t = 1 - alpha_prod_t
85
+
86
+ # 3. compute predicted original sample from predicted noise also called
87
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
88
+ if self.config.prediction_type == "epsilon":
89
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
90
+ pred_epsilon = model_output
91
+ elif self.config.prediction_type == "sample":
92
+ pred_original_sample = model_output
93
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
94
+ elif self.config.prediction_type == "v_prediction":
95
+ pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
96
+ pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample
97
+ else:
98
+ raise ValueError(
99
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
100
+ " `v_prediction`"
101
+ )
102
+
103
+ # 4. Clip or threshold "predicted x_0"
104
+ if self.config.thresholding:
105
+ pred_original_sample = self._threshold_sample(pred_original_sample)
106
+ elif self.config.clip_sample:
107
+ pred_original_sample = pred_original_sample.clamp(
108
+ -self.config.clip_sample_range, self.config.clip_sample_range
109
+ )
110
+
111
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
112
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
113
+ variance = self._get_variance(timestep, prev_timestep)
114
+ std_dev_t = eta * variance ** (0.5)
115
+
116
+ if use_clipped_model_output:
117
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
118
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # [2, 4, 64, 64]
119
+
120
+ # 6. apply guidance following the formula (14) from https://arxiv.org/pdf/2105.05233.pdf
121
+ if score is not None and guidance_scale > 0.0: # indices指定了应用guidance的位置,此处indices = [0]
122
+ if indices is not None:
123
+ # import pdb; pdb.set_trace()
124
+ assert pred_epsilon[indices].shape == score.shape, "pred_epsilon[indices].shape != score.shape"
125
+ pred_epsilon[indices] = pred_epsilon[indices] - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score # 只修改了其中第一个[1, 4, 64, 64]的部分
126
+ else:
127
+ assert pred_epsilon.shape == score.shape
128
+ pred_epsilon = pred_epsilon - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score
129
+ #
130
+
131
+ # 7. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
132
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon # [2, 4, 64, 64]
133
+
134
+ # 8. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
135
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction # [2, 4, 64, 64]
136
+
137
+ if eta > 0:
138
+ if variance_noise is not None and generator is not None:
139
+ raise ValueError(
140
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
141
+ " `variance_noise` stays `None`."
142
+ )
143
+
144
+ if variance_noise is None:
145
+ variance_noise = randn_tensor(
146
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
147
+ )
148
+ variance = std_dev_t * variance_noise # 最后还要再加一些随机噪声
149
+
150
+ prev_sample = prev_sample + variance # [2, 4, 64, 64]
151
+ self.pred_epsilon = pred_epsilon
152
+ if not return_dict:
153
+ return (prev_sample,)
154
+
155
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
motionclone/models/sparse_controlnet.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Changes were made to this source code by Yuwei Guo.
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from torch import nn
21
+ from torch.nn import functional as F
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+
28
+
29
+ from .unet_blocks import (
30
+ CrossAttnDownBlock3D,
31
+ DownBlock3D,
32
+ UNetMidBlock3DCrossAttn,
33
+ get_down_block,
34
+ )
35
+ from einops import repeat, rearrange
36
+ from .resnet import InflatedConv3d
37
+
38
+ from diffusers.models.unet_2d_condition import UNet2DConditionModel
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ @dataclass
44
+ class SparseControlNetOutput(BaseOutput):
45
+ down_block_res_samples: Tuple[torch.Tensor]
46
+ mid_block_res_sample: torch.Tensor
47
+
48
+
49
+ class SparseControlNetConditioningEmbedding(nn.Module):
50
+ def __init__(
51
+ self,
52
+ conditioning_embedding_channels: int,
53
+ conditioning_channels: int = 3,
54
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
55
+ ):
56
+ super().__init__()
57
+
58
+ self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
59
+
60
+ self.blocks = nn.ModuleList([])
61
+
62
+ for i in range(len(block_out_channels) - 1):
63
+ channel_in = block_out_channels[i]
64
+ channel_out = block_out_channels[i + 1]
65
+ self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
66
+ self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
67
+
68
+ self.conv_out = zero_module(
69
+ InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
70
+ )
71
+
72
+ def forward(self, conditioning):
73
+ embedding = self.conv_in(conditioning)
74
+ embedding = F.silu(embedding)
75
+
76
+ for block in self.blocks:
77
+ embedding = block(embedding)
78
+ embedding = F.silu(embedding)
79
+
80
+ embedding = self.conv_out(embedding)
81
+
82
+ return embedding
83
+
84
+
85
+ class SparseControlNetModel(ModelMixin, ConfigMixin):
86
+ _supports_gradient_checkpointing = True
87
+
88
+ @register_to_config
89
+ def __init__(
90
+ self,
91
+ in_channels: int = 4,
92
+ conditioning_channels: int = 3,
93
+ flip_sin_to_cos: bool = True,
94
+ freq_shift: int = 0,
95
+ down_block_types: Tuple[str] = (
96
+ "CrossAttnDownBlock2D",
97
+ "CrossAttnDownBlock2D",
98
+ "CrossAttnDownBlock2D",
99
+ "DownBlock2D",
100
+ ),
101
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
102
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
103
+ layers_per_block: int = 2,
104
+ downsample_padding: int = 1,
105
+ mid_block_scale_factor: float = 1,
106
+ act_fn: str = "silu",
107
+ norm_num_groups: Optional[int] = 32,
108
+ norm_eps: float = 1e-5,
109
+ cross_attention_dim: int = 1280,
110
+ attention_head_dim: Union[int, Tuple[int]] = 8,
111
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
112
+ use_linear_projection: bool = False,
113
+ class_embed_type: Optional[str] = None,
114
+ num_class_embeds: Optional[int] = None,
115
+ upcast_attention: bool = False,
116
+ resnet_time_scale_shift: str = "default",
117
+ projection_class_embeddings_input_dim: Optional[int] = None,
118
+ controlnet_conditioning_channel_order: str = "rgb",
119
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
120
+ global_pool_conditions: bool = False,
121
+
122
+ use_motion_module = True,
123
+ motion_module_resolutions = ( 1,2,4,8 ),
124
+ motion_module_mid_block = False,
125
+ motion_module_type = "Vanilla",
126
+ motion_module_kwargs = {
127
+ "num_attention_heads": 8,
128
+ "num_transformer_block": 1,
129
+ "attention_block_types": ["Temporal_Self"],
130
+ "temporal_position_encoding": True,
131
+ "temporal_position_encoding_max_len": 32,
132
+ "temporal_attention_dim_div": 1,
133
+ "causal_temporal_attention": False,
134
+ },
135
+
136
+ concate_conditioning_mask: bool = True,
137
+ use_simplified_condition_embedding: bool = False,
138
+
139
+ set_noisy_sample_input_to_zero: bool = False,
140
+ ):
141
+ super().__init__()
142
+
143
+ # If `num_attention_heads` is not defined (which is the case for most models)
144
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
145
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
146
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
147
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
148
+ # which is why we correct for the naming here.
149
+ num_attention_heads = num_attention_heads or attention_head_dim
150
+
151
+ # Check inputs
152
+ if len(block_out_channels) != len(down_block_types):
153
+ raise ValueError(
154
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
155
+ )
156
+
157
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
158
+ raise ValueError(
159
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
160
+ )
161
+
162
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
163
+ raise ValueError(
164
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
165
+ )
166
+
167
+ # input
168
+ self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero
169
+
170
+ conv_in_kernel = 3
171
+ conv_in_padding = (conv_in_kernel - 1) // 2
172
+ self.conv_in = InflatedConv3d(
173
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
174
+ )
175
+
176
+ if concate_conditioning_mask:
177
+ conditioning_channels = conditioning_channels + 1
178
+ self.concate_conditioning_mask = concate_conditioning_mask
179
+
180
+ # control net conditioning embedding
181
+ if use_simplified_condition_embedding:
182
+ self.controlnet_cond_embedding = zero_module(
183
+ InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
184
+ ).to(torch.float16)
185
+ else:
186
+ self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding(
187
+ conditioning_embedding_channels=block_out_channels[0],
188
+ block_out_channels=conditioning_embedding_out_channels,
189
+ conditioning_channels=conditioning_channels,
190
+ ).to(torch.float16)
191
+ self.use_simplified_condition_embedding = use_simplified_condition_embedding
192
+
193
+ # time
194
+ time_embed_dim = block_out_channels[0] * 4
195
+
196
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
197
+ timestep_input_dim = block_out_channels[0]
198
+
199
+ self.time_embedding = TimestepEmbedding(
200
+ timestep_input_dim,
201
+ time_embed_dim,
202
+ act_fn=act_fn,
203
+ )
204
+
205
+ # class embedding
206
+ if class_embed_type is None and num_class_embeds is not None:
207
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
208
+ elif class_embed_type == "timestep":
209
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
210
+ elif class_embed_type == "identity":
211
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
212
+ elif class_embed_type == "projection":
213
+ if projection_class_embeddings_input_dim is None:
214
+ raise ValueError(
215
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
216
+ )
217
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
218
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
219
+ # 2. it projects from an arbitrary input dimension.
220
+ #
221
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
222
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
223
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
224
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
225
+ else:
226
+ self.class_embedding = None
227
+
228
+
229
+ self.down_blocks = nn.ModuleList([])
230
+ self.controlnet_down_blocks = nn.ModuleList([])
231
+
232
+ if isinstance(only_cross_attention, bool):
233
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
234
+
235
+ if isinstance(attention_head_dim, int):
236
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
237
+
238
+ if isinstance(num_attention_heads, int):
239
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
240
+
241
+ # down
242
+ output_channel = block_out_channels[0]
243
+
244
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
245
+ controlnet_block = zero_module(controlnet_block)
246
+ self.controlnet_down_blocks.append(controlnet_block)
247
+
248
+ for i, down_block_type in enumerate(down_block_types):
249
+ res = 2 ** i
250
+ input_channel = output_channel
251
+ output_channel = block_out_channels[i]
252
+ is_final_block = i == len(block_out_channels) - 1
253
+
254
+ down_block = get_down_block(
255
+ down_block_type,
256
+ num_layers=layers_per_block,
257
+ in_channels=input_channel,
258
+ out_channels=output_channel,
259
+ temb_channels=time_embed_dim,
260
+ add_downsample=not is_final_block,
261
+ resnet_eps=norm_eps,
262
+ resnet_act_fn=act_fn,
263
+ resnet_groups=norm_num_groups,
264
+ cross_attention_dim=cross_attention_dim,
265
+ attn_num_head_channels=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
266
+ downsample_padding=downsample_padding,
267
+ use_linear_projection=use_linear_projection,
268
+ only_cross_attention=only_cross_attention[i],
269
+ upcast_attention=upcast_attention,
270
+ resnet_time_scale_shift=resnet_time_scale_shift,
271
+
272
+ use_inflated_groupnorm=True,
273
+
274
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
275
+ motion_module_type=motion_module_type,
276
+ motion_module_kwargs=motion_module_kwargs,
277
+ )
278
+ self.down_blocks.append(down_block)
279
+
280
+ for _ in range(layers_per_block):
281
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
282
+ controlnet_block = zero_module(controlnet_block)
283
+ self.controlnet_down_blocks.append(controlnet_block)
284
+
285
+ if not is_final_block:
286
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
287
+ controlnet_block = zero_module(controlnet_block)
288
+ self.controlnet_down_blocks.append(controlnet_block)
289
+
290
+ # mid
291
+ mid_block_channel = block_out_channels[-1]
292
+
293
+ controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1)
294
+ controlnet_block = zero_module(controlnet_block)
295
+ self.controlnet_mid_block = controlnet_block
296
+
297
+ self.mid_block = UNetMidBlock3DCrossAttn(
298
+ in_channels=mid_block_channel,
299
+ temb_channels=time_embed_dim,
300
+ resnet_eps=norm_eps,
301
+ resnet_act_fn=act_fn,
302
+ output_scale_factor=mid_block_scale_factor,
303
+ resnet_time_scale_shift=resnet_time_scale_shift,
304
+ cross_attention_dim=cross_attention_dim,
305
+ attn_num_head_channels=num_attention_heads[-1],
306
+ resnet_groups=norm_num_groups,
307
+ use_linear_projection=use_linear_projection,
308
+ upcast_attention=upcast_attention,
309
+
310
+ use_inflated_groupnorm=True,
311
+ use_motion_module=use_motion_module and motion_module_mid_block,
312
+ motion_module_type=motion_module_type,
313
+ motion_module_kwargs=motion_module_kwargs,
314
+ )
315
+
316
+ @classmethod
317
+ def from_unet(
318
+ cls,
319
+ unet: UNet2DConditionModel,
320
+ controlnet_conditioning_channel_order: str = "rgb",
321
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
322
+ load_weights_from_unet: bool = True,
323
+
324
+ controlnet_additional_kwargs: dict = {},
325
+ ):
326
+ controlnet = cls(
327
+ in_channels=unet.config.in_channels,
328
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
329
+ freq_shift=unet.config.freq_shift,
330
+ down_block_types=unet.config.down_block_types,
331
+ only_cross_attention=unet.config.only_cross_attention,
332
+ block_out_channels=unet.config.block_out_channels,
333
+ layers_per_block=unet.config.layers_per_block,
334
+ downsample_padding=unet.config.downsample_padding,
335
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
336
+ act_fn=unet.config.act_fn,
337
+ norm_num_groups=unet.config.norm_num_groups,
338
+ norm_eps=unet.config.norm_eps,
339
+ cross_attention_dim=unet.config.cross_attention_dim,
340
+ attention_head_dim=unet.config.attention_head_dim,
341
+ num_attention_heads=unet.config.num_attention_heads,
342
+ use_linear_projection=unet.config.use_linear_projection,
343
+ class_embed_type=unet.config.class_embed_type,
344
+ num_class_embeds=unet.config.num_class_embeds,
345
+ upcast_attention=unet.config.upcast_attention,
346
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
347
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
348
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
349
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
350
+
351
+ **controlnet_additional_kwargs,
352
+ )
353
+
354
+ if load_weights_from_unet:
355
+ m, u = controlnet.conv_in.load_state_dict(cls.image_layer_filter(unet.conv_in.state_dict()), strict=False)
356
+ assert len(u) == 0
357
+ m, u = controlnet.time_proj.load_state_dict(cls.image_layer_filter(unet.time_proj.state_dict()), strict=False)
358
+ assert len(u) == 0
359
+ m, u = controlnet.time_embedding.load_state_dict(cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False)
360
+ assert len(u) == 0
361
+
362
+ if controlnet.class_embedding:
363
+ m, u = controlnet.class_embedding.load_state_dict(cls.image_layer_filter(unet.class_embedding.state_dict()), strict=False)
364
+ assert len(u) == 0
365
+ m, u = controlnet.down_blocks.load_state_dict(cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False)
366
+ assert len(u) == 0
367
+ m, u = controlnet.mid_block.load_state_dict(cls.image_layer_filter(unet.mid_block.state_dict()), strict=False)
368
+ assert len(u) == 0
369
+
370
+ return controlnet
371
+
372
+ @staticmethod
373
+ def image_layer_filter(state_dict):
374
+ new_state_dict = {}
375
+ for name, param in state_dict.items():
376
+ if "motion_modules." in name or "lora" in name: continue
377
+ new_state_dict[name] = param
378
+ return new_state_dict
379
+
380
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
381
+ def set_attention_slice(self, slice_size):
382
+ r"""
383
+ Enable sliced attention computation.
384
+
385
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
386
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
387
+
388
+ Args:
389
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
390
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
391
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
392
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
393
+ must be a multiple of `slice_size`.
394
+ """
395
+ sliceable_head_dims = []
396
+
397
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
398
+ if hasattr(module, "set_attention_slice"):
399
+ sliceable_head_dims.append(module.sliceable_head_dim)
400
+
401
+ for child in module.children():
402
+ fn_recursive_retrieve_sliceable_dims(child)
403
+
404
+ # retrieve number of attention layers
405
+ for module in self.children():
406
+ fn_recursive_retrieve_sliceable_dims(module)
407
+
408
+ num_sliceable_layers = len(sliceable_head_dims)
409
+
410
+ if slice_size == "auto":
411
+ # half the attention head size is usually a good trade-off between
412
+ # speed and memory
413
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
414
+ elif slice_size == "max":
415
+ # make smallest slice possible
416
+ slice_size = num_sliceable_layers * [1]
417
+
418
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
419
+
420
+ if len(slice_size) != len(sliceable_head_dims):
421
+ raise ValueError(
422
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
423
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
424
+ )
425
+
426
+ for i in range(len(slice_size)):
427
+ size = slice_size[i]
428
+ dim = sliceable_head_dims[i]
429
+ if size is not None and size > dim:
430
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
431
+
432
+ # Recursively walk through all the children.
433
+ # Any children which exposes the set_attention_slice method
434
+ # gets the message
435
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
436
+ if hasattr(module, "set_attention_slice"):
437
+ module.set_attention_slice(slice_size.pop())
438
+
439
+ for child in module.children():
440
+ fn_recursive_set_attention_slice(child, slice_size)
441
+
442
+ reversed_slice_size = list(reversed(slice_size))
443
+ for module in self.children():
444
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
445
+
446
+ def _set_gradient_checkpointing(self, module, value=False):
447
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
448
+ module.gradient_checkpointing = value
449
+
450
+ def forward(
451
+ self,
452
+ sample: torch.FloatTensor,
453
+ timestep: Union[torch.Tensor, float, int],
454
+ encoder_hidden_states: torch.Tensor,
455
+
456
+ controlnet_cond: torch.FloatTensor,
457
+ conditioning_mask: Optional[torch.FloatTensor] = None,
458
+
459
+ conditioning_scale: float = 1.0,
460
+ class_labels: Optional[torch.Tensor] = None,
461
+ attention_mask: Optional[torch.Tensor] = None,
462
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
463
+ guess_mode: bool = False,
464
+ return_dict: bool = True,
465
+ ) -> Union[SparseControlNetOutput, Tuple]:
466
+
467
+ # set input noise to zero
468
+ # if self.set_noisy_sample_input_to_zero:
469
+ # sample = torch.zeros_like(sample).to(sample.device)
470
+
471
+ # prepare attention_mask
472
+ if attention_mask is not None:
473
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
474
+ attention_mask = attention_mask.unsqueeze(1)
475
+
476
+ # 1. time
477
+ timesteps = timestep
478
+ if not torch.is_tensor(timesteps):
479
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
480
+ # This would be a good case for the `match` statement (Python 3.10+)
481
+ is_mps = sample.device.type == "mps"
482
+ if isinstance(timestep, float):
483
+ dtype = torch.float32 if is_mps else torch.float64
484
+ else:
485
+ dtype = torch.int32 if is_mps else torch.int64
486
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
487
+ elif len(timesteps.shape) == 0:
488
+ timesteps = timesteps[None].to(sample.device)
489
+
490
+ timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0])
491
+ encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0] // encoder_hidden_states.shape[0], 1, 1)
492
+
493
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
494
+ timesteps = timesteps.expand(sample.shape[0])
495
+
496
+ t_emb = self.time_proj(timesteps)
497
+
498
+ # timesteps does not contain any weights and will always return f32 tensors
499
+ # but time_embedding might actually be running in fp16. so we need to cast here.
500
+ # there might be better ways to encapsulate this.
501
+ t_emb = t_emb.to(dtype=self.dtype)
502
+ emb = self.time_embedding(t_emb)
503
+
504
+ if self.class_embedding is not None:
505
+ if class_labels is None:
506
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
507
+
508
+ if self.config.class_embed_type == "timestep":
509
+ class_labels = self.time_proj(class_labels)
510
+
511
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
512
+ emb = emb + class_emb
513
+
514
+ # 2. pre-process
515
+ # equal to set input noise to zero
516
+ if self.set_noisy_sample_input_to_zero:
517
+ shape = sample.shape
518
+ sample = self.conv_in.bias.reshape(1,-1,1,1,1).expand(shape[0],-1,shape[2],shape[3],shape[4])
519
+ else:
520
+ sample = self.conv_in(sample)
521
+
522
+ if self.concate_conditioning_mask:
523
+ controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1).to(torch.float16)
524
+ # import pdb; pdb.set_trace()
525
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
526
+
527
+ sample = sample + controlnet_cond
528
+
529
+ # 3. down
530
+ down_block_res_samples = (sample,)
531
+ for downsample_block in self.down_blocks:
532
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
533
+ sample, res_samples = downsample_block(
534
+ hidden_states=sample,
535
+ temb=emb,
536
+ encoder_hidden_states=encoder_hidden_states,
537
+ attention_mask=attention_mask,
538
+ # cross_attention_kwargs=cross_attention_kwargs,
539
+ )
540
+ else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
541
+
542
+ down_block_res_samples += res_samples
543
+
544
+ # 4. mid
545
+ if self.mid_block is not None:
546
+ sample = self.mid_block(
547
+ sample,
548
+ emb,
549
+ encoder_hidden_states=encoder_hidden_states,
550
+ attention_mask=attention_mask,
551
+ # cross_attention_kwargs=cross_attention_kwargs,
552
+ )
553
+
554
+ # 5. controlnet blocks
555
+ controlnet_down_block_res_samples = ()
556
+
557
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
558
+ down_block_res_sample = controlnet_block(down_block_res_sample)
559
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
560
+
561
+ down_block_res_samples = controlnet_down_block_res_samples
562
+
563
+ mid_block_res_sample = self.controlnet_mid_block(sample)
564
+
565
+ # 6. scaling
566
+ if guess_mode and not self.config.global_pool_conditions:
567
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
568
+
569
+ scales = scales * conditioning_scale
570
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
571
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
572
+ else:
573
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
574
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
575
+
576
+ if self.config.global_pool_conditions:
577
+ down_block_res_samples = [
578
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
579
+ ]
580
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
581
+
582
+ if not return_dict:
583
+ return (down_block_res_samples, mid_block_res_sample)
584
+
585
+ return SparseControlNetOutput(
586
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
587
+ )
588
+
589
+
590
+ def zero_module(module):
591
+ for p in module.parameters():
592
+ nn.init.zeros_(p)
593
+ return module
motionclone/models/unet.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import json
8
+ import pdb
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.utils.checkpoint
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils import BaseOutput, logging
17
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
18
+ from .unet_blocks import (
19
+ CrossAttnDownBlock3D,
20
+ CrossAttnUpBlock3D,
21
+ DownBlock3D,
22
+ UNetMidBlock3DCrossAttn,
23
+ UpBlock3D,
24
+ get_down_block,
25
+ get_up_block,
26
+ )
27
+ from .resnet import InflatedConv3d, InflatedGroupNorm
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ @dataclass
34
+ class UNet3DConditionOutput(BaseOutput):
35
+ sample: torch.FloatTensor
36
+
37
+
38
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
39
+ _supports_gradient_checkpointing = True
40
+
41
+ @register_to_config
42
+ def __init__(
43
+ self,
44
+ sample_size: Optional[int] = None,
45
+ in_channels: int = 4,
46
+ out_channels: int = 4,
47
+ center_input_sample: bool = False,
48
+ flip_sin_to_cos: bool = True,
49
+ freq_shift: int = 0,
50
+ down_block_types: Tuple[str] = (
51
+ "CrossAttnDownBlock3D",
52
+ "CrossAttnDownBlock3D",
53
+ "CrossAttnDownBlock3D",
54
+ "DownBlock3D",
55
+ ),
56
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
57
+ up_block_types: Tuple[str] = ( # 第一个不带有CrossAttn,后面三个带有CrossAttn
58
+ "UpBlock3D",
59
+ "CrossAttnUpBlock3D",
60
+ "CrossAttnUpBlock3D",
61
+ "CrossAttnUpBlock3D"
62
+ ),
63
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
64
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
65
+ layers_per_block: int = 2,
66
+ downsample_padding: int = 1,
67
+ mid_block_scale_factor: float = 1,
68
+ act_fn: str = "silu",
69
+ norm_num_groups: int = 32,
70
+ norm_eps: float = 1e-5,
71
+ cross_attention_dim: int = 1280,
72
+ attention_head_dim: Union[int, Tuple[int]] = 8,
73
+ dual_cross_attention: bool = False,
74
+ use_linear_projection: bool = False,
75
+ class_embed_type: Optional[str] = None,
76
+ num_class_embeds: Optional[int] = None,
77
+ upcast_attention: bool = False,
78
+ resnet_time_scale_shift: str = "default",
79
+
80
+ use_inflated_groupnorm=False,
81
+
82
+ # Additional
83
+ use_motion_module = False,
84
+ motion_module_resolutions = ( 1,2,4,8 ),
85
+ motion_module_mid_block = False,
86
+ motion_module_decoder_only = False,
87
+ motion_module_type = None,
88
+ motion_module_kwargs = {},
89
+ unet_use_cross_frame_attention = False,
90
+ unet_use_temporal_attention = False,
91
+ ):
92
+ super().__init__()
93
+
94
+ self.sample_size = sample_size
95
+ time_embed_dim = block_out_channels[0] * 4
96
+
97
+ # input
98
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
99
+
100
+ # time
101
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
102
+ timestep_input_dim = block_out_channels[0]
103
+
104
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
105
+
106
+ # class embedding
107
+ if class_embed_type is None and num_class_embeds is not None:
108
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
109
+ elif class_embed_type == "timestep":
110
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
111
+ elif class_embed_type == "identity":
112
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
113
+ else:
114
+ self.class_embedding = None
115
+
116
+ self.down_blocks = nn.ModuleList([])
117
+ self.mid_block = None
118
+ self.up_blocks = nn.ModuleList([])
119
+
120
+ if isinstance(only_cross_attention, bool):
121
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
122
+
123
+ if isinstance(attention_head_dim, int):
124
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
125
+
126
+ # down
127
+ output_channel = block_out_channels[0]
128
+ for i, down_block_type in enumerate(down_block_types):
129
+ res = 2 ** i
130
+ input_channel = output_channel
131
+ output_channel = block_out_channels[i]
132
+ is_final_block = i == len(block_out_channels) - 1
133
+
134
+ down_block = get_down_block(
135
+ down_block_type,
136
+ num_layers=layers_per_block,
137
+ in_channels=input_channel,
138
+ out_channels=output_channel,
139
+ temb_channels=time_embed_dim,
140
+ add_downsample=not is_final_block,
141
+ resnet_eps=norm_eps,
142
+ resnet_act_fn=act_fn,
143
+ resnet_groups=norm_num_groups,
144
+ cross_attention_dim=cross_attention_dim,
145
+ attn_num_head_channels=attention_head_dim[i],
146
+ downsample_padding=downsample_padding,
147
+ dual_cross_attention=dual_cross_attention,
148
+ use_linear_projection=use_linear_projection,
149
+ only_cross_attention=only_cross_attention[i],
150
+ upcast_attention=upcast_attention,
151
+ resnet_time_scale_shift=resnet_time_scale_shift,
152
+
153
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
154
+ unet_use_temporal_attention=unet_use_temporal_attention,
155
+ use_inflated_groupnorm=use_inflated_groupnorm,
156
+
157
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
158
+ motion_module_type=motion_module_type,
159
+ motion_module_kwargs=motion_module_kwargs,
160
+ )
161
+ self.down_blocks.append(down_block)
162
+
163
+ # mid
164
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
165
+ self.mid_block = UNetMidBlock3DCrossAttn(
166
+ in_channels=block_out_channels[-1],
167
+ temb_channels=time_embed_dim,
168
+ resnet_eps=norm_eps,
169
+ resnet_act_fn=act_fn,
170
+ output_scale_factor=mid_block_scale_factor,
171
+ resnet_time_scale_shift=resnet_time_scale_shift,
172
+ cross_attention_dim=cross_attention_dim,
173
+ attn_num_head_channels=attention_head_dim[-1],
174
+ resnet_groups=norm_num_groups,
175
+ dual_cross_attention=dual_cross_attention,
176
+ use_linear_projection=use_linear_projection,
177
+ upcast_attention=upcast_attention,
178
+
179
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
180
+ unet_use_temporal_attention=unet_use_temporal_attention,
181
+ use_inflated_groupnorm=use_inflated_groupnorm,
182
+
183
+ use_motion_module=use_motion_module and motion_module_mid_block,
184
+ motion_module_type=motion_module_type,
185
+ motion_module_kwargs=motion_module_kwargs,
186
+ )
187
+ else:
188
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
189
+
190
+ # count how many layers upsample the videos
191
+ self.num_upsamplers = 0
192
+
193
+ # up
194
+ reversed_block_out_channels = list(reversed(block_out_channels))
195
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
196
+ only_cross_attention = list(reversed(only_cross_attention))
197
+ output_channel = reversed_block_out_channels[0]
198
+ for i, up_block_type in enumerate(up_block_types):
199
+ res = 2 ** (3 - i)
200
+ is_final_block = i == len(block_out_channels) - 1
201
+
202
+ prev_output_channel = output_channel
203
+ output_channel = reversed_block_out_channels[i]
204
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
205
+
206
+ # add upsample block for all BUT final layer
207
+ if not is_final_block:
208
+ add_upsample = True
209
+ self.num_upsamplers += 1
210
+ else:
211
+ add_upsample = False
212
+
213
+ up_block = get_up_block(
214
+ up_block_type,
215
+ num_layers=layers_per_block + 1,
216
+ in_channels=input_channel,
217
+ out_channels=output_channel,
218
+ prev_output_channel=prev_output_channel,
219
+ temb_channels=time_embed_dim,
220
+ add_upsample=add_upsample,
221
+ resnet_eps=norm_eps,
222
+ resnet_act_fn=act_fn,
223
+ resnet_groups=norm_num_groups,
224
+ cross_attention_dim=cross_attention_dim,
225
+ attn_num_head_channels=reversed_attention_head_dim[i],
226
+ dual_cross_attention=dual_cross_attention,
227
+ use_linear_projection=use_linear_projection,
228
+ only_cross_attention=only_cross_attention[i],
229
+ upcast_attention=upcast_attention,
230
+ resnet_time_scale_shift=resnet_time_scale_shift,
231
+
232
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
233
+ unet_use_temporal_attention=unet_use_temporal_attention,
234
+ use_inflated_groupnorm=use_inflated_groupnorm,
235
+
236
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
237
+ motion_module_type=motion_module_type,
238
+ motion_module_kwargs=motion_module_kwargs,
239
+ )
240
+ self.up_blocks.append(up_block)
241
+ prev_output_channel = output_channel
242
+
243
+ # out
244
+ if use_inflated_groupnorm:
245
+ self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
246
+ else:
247
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
248
+ self.conv_act = nn.SiLU()
249
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
250
+
251
+ def set_attention_slice(self, slice_size):
252
+ r"""
253
+ Enable sliced attention computation.
254
+
255
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
256
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
257
+
258
+ Args:
259
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
260
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
261
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
262
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
263
+ must be a multiple of `slice_size`.
264
+ """
265
+ sliceable_head_dims = []
266
+
267
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
268
+ if hasattr(module, "set_attention_slice"):
269
+ sliceable_head_dims.append(module.sliceable_head_dim)
270
+
271
+ for child in module.children():
272
+ fn_recursive_retrieve_slicable_dims(child)
273
+
274
+ # retrieve number of attention layers
275
+ for module in self.children():
276
+ fn_recursive_retrieve_slicable_dims(module)
277
+
278
+ num_slicable_layers = len(sliceable_head_dims)
279
+
280
+ if slice_size == "auto":
281
+ # half the attention head size is usually a good trade-off between
282
+ # speed and memory
283
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
284
+ elif slice_size == "max":
285
+ # make smallest slice possible
286
+ slice_size = num_slicable_layers * [1]
287
+
288
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
289
+
290
+ if len(slice_size) != len(sliceable_head_dims):
291
+ raise ValueError(
292
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
293
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
294
+ )
295
+
296
+ for i in range(len(slice_size)):
297
+ size = slice_size[i]
298
+ dim = sliceable_head_dims[i]
299
+ if size is not None and size > dim:
300
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
301
+
302
+ # Recursively walk through all the children.
303
+ # Any children which exposes the set_attention_slice method
304
+ # gets the message
305
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
306
+ if hasattr(module, "set_attention_slice"):
307
+ module.set_attention_slice(slice_size.pop())
308
+
309
+ for child in module.children():
310
+ fn_recursive_set_attention_slice(child, slice_size)
311
+
312
+ reversed_slice_size = list(reversed(slice_size))
313
+ for module in self.children():
314
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
315
+
316
+ def _set_gradient_checkpointing(self, module, value=False):
317
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
318
+ module.gradient_checkpointing = value
319
+
320
+ def forward(
321
+ self,
322
+ sample: torch.FloatTensor,
323
+ timestep: Union[torch.Tensor, float, int],
324
+ encoder_hidden_states: torch.Tensor,
325
+ class_labels: Optional[torch.Tensor] = None,
326
+ attention_mask: Optional[torch.Tensor] = None,
327
+
328
+ # support controlnet
329
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
330
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
331
+
332
+ return_dict: bool = True,
333
+ ) -> Union[UNet3DConditionOutput, Tuple]:
334
+ r"""
335
+ Args:
336
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
337
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
338
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
339
+ return_dict (`bool`, *optional*, defaults to `True`):
340
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
341
+
342
+ Returns:
343
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
344
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
345
+ returning a tuple, the first element is the sample tensor.
346
+ """
347
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
348
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
349
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
350
+ # on the fly if necessary.
351
+ default_overall_up_factor = 2**self.num_upsamplers
352
+
353
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
354
+ forward_upsample_size = False
355
+ upsample_size = None
356
+
357
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
358
+ logger.info("Forward upsample size to force interpolation output size.")
359
+ forward_upsample_size = True
360
+
361
+ # prepare attention_mask
362
+ if attention_mask is not None:
363
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
364
+ attention_mask = attention_mask.unsqueeze(1)
365
+
366
+ # center input if necessary
367
+ if self.config.center_input_sample:
368
+ sample = 2 * sample - 1.0
369
+
370
+ # time
371
+ timesteps = timestep
372
+ if not torch.is_tensor(timesteps):
373
+ # This would be a good case for the `match` statement (Python 3.10+)
374
+ is_mps = sample.device.type == "mps"
375
+ if isinstance(timestep, float):
376
+ dtype = torch.float32 if is_mps else torch.float64
377
+ else:
378
+ dtype = torch.int32 if is_mps else torch.int64
379
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
380
+ elif len(timesteps.shape) == 0:
381
+ timesteps = timesteps[None].to(sample.device)
382
+
383
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
384
+ timesteps = timesteps.expand(sample.shape[0])
385
+
386
+ t_emb = self.time_proj(timesteps)
387
+
388
+ # timesteps does not contain any weights and will always return f32 tensors
389
+ # but time_embedding might actually be running in fp16. so we need to cast here.
390
+ # there might be better ways to encapsulate this.
391
+ t_emb = t_emb.to(dtype=self.dtype)
392
+ emb = self.time_embedding(t_emb)
393
+
394
+ if self.class_embedding is not None:
395
+ if class_labels is None:
396
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
397
+
398
+ if self.config.class_embed_type == "timestep":
399
+ class_labels = self.time_proj(class_labels)
400
+
401
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
402
+ emb = emb + class_emb
403
+
404
+ # pre-process
405
+ sample = self.conv_in(sample)
406
+
407
+ # down
408
+ down_block_res_samples = (sample,)
409
+ for downsample_block in self.down_blocks:
410
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
411
+ sample, res_samples = downsample_block(
412
+ hidden_states=sample,
413
+ temb=emb,
414
+ encoder_hidden_states=encoder_hidden_states,
415
+ attention_mask=attention_mask,
416
+ )
417
+ else:
418
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
419
+
420
+ down_block_res_samples += res_samples
421
+
422
+ # support controlnet
423
+ down_block_res_samples = list(down_block_res_samples)
424
+ if down_block_additional_residuals is not None:
425
+ for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
426
+ if down_block_additional_residual.dim() == 4: # boardcast
427
+ down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
428
+ down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
429
+
430
+ # mid
431
+ sample = self.mid_block(
432
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
433
+ )
434
+
435
+ # support controlnet
436
+ if mid_block_additional_residual is not None:
437
+ if mid_block_additional_residual.dim() == 4: # boardcast
438
+ mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
439
+ sample = sample + mid_block_additional_residual
440
+
441
+ # up
442
+ for i, upsample_block in enumerate(self.up_blocks):
443
+ is_final_block = i == len(self.up_blocks) - 1
444
+
445
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
446
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
447
+
448
+ # if we have not reached the final block and need to forward the
449
+ # upsample size, we do it here
450
+ if not is_final_block and forward_upsample_size:
451
+ upsample_size = down_block_res_samples[-1].shape[2:]
452
+
453
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
454
+ sample = upsample_block(
455
+ hidden_states=sample,
456
+ temb=emb,
457
+ res_hidden_states_tuple=res_samples,
458
+ encoder_hidden_states=encoder_hidden_states,
459
+ upsample_size=upsample_size,
460
+ attention_mask=attention_mask,
461
+ )
462
+ else:
463
+ sample = upsample_block(
464
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
465
+ )
466
+
467
+ # post-process
468
+ sample = self.conv_norm_out(sample)
469
+ sample = self.conv_act(sample)
470
+ sample = self.conv_out(sample)
471
+
472
+ if not return_dict:
473
+ return (sample,)
474
+
475
+ return UNet3DConditionOutput(sample=sample)
476
+
477
+ @classmethod
478
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
479
+ if subfolder is not None:
480
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
481
+ print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
482
+
483
+ config_file = os.path.join(pretrained_model_path, 'config.json')
484
+ if not os.path.isfile(config_file):
485
+ raise RuntimeError(f"{config_file} does not exist")
486
+ with open(config_file, "r") as f:
487
+ config = json.load(f)
488
+ config["_class_name"] = cls.__name__
489
+ config["down_block_types"] = [
490
+ "CrossAttnDownBlock3D",
491
+ "CrossAttnDownBlock3D",
492
+ "CrossAttnDownBlock3D",
493
+ "DownBlock3D"
494
+ ]
495
+ config["up_block_types"] = [
496
+ "UpBlock3D",
497
+ "CrossAttnUpBlock3D",
498
+ "CrossAttnUpBlock3D",
499
+ "CrossAttnUpBlock3D"
500
+ ]
501
+
502
+ from diffusers.utils import WEIGHTS_NAME
503
+ model = cls.from_config(config, **unet_additional_kwargs)
504
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
505
+ if not os.path.isfile(model_file):
506
+ raise RuntimeError(f"{model_file} does not exist")
507
+ state_dict = torch.load(model_file, map_location="cpu")
508
+
509
+ m, u = model.load_state_dict(state_dict, strict=False)
510
+ print(f"### motion keys will be loaded: {len(m)}; \n### unexpected keys: {len(u)};")
511
+
512
+ params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()]
513
+ print(f"### Motion Module Parameters: {sum(params) / 1e6} M")
514
+
515
+ return model
motionclone/models/unet_blocks.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8
+ from .motion_module import get_motion_module
9
+
10
+ import pdb
11
+
12
+ def get_down_block(
13
+ down_block_type,
14
+ num_layers,
15
+ in_channels,
16
+ out_channels,
17
+ temb_channels,
18
+ add_downsample,
19
+ resnet_eps,
20
+ resnet_act_fn,
21
+ attn_num_head_channels,
22
+ resnet_groups=None,
23
+ cross_attention_dim=None,
24
+ downsample_padding=None,
25
+ dual_cross_attention=False,
26
+ use_linear_projection=False,
27
+ only_cross_attention=False,
28
+ upcast_attention=False,
29
+ resnet_time_scale_shift="default",
30
+
31
+ unet_use_cross_frame_attention=False,
32
+ unet_use_temporal_attention=False,
33
+ use_inflated_groupnorm=False,
34
+
35
+ use_motion_module=None,
36
+
37
+ motion_module_type=None,
38
+ motion_module_kwargs=None,
39
+ ):
40
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
41
+ if down_block_type == "DownBlock3D":
42
+ return DownBlock3D(
43
+ num_layers=num_layers,
44
+ in_channels=in_channels,
45
+ out_channels=out_channels,
46
+ temb_channels=temb_channels,
47
+ add_downsample=add_downsample,
48
+ resnet_eps=resnet_eps,
49
+ resnet_act_fn=resnet_act_fn,
50
+ resnet_groups=resnet_groups,
51
+ downsample_padding=downsample_padding,
52
+ resnet_time_scale_shift=resnet_time_scale_shift,
53
+
54
+ use_inflated_groupnorm=use_inflated_groupnorm,
55
+
56
+ use_motion_module=use_motion_module,
57
+ motion_module_type=motion_module_type,
58
+ motion_module_kwargs=motion_module_kwargs,
59
+ )
60
+ elif down_block_type == "CrossAttnDownBlock3D":
61
+ if cross_attention_dim is None:
62
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
63
+ return CrossAttnDownBlock3D(
64
+ num_layers=num_layers,
65
+ in_channels=in_channels,
66
+ out_channels=out_channels,
67
+ temb_channels=temb_channels,
68
+ add_downsample=add_downsample,
69
+ resnet_eps=resnet_eps,
70
+ resnet_act_fn=resnet_act_fn,
71
+ resnet_groups=resnet_groups,
72
+ downsample_padding=downsample_padding,
73
+ cross_attention_dim=cross_attention_dim,
74
+ attn_num_head_channels=attn_num_head_channels,
75
+ dual_cross_attention=dual_cross_attention,
76
+ use_linear_projection=use_linear_projection,
77
+ only_cross_attention=only_cross_attention,
78
+ upcast_attention=upcast_attention,
79
+ resnet_time_scale_shift=resnet_time_scale_shift,
80
+
81
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
82
+ unet_use_temporal_attention=unet_use_temporal_attention,
83
+ use_inflated_groupnorm=use_inflated_groupnorm,
84
+
85
+ use_motion_module=use_motion_module,
86
+ motion_module_type=motion_module_type,
87
+ motion_module_kwargs=motion_module_kwargs,
88
+ )
89
+ raise ValueError(f"{down_block_type} does not exist.")
90
+
91
+
92
+ def get_up_block(
93
+ up_block_type,
94
+ num_layers,
95
+ in_channels,
96
+ out_channels,
97
+ prev_output_channel,
98
+ temb_channels,
99
+ add_upsample,
100
+ resnet_eps,
101
+ resnet_act_fn,
102
+ attn_num_head_channels,
103
+ resnet_groups=None,
104
+ cross_attention_dim=None,
105
+ dual_cross_attention=False,
106
+ use_linear_projection=False,
107
+ only_cross_attention=False,
108
+ upcast_attention=False,
109
+ resnet_time_scale_shift="default",
110
+
111
+ unet_use_cross_frame_attention=False,
112
+ unet_use_temporal_attention=False,
113
+ use_inflated_groupnorm=False,
114
+
115
+ use_motion_module=None,
116
+ motion_module_type=None,
117
+ motion_module_kwargs=None,
118
+ ):
119
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
120
+ if up_block_type == "UpBlock3D":
121
+ return UpBlock3D(
122
+ num_layers=num_layers,
123
+ in_channels=in_channels,
124
+ out_channels=out_channels,
125
+ prev_output_channel=prev_output_channel,
126
+ temb_channels=temb_channels,
127
+ add_upsample=add_upsample,
128
+ resnet_eps=resnet_eps,
129
+ resnet_act_fn=resnet_act_fn,
130
+ resnet_groups=resnet_groups,
131
+ resnet_time_scale_shift=resnet_time_scale_shift,
132
+
133
+ use_inflated_groupnorm=use_inflated_groupnorm,
134
+
135
+ use_motion_module=use_motion_module,
136
+ motion_module_type=motion_module_type,
137
+ motion_module_kwargs=motion_module_kwargs,
138
+ )
139
+ elif up_block_type == "CrossAttnUpBlock3D":
140
+ if cross_attention_dim is None:
141
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
142
+ return CrossAttnUpBlock3D(
143
+ num_layers=num_layers,
144
+ in_channels=in_channels,
145
+ out_channels=out_channels,
146
+ prev_output_channel=prev_output_channel,
147
+ temb_channels=temb_channels,
148
+ add_upsample=add_upsample,
149
+ resnet_eps=resnet_eps,
150
+ resnet_act_fn=resnet_act_fn,
151
+ resnet_groups=resnet_groups,
152
+ cross_attention_dim=cross_attention_dim,
153
+ attn_num_head_channels=attn_num_head_channels,
154
+ dual_cross_attention=dual_cross_attention,
155
+ use_linear_projection=use_linear_projection,
156
+ only_cross_attention=only_cross_attention,
157
+ upcast_attention=upcast_attention,
158
+ resnet_time_scale_shift=resnet_time_scale_shift,
159
+
160
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
161
+ unet_use_temporal_attention=unet_use_temporal_attention,
162
+ use_inflated_groupnorm=use_inflated_groupnorm,
163
+
164
+ use_motion_module=use_motion_module,
165
+ motion_module_type=motion_module_type,
166
+ motion_module_kwargs=motion_module_kwargs,
167
+ )
168
+ raise ValueError(f"{up_block_type} does not exist.")
169
+
170
+
171
+ class UNetMidBlock3DCrossAttn(nn.Module):
172
+ def __init__(
173
+ self,
174
+ in_channels: int,
175
+ temb_channels: int,
176
+ dropout: float = 0.0,
177
+ num_layers: int = 1,
178
+ resnet_eps: float = 1e-6,
179
+ resnet_time_scale_shift: str = "default",
180
+ resnet_act_fn: str = "swish",
181
+ resnet_groups: int = 32,
182
+ resnet_pre_norm: bool = True,
183
+ attn_num_head_channels=1,
184
+ output_scale_factor=1.0,
185
+ cross_attention_dim=1280,
186
+ dual_cross_attention=False,
187
+ use_linear_projection=False,
188
+ upcast_attention=False,
189
+
190
+ unet_use_cross_frame_attention=False,
191
+ unet_use_temporal_attention=False,
192
+ use_inflated_groupnorm=False,
193
+
194
+ use_motion_module=None,
195
+
196
+ motion_module_type=None,
197
+ motion_module_kwargs=None,
198
+ ):
199
+ super().__init__()
200
+
201
+ self.has_cross_attention = True
202
+ self.attn_num_head_channels = attn_num_head_channels
203
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
204
+
205
+ # there is always at least one resnet
206
+ resnets = [
207
+ ResnetBlock3D(
208
+ in_channels=in_channels,
209
+ out_channels=in_channels,
210
+ temb_channels=temb_channels,
211
+ eps=resnet_eps,
212
+ groups=resnet_groups,
213
+ dropout=dropout,
214
+ time_embedding_norm=resnet_time_scale_shift,
215
+ non_linearity=resnet_act_fn,
216
+ output_scale_factor=output_scale_factor,
217
+ pre_norm=resnet_pre_norm,
218
+
219
+ use_inflated_groupnorm=use_inflated_groupnorm,
220
+ )
221
+ ]
222
+ attentions = []
223
+ motion_modules = []
224
+
225
+ for _ in range(num_layers):
226
+ if dual_cross_attention:
227
+ raise NotImplementedError
228
+ attentions.append(
229
+ Transformer3DModel(
230
+ attn_num_head_channels,
231
+ in_channels // attn_num_head_channels,
232
+ in_channels=in_channels,
233
+ num_layers=1,
234
+ cross_attention_dim=cross_attention_dim,
235
+ norm_num_groups=resnet_groups,
236
+ use_linear_projection=use_linear_projection,
237
+ upcast_attention=upcast_attention,
238
+
239
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
240
+ unet_use_temporal_attention=unet_use_temporal_attention,
241
+ )
242
+ )
243
+ motion_modules.append(
244
+ get_motion_module(
245
+ in_channels=in_channels,
246
+ motion_module_type=motion_module_type,
247
+ motion_module_kwargs=motion_module_kwargs,
248
+ ) if use_motion_module else None
249
+ )
250
+ resnets.append(
251
+ ResnetBlock3D(
252
+ in_channels=in_channels,
253
+ out_channels=in_channels,
254
+ temb_channels=temb_channels,
255
+ eps=resnet_eps,
256
+ groups=resnet_groups,
257
+ dropout=dropout,
258
+ time_embedding_norm=resnet_time_scale_shift,
259
+ non_linearity=resnet_act_fn,
260
+ output_scale_factor=output_scale_factor,
261
+ pre_norm=resnet_pre_norm,
262
+
263
+ use_inflated_groupnorm=use_inflated_groupnorm,
264
+ )
265
+ )
266
+
267
+ self.attentions = nn.ModuleList(attentions)
268
+ self.resnets = nn.ModuleList(resnets)
269
+ self.motion_modules = nn.ModuleList(motion_modules)
270
+
271
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
272
+ hidden_states = self.resnets[0](hidden_states, temb)
273
+ for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
274
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
275
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
276
+ hidden_states = resnet(hidden_states, temb)
277
+
278
+ return hidden_states
279
+
280
+
281
+ class CrossAttnDownBlock3D(nn.Module):
282
+ def __init__(
283
+ self,
284
+ in_channels: int,
285
+ out_channels: int,
286
+ temb_channels: int,
287
+ dropout: float = 0.0,
288
+ num_layers: int = 1,
289
+ resnet_eps: float = 1e-6,
290
+ resnet_time_scale_shift: str = "default",
291
+ resnet_act_fn: str = "swish",
292
+ resnet_groups: int = 32,
293
+ resnet_pre_norm: bool = True,
294
+ attn_num_head_channels=1,
295
+ cross_attention_dim=1280,
296
+ output_scale_factor=1.0,
297
+ downsample_padding=1,
298
+ add_downsample=True,
299
+ dual_cross_attention=False,
300
+ use_linear_projection=False,
301
+ only_cross_attention=False,
302
+ upcast_attention=False,
303
+
304
+ unet_use_cross_frame_attention=False,
305
+ unet_use_temporal_attention=False,
306
+ use_inflated_groupnorm=False,
307
+
308
+ use_motion_module=None,
309
+
310
+ motion_module_type=None,
311
+ motion_module_kwargs=None,
312
+ ):
313
+ super().__init__()
314
+ resnets = []
315
+ attentions = []
316
+ motion_modules = []
317
+
318
+ self.has_cross_attention = True
319
+ self.attn_num_head_channels = attn_num_head_channels
320
+
321
+ for i in range(num_layers):
322
+ in_channels = in_channels if i == 0 else out_channels
323
+ resnets.append(
324
+ ResnetBlock3D(
325
+ in_channels=in_channels,
326
+ out_channels=out_channels,
327
+ temb_channels=temb_channels,
328
+ eps=resnet_eps,
329
+ groups=resnet_groups,
330
+ dropout=dropout,
331
+ time_embedding_norm=resnet_time_scale_shift,
332
+ non_linearity=resnet_act_fn,
333
+ output_scale_factor=output_scale_factor,
334
+ pre_norm=resnet_pre_norm,
335
+
336
+ use_inflated_groupnorm=use_inflated_groupnorm,
337
+ )
338
+ )
339
+ if dual_cross_attention:
340
+ raise NotImplementedError
341
+ attentions.append(
342
+ Transformer3DModel(
343
+ attn_num_head_channels,
344
+ out_channels // attn_num_head_channels,
345
+ in_channels=out_channels,
346
+ num_layers=1,
347
+ cross_attention_dim=cross_attention_dim,
348
+ norm_num_groups=resnet_groups,
349
+ use_linear_projection=use_linear_projection,
350
+ only_cross_attention=only_cross_attention,
351
+ upcast_attention=upcast_attention,
352
+
353
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
354
+ unet_use_temporal_attention=unet_use_temporal_attention,
355
+ )
356
+ )
357
+ motion_modules.append(
358
+ get_motion_module(
359
+ in_channels=out_channels,
360
+ motion_module_type=motion_module_type,
361
+ motion_module_kwargs=motion_module_kwargs,
362
+ ) if use_motion_module else None
363
+ )
364
+
365
+ self.attentions = nn.ModuleList(attentions)
366
+ self.resnets = nn.ModuleList(resnets)
367
+ self.motion_modules = nn.ModuleList(motion_modules)
368
+
369
+ if add_downsample:
370
+ self.downsamplers = nn.ModuleList(
371
+ [
372
+ Downsample3D(
373
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
374
+ )
375
+ ]
376
+ )
377
+ else:
378
+ self.downsamplers = None
379
+
380
+ self.gradient_checkpointing = False
381
+
382
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
383
+ output_states = ()
384
+
385
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
386
+ if self.training and self.gradient_checkpointing:
387
+
388
+ def create_custom_forward(module, return_dict=None):
389
+ def custom_forward(*inputs):
390
+ if return_dict is not None:
391
+ return module(*inputs, return_dict=return_dict)
392
+ else:
393
+ return module(*inputs)
394
+
395
+ return custom_forward
396
+
397
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
398
+ hidden_states = torch.utils.checkpoint.checkpoint(
399
+ create_custom_forward(attn, return_dict=False),
400
+ hidden_states,
401
+ encoder_hidden_states,
402
+ )[0]
403
+ if motion_module is not None:
404
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
405
+
406
+ else:
407
+ hidden_states = resnet(hidden_states, temb)
408
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
409
+
410
+ # add motion module
411
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
412
+
413
+ output_states += (hidden_states,)
414
+
415
+ if self.downsamplers is not None:
416
+ for downsampler in self.downsamplers:
417
+ hidden_states = downsampler(hidden_states)
418
+
419
+ output_states += (hidden_states,)
420
+
421
+ return hidden_states, output_states
422
+
423
+
424
+ class DownBlock3D(nn.Module):
425
+ def __init__(
426
+ self,
427
+ in_channels: int,
428
+ out_channels: int,
429
+ temb_channels: int,
430
+ dropout: float = 0.0,
431
+ num_layers: int = 1,
432
+ resnet_eps: float = 1e-6,
433
+ resnet_time_scale_shift: str = "default",
434
+ resnet_act_fn: str = "swish",
435
+ resnet_groups: int = 32,
436
+ resnet_pre_norm: bool = True,
437
+ output_scale_factor=1.0,
438
+ add_downsample=True,
439
+ downsample_padding=1,
440
+
441
+ use_inflated_groupnorm=False,
442
+
443
+ use_motion_module=None,
444
+ motion_module_type=None,
445
+ motion_module_kwargs=None,
446
+ ):
447
+ super().__init__()
448
+ resnets = []
449
+ motion_modules = []
450
+
451
+ for i in range(num_layers):
452
+ in_channels = in_channels if i == 0 else out_channels
453
+ resnets.append(
454
+ ResnetBlock3D(
455
+ in_channels=in_channels,
456
+ out_channels=out_channels,
457
+ temb_channels=temb_channels,
458
+ eps=resnet_eps,
459
+ groups=resnet_groups,
460
+ dropout=dropout,
461
+ time_embedding_norm=resnet_time_scale_shift,
462
+ non_linearity=resnet_act_fn,
463
+ output_scale_factor=output_scale_factor,
464
+ pre_norm=resnet_pre_norm,
465
+
466
+ use_inflated_groupnorm=use_inflated_groupnorm,
467
+ )
468
+ )
469
+ motion_modules.append(
470
+ get_motion_module(
471
+ in_channels=out_channels,
472
+ motion_module_type=motion_module_type,
473
+ motion_module_kwargs=motion_module_kwargs,
474
+ ) if use_motion_module else None
475
+ )
476
+
477
+ self.resnets = nn.ModuleList(resnets)
478
+ self.motion_modules = nn.ModuleList(motion_modules)
479
+
480
+ if add_downsample:
481
+ self.downsamplers = nn.ModuleList(
482
+ [
483
+ Downsample3D(
484
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
485
+ )
486
+ ]
487
+ )
488
+ else:
489
+ self.downsamplers = None
490
+
491
+ self.gradient_checkpointing = False
492
+
493
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
494
+ output_states = ()
495
+
496
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
497
+ if self.training and self.gradient_checkpointing:
498
+ def create_custom_forward(module):
499
+ def custom_forward(*inputs):
500
+ return module(*inputs)
501
+
502
+ return custom_forward
503
+
504
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
505
+ if motion_module is not None:
506
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
507
+ else:
508
+ hidden_states = resnet(hidden_states, temb)
509
+
510
+ # add motion module
511
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
512
+
513
+ output_states += (hidden_states,)
514
+
515
+ if self.downsamplers is not None:
516
+ for downsampler in self.downsamplers:
517
+ hidden_states = downsampler(hidden_states)
518
+
519
+ output_states += (hidden_states,)
520
+
521
+ return hidden_states, output_states
522
+
523
+
524
+ class CrossAttnUpBlock3D(nn.Module):
525
+ def __init__(
526
+ self,
527
+ in_channels: int,
528
+ out_channels: int,
529
+ prev_output_channel: int,
530
+ temb_channels: int,
531
+ dropout: float = 0.0,
532
+ num_layers: int = 1,
533
+ resnet_eps: float = 1e-6,
534
+ resnet_time_scale_shift: str = "default",
535
+ resnet_act_fn: str = "swish",
536
+ resnet_groups: int = 32,
537
+ resnet_pre_norm: bool = True,
538
+ attn_num_head_channels=1,
539
+ cross_attention_dim=1280,
540
+ output_scale_factor=1.0,
541
+ add_upsample=True,
542
+ dual_cross_attention=False,
543
+ use_linear_projection=False,
544
+ only_cross_attention=False,
545
+ upcast_attention=False,
546
+
547
+ unet_use_cross_frame_attention=False,
548
+ unet_use_temporal_attention=False,
549
+ use_inflated_groupnorm=False,
550
+
551
+ use_motion_module=None,
552
+
553
+ motion_module_type=None,
554
+ motion_module_kwargs=None,
555
+ ):
556
+ super().__init__()
557
+ resnets = []
558
+ attentions = []
559
+ motion_modules = []
560
+
561
+ self.has_cross_attention = True
562
+ self.attn_num_head_channels = attn_num_head_channels
563
+
564
+ for i in range(num_layers):
565
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
566
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
567
+
568
+ resnets.append(
569
+ ResnetBlock3D(
570
+ in_channels=resnet_in_channels + res_skip_channels,
571
+ out_channels=out_channels,
572
+ temb_channels=temb_channels,
573
+ eps=resnet_eps,
574
+ groups=resnet_groups,
575
+ dropout=dropout,
576
+ time_embedding_norm=resnet_time_scale_shift,
577
+ non_linearity=resnet_act_fn,
578
+ output_scale_factor=output_scale_factor,
579
+ pre_norm=resnet_pre_norm,
580
+
581
+ use_inflated_groupnorm=use_inflated_groupnorm,
582
+ )
583
+ )
584
+ if dual_cross_attention:
585
+ raise NotImplementedError
586
+ attentions.append(
587
+ Transformer3DModel(
588
+ attn_num_head_channels,
589
+ out_channels // attn_num_head_channels,
590
+ in_channels=out_channels,
591
+ num_layers=1,
592
+ cross_attention_dim=cross_attention_dim,
593
+ norm_num_groups=resnet_groups,
594
+ use_linear_projection=use_linear_projection,
595
+ only_cross_attention=only_cross_attention,
596
+ upcast_attention=upcast_attention,
597
+
598
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
599
+ unet_use_temporal_attention=unet_use_temporal_attention,
600
+ )
601
+ )
602
+ motion_modules.append(
603
+ get_motion_module(
604
+ in_channels=out_channels,
605
+ motion_module_type=motion_module_type,
606
+ motion_module_kwargs=motion_module_kwargs,
607
+ ) if use_motion_module else None
608
+ )
609
+
610
+ self.attentions = nn.ModuleList(attentions)
611
+ self.resnets = nn.ModuleList(resnets)
612
+ self.motion_modules = nn.ModuleList(motion_modules)
613
+
614
+ if add_upsample:
615
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
616
+ else:
617
+ self.upsamplers = None
618
+
619
+ self.gradient_checkpointing = False
620
+
621
+ def forward(
622
+ self,
623
+ hidden_states,
624
+ res_hidden_states_tuple,
625
+ temb=None,
626
+ encoder_hidden_states=None,
627
+ upsample_size=None,
628
+ attention_mask=None,
629
+ ):
630
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
631
+ # pop res hidden states
632
+ res_hidden_states = res_hidden_states_tuple[-1]
633
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
634
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
635
+
636
+ if self.training and self.gradient_checkpointing:
637
+
638
+ def create_custom_forward(module, return_dict=None):
639
+ def custom_forward(*inputs):
640
+ if return_dict is not None:
641
+ return module(*inputs, return_dict=return_dict)
642
+ else:
643
+ return module(*inputs)
644
+
645
+ return custom_forward
646
+
647
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
648
+ hidden_states = torch.utils.checkpoint.checkpoint(
649
+ create_custom_forward(attn, return_dict=False),
650
+ hidden_states,
651
+ encoder_hidden_states,
652
+ )[0]
653
+ if motion_module is not None:
654
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
655
+
656
+ else:
657
+ hidden_states = resnet(hidden_states, temb)
658
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
659
+
660
+ # add motion module
661
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
662
+
663
+ if self.upsamplers is not None:
664
+ for upsampler in self.upsamplers:
665
+ hidden_states = upsampler(hidden_states, upsample_size)
666
+
667
+ return hidden_states
668
+
669
+
670
+ class UpBlock3D(nn.Module):
671
+ def __init__(
672
+ self,
673
+ in_channels: int,
674
+ prev_output_channel: int,
675
+ out_channels: int,
676
+ temb_channels: int,
677
+ dropout: float = 0.0,
678
+ num_layers: int = 1,
679
+ resnet_eps: float = 1e-6,
680
+ resnet_time_scale_shift: str = "default",
681
+ resnet_act_fn: str = "swish",
682
+ resnet_groups: int = 32,
683
+ resnet_pre_norm: bool = True,
684
+ output_scale_factor=1.0,
685
+ add_upsample=True,
686
+
687
+ use_inflated_groupnorm=False,
688
+
689
+ use_motion_module=None,
690
+ motion_module_type=None,
691
+ motion_module_kwargs=None,
692
+ ):
693
+ super().__init__()
694
+ resnets = []
695
+ motion_modules = []
696
+
697
+ for i in range(num_layers):
698
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
699
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
700
+
701
+ resnets.append(
702
+ ResnetBlock3D(
703
+ in_channels=resnet_in_channels + res_skip_channels,
704
+ out_channels=out_channels,
705
+ temb_channels=temb_channels,
706
+ eps=resnet_eps,
707
+ groups=resnet_groups,
708
+ dropout=dropout,
709
+ time_embedding_norm=resnet_time_scale_shift,
710
+ non_linearity=resnet_act_fn,
711
+ output_scale_factor=output_scale_factor,
712
+ pre_norm=resnet_pre_norm,
713
+
714
+ use_inflated_groupnorm=use_inflated_groupnorm,
715
+ )
716
+ )
717
+ motion_modules.append(
718
+ get_motion_module(
719
+ in_channels=out_channels,
720
+ motion_module_type=motion_module_type,
721
+ motion_module_kwargs=motion_module_kwargs,
722
+ ) if use_motion_module else None
723
+ )
724
+
725
+ self.resnets = nn.ModuleList(resnets)
726
+ self.motion_modules = nn.ModuleList(motion_modules)
727
+
728
+ if add_upsample:
729
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
730
+ else:
731
+ self.upsamplers = None
732
+
733
+ self.gradient_checkpointing = False
734
+
735
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
736
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
737
+ # pop res hidden states
738
+ res_hidden_states = res_hidden_states_tuple[-1]
739
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
740
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
741
+
742
+ if self.training and self.gradient_checkpointing:
743
+ def create_custom_forward(module):
744
+ def custom_forward(*inputs):
745
+ return module(*inputs)
746
+
747
+ return custom_forward
748
+
749
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
750
+ if motion_module is not None:
751
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
752
+ else:
753
+ hidden_states = resnet(hidden_states, temb)
754
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
755
+
756
+ if self.upsamplers is not None:
757
+ for upsampler in self.upsamplers:
758
+ hidden_states = upsampler(hidden_states, upsample_size)
759
+
760
+ return hidden_states