fffiloni commited on
Commit
bfed184
1 Parent(s): 43797a0

Migrated from GitHub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. LICENSE.txt +201 -0
  3. ORIGINAL_README.md +70 -0
  4. assets/driving_video.mp4 +3 -0
  5. assets/source_image.png +0 -0
  6. assets/teaser/teaser.png +3 -0
  7. checkpoint/checkpoint_location +0 -0
  8. config/cldm_v15_appearance_pose_local_mm.yaml +130 -0
  9. core/test_xportrait.py +506 -0
  10. env_install.sh +2 -0
  11. model_lib/ControlNet/cldm/__pycache__/cldm.cpython-39.pyc +0 -0
  12. model_lib/ControlNet/cldm/__pycache__/model.cpython-39.pyc +0 -0
  13. model_lib/ControlNet/cldm/cldm.py +715 -0
  14. model_lib/ControlNet/cldm/model.py +28 -0
  15. model_lib/ControlNet/ldm/__pycache__/util.cpython-39.pyc +0 -0
  16. model_lib/ControlNet/ldm/data/__init__.py +0 -0
  17. model_lib/ControlNet/ldm/data/util.py +24 -0
  18. model_lib/ControlNet/ldm/models/__pycache__/autoencoder.cpython-39.pyc +0 -0
  19. model_lib/ControlNet/ldm/models/autoencoder.py +219 -0
  20. model_lib/ControlNet/ldm/models/diffusion/__init__.py +0 -0
  21. model_lib/ControlNet/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  22. model_lib/ControlNet/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc +0 -0
  23. model_lib/ControlNet/ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc +0 -0
  24. model_lib/ControlNet/ldm/models/diffusion/ddim.py +763 -0
  25. model_lib/ControlNet/ldm/models/diffusion/ddpm.py +0 -0
  26. model_lib/ControlNet/ldm/models/diffusion/dpm_solver/__init__.py +1 -0
  27. model_lib/ControlNet/ldm/models/diffusion/dpm_solver/dpm_solver.py +1154 -0
  28. model_lib/ControlNet/ldm/models/diffusion/dpm_solver/sampler.py +87 -0
  29. model_lib/ControlNet/ldm/models/diffusion/plms.py +244 -0
  30. model_lib/ControlNet/ldm/models/diffusion/sampling_util.py +22 -0
  31. model_lib/ControlNet/ldm/modules/__pycache__/attention.cpython-39.pyc +0 -0
  32. model_lib/ControlNet/ldm/modules/__pycache__/ema.cpython-39.pyc +0 -0
  33. model_lib/ControlNet/ldm/modules/__pycache__/motion_module.cpython-39.pyc +0 -0
  34. model_lib/ControlNet/ldm/modules/attention.py +386 -0
  35. model_lib/ControlNet/ldm/modules/diffusionmodules/__init__.py +0 -0
  36. model_lib/ControlNet/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc +0 -0
  37. model_lib/ControlNet/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc +0 -0
  38. model_lib/ControlNet/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc +0 -0
  39. model_lib/ControlNet/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc +0 -0
  40. model_lib/ControlNet/ldm/modules/diffusionmodules/model.py +859 -0
  41. model_lib/ControlNet/ldm/modules/diffusionmodules/openaimodel.py +1212 -0
  42. model_lib/ControlNet/ldm/modules/diffusionmodules/upscaling.py +81 -0
  43. model_lib/ControlNet/ldm/modules/diffusionmodules/util.py +305 -0
  44. model_lib/ControlNet/ldm/modules/distributions/__init__.py +0 -0
  45. model_lib/ControlNet/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc +0 -0
  46. model_lib/ControlNet/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc +0 -0
  47. model_lib/ControlNet/ldm/modules/distributions/distributions.py +92 -0
  48. model_lib/ControlNet/ldm/modules/ema.py +80 -0
  49. model_lib/ControlNet/ldm/modules/encoders/__init__.py +0 -0
  50. model_lib/ControlNet/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/driving_video.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ assets/teaser/teaser.png filter=lfs diff=lfs merge=lfs -text
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
ORIGINAL_README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- # magic-edit.github.io -->
2
+
3
+ <p align="center">
4
+
5
+ <h2 align="center">X-Portrait: Expressive Portrait Animation with Hierarchical Motion Attention</h2>
6
+ <p align="center">
7
+ <a href="https://scholar.google.com/citations?user=FV0eXhQAAAAJ&hl=en">You Xie</a>,
8
+ <a href="https://hongyixu37.github.io/homepage/">Hongyi Xu</a>,
9
+ <a href="https://guoxiansong.github.io/homepage/index.html">Guoxian Song</a>,
10
+ <a href="https://chaowang.info/">Chao Wang</a>,
11
+ <a href="https://seasonsh.github.io/">Yichun Shi</a>,
12
+ <a href="http://linjieluo.com/">Linjie Luo</a>
13
+ <br>
14
+ <b>&nbsp; ByteDance Inc. </b>
15
+ <br>
16
+ <br>
17
+ <a href="https://arxiv.org/abs/2403.15931"><img src='https://img.shields.io/badge/arXiv-X--Portrait-red' alt='Paper PDF'></a>
18
+ <a href='https://byteaigc.github.io/x-portrait/'><img src='https://img.shields.io/badge/Project_Page-X--Portrait-green' alt='Project Page'></a>
19
+ <a href='https://youtu.be/VGxt5XghRdw'>
20
+ <img src='https://img.shields.io/badge/YouTube-X--Portrait-rgb(255, 0, 0)' alt='Youtube'></a>
21
+ <br>
22
+ </p>
23
+
24
+ <table align="center">
25
+ <tr>
26
+ <td>
27
+ <img src="assets/teaser/teaser.png">
28
+ </td>
29
+ </tr>
30
+ </table>
31
+
32
+ This repository contains the video generation code of SIGGRAPH 2024 paper [X-Portrait](https://arxiv.org/pdf/2403.15931).
33
+
34
+ ## Installation
35
+ Note: Python 3.9 and Cuda 11.8 are required.
36
+ ```shell
37
+ bash env_install.sh
38
+ ```
39
+
40
+ ## Model
41
+ Please download pre-trained model from [here](https://drive.google.com/drive/folders/1Bq0n-w1VT5l99CoaVg02hFpqE5eGLo9O?usp=sharing), and save it under "checkpoint/"
42
+
43
+ ## Testing
44
+ ```shell
45
+ bash scripts/test_xportrait.sh
46
+ ```
47
+ parameters:
48
+ **model_config**: config file of the corresponding model
49
+ **output_dir**: output path for generated video
50
+ **source_image**: path of source image
51
+ **driving_video**: path of driving video
52
+ **best_frame**: specify the frame index in the driving video where the head pose best matches the source image (note: precision of best_frame index might affect the final quality)
53
+ **out_frames**: number of generation frames
54
+ **num_mix**: number of overlapping frames when applying prompt travelling during inference
55
+ **ddim_steps**: number of inference steps (e.g., 30 steps for ddim)
56
+
57
+ ## Performance Boost
58
+ **efficiency**: Our model is compatible with LCM LoRA (https://huggingface.co/latent-consistency/lcm-lora-sdv1-5), which helps reduce the number of inference steps.
59
+ **expressiveness**: Expressiveness of the results could be boosted if results of other face reenactment approaches, e.g., face vid2vid, could be provided via parameter "--initial_facevid2vid_results".
60
+
61
+ ## 🎓 Citation
62
+ If you find this codebase useful for your research, please use the following entry.
63
+ ```BibTeX
64
+ @inproceedings{xie2024x,
65
+ title={X-Portrait: Expressive Portrait Animation with Hierarchical Motion Attention},
66
+ author={Xie, You and Xu, Hongyi and Song, Guoxian and Wang, Chao and Shi, Yichun and Luo, Linjie},
67
+ journal={arXiv preprint arXiv:2403.15931},
68
+ year={2024}
69
+ }
70
+ ```
assets/driving_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:030c10c861e9fd4f6395eede5e9d4005dafc3fa56569e6a7167337b1b3675c08
3
+ size 3839556
assets/source_image.png ADDED
assets/teaser/teaser.png ADDED

Git LFS Details

  • SHA256: 32c8e4475ed8b1db09711d4258c49ff03b9cb9b461557dd7f715bf056940b3c7
  • Pointer size: 132 Bytes
  • Size of remote file: 5.33 MB
checkpoint/checkpoint_location ADDED
File without changes
config/cldm_v15_appearance_pose_local_mm.yaml ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: model_lib.ControlNet.cldm.cldm.ControlLDMReferenceOnly_Temporal_Pose_Local
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ num_timesteps_cond: 1
7
+ log_every_t: 200
8
+ timesteps: 1000
9
+ first_stage_key: "jpg"
10
+ cond_stage_key: "txt"
11
+ control_key: "hint"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+ only_mid_control: False
20
+
21
+ appearance_control_stage_config:
22
+ target: model_lib.ControlNet.cldm.cldm.ControlNetReferenceOnly
23
+ params:
24
+ image_size: 32 # unused
25
+ in_channels: 4
26
+ hint_channels: 3
27
+ out_channels: 4
28
+ model_channels: 320
29
+ attention_resolutions: [ 4, 2, 1 ]
30
+ num_res_blocks: 2
31
+ channel_mult: [ 1, 2, 4, 4 ]
32
+ num_heads: 8
33
+ use_spatial_transformer: True
34
+ transformer_depth: 1
35
+ context_dim: 768
36
+ use_checkpoint: True
37
+ legacy: False
38
+
39
+ pose_control_stage_config:
40
+ target: model_lib.ControlNet.cldm.cldm.ControlNet
41
+ params:
42
+ image_size: 32 # unused
43
+ in_channels: 4
44
+ hint_channels: 3
45
+ model_channels: 320
46
+ attention_resolutions: [ 4, 2, 1 ]
47
+ num_res_blocks: 2
48
+ channel_mult: [ 1, 2, 4, 4 ]
49
+ num_heads: 8
50
+ use_spatial_transformer: True
51
+ transformer_depth: 1
52
+ context_dim: 768
53
+ use_checkpoint: True
54
+ legacy: False
55
+
56
+ local_pose_control_stage_config:
57
+ target: model_lib.ControlNet.cldm.cldm.ControlNet
58
+ params:
59
+ image_size: 32 # unused
60
+ in_channels: 4
61
+ hint_channels: 3
62
+ model_channels: 320
63
+ attention_resolutions: [ 4, 2, 1 ]
64
+ num_res_blocks: 2
65
+ channel_mult: [ 1, 2, 4, 4 ]
66
+ num_heads: 8
67
+ use_spatial_transformer: True
68
+ transformer_depth: 1
69
+ context_dim: 768
70
+ use_checkpoint: True
71
+ legacy: False
72
+
73
+ unet_config:
74
+ target: model_lib.ControlNet.cldm.cldm.ControlledUnetModelAttn_Temporal_Pose_Local
75
+ params:
76
+ image_size: 32 # unused
77
+ in_channels: 4
78
+ out_channels: 4
79
+ model_channels: 320
80
+ attention_resolutions: [ 4, 2, 1 ]
81
+ num_res_blocks: 2
82
+ channel_mult: [ 1, 2, 4, 4 ]
83
+ num_heads: 8
84
+ use_spatial_transformer: True
85
+ transformer_depth: 1
86
+ context_dim: 768
87
+ use_checkpoint: True
88
+ legacy: False
89
+
90
+ unet_additional_kwargs:
91
+ use_motion_module : true
92
+ motion_module_resolutions : [ 1,2,4,8 ]
93
+ unet_use_cross_frame_attention : false
94
+ unet_use_temporal_attention : false
95
+
96
+ motion_module_type: Vanilla
97
+ motion_module_kwargs:
98
+ num_attention_heads : 8
99
+ num_transformer_block : 1
100
+ attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
101
+ temporal_position_encoding : true
102
+ temporal_position_encoding_max_len : 24
103
+ temporal_attention_dim_div : 1
104
+ zero_initialize : true
105
+
106
+ first_stage_config:
107
+ target: model_lib.ControlNet.ldm.models.autoencoder.AutoencoderKL
108
+ params:
109
+ embed_dim: 4
110
+ monitor: val/rec_loss
111
+ ddconfig:
112
+ double_z: true
113
+ z_channels: 4
114
+ resolution: 256
115
+ in_channels: 3
116
+ out_ch: 3
117
+ ch: 128
118
+ ch_mult:
119
+ - 1
120
+ - 2
121
+ - 4
122
+ - 4
123
+ num_res_blocks: 2
124
+ attn_resolutions: []
125
+ dropout: 0.0
126
+ lossconfig:
127
+ target: torch.nn.Identity
128
+
129
+ cond_stage_config:
130
+ target: model_lib.ControlNet.ldm.modules.encoders.modules.FrozenCLIPEmbedder
core/test_xportrait.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+ import os
7
+ import argparse
8
+ import numpy as np
9
+ # torch
10
+ import torch
11
+ from ema_pytorch import EMA
12
+ from einops import rearrange
13
+ import cv2
14
+ # utils
15
+ from utils.utils import set_seed, count_param, print_peak_memory
16
+ # model
17
+ import imageio
18
+ from model_lib.ControlNet.cldm.model import create_model
19
+ import copy
20
+ import glob
21
+ import imageio
22
+ from skimage.transform import resize
23
+ from skimage import img_as_ubyte
24
+ import face_alignment
25
+ import sys
26
+ from decord import VideoReader
27
+ from decord import cpu, gpu
28
+
29
+ TORCH_VERSION = torch.__version__.split(".")[0]
30
+ FP16_DTYPE = torch.float16
31
+ print(f"TORCH_VERSION={TORCH_VERSION} FP16_DTYPE={FP16_DTYPE}")
32
+
33
+ def extract_local_feature_from_single_img(img, fa, remove_local=False, real_tocrop=None, target_res = 512):
34
+ device = img.device
35
+ pred = img.permute([1, 2, 0]).detach().cpu().numpy()
36
+
37
+ pred_lmks = img_as_ubyte(resize(pred, (256, 256)))
38
+
39
+ try:
40
+ lmks = fa.get_landmarks_from_image(pred_lmks, return_landmark_score=False)[0]
41
+ except:
42
+ print ('undetected faces!!')
43
+ if real_tocrop is None:
44
+ return torch.zeros_like(img) * 2 - 1., [196,196,320,320]
45
+ return torch.zeros_like(img), [196,196,320,320]
46
+
47
+ halfedge = 32
48
+ left_eye_center = (np.clip(np.round(np.mean(lmks[43:48], axis=0)), halfedge, 255-halfedge) * (target_res / 256)).astype(np.int32)
49
+ right_eye_center = (np.clip(np.round(np.mean(lmks[37:42], axis=0)), halfedge, 255-halfedge) * (target_res / 256)).astype(np.int32)
50
+ mouth_center = (np.clip(np.round(np.mean(lmks[49:68], axis=0)), halfedge, 255-halfedge) * (target_res / 256)).astype(np.int32)
51
+
52
+ if real_tocrop is not None:
53
+ pred = real_tocrop.permute([1, 2, 0]).detach().cpu().numpy()
54
+
55
+ half_size = target_res // 8 #64
56
+ if remove_local:
57
+ local_viz = pred
58
+ local_viz[left_eye_center[1] - half_size : left_eye_center[1] + half_size, left_eye_center[0] - half_size : left_eye_center[0] + half_size] = 0
59
+ local_viz[right_eye_center[1] - half_size : right_eye_center[1] + half_size, right_eye_center[0] - half_size : right_eye_center[0] + half_size] = 0
60
+ local_viz[mouth_center[1] - half_size : mouth_center[1] + half_size, mouth_center[0] - half_size : mouth_center[0] + half_size] = 0
61
+ else:
62
+ local_viz = np.zeros_like(pred)
63
+ local_viz[left_eye_center[1] - half_size : left_eye_center[1] + half_size, left_eye_center[0] - half_size : left_eye_center[0] + half_size] = pred[left_eye_center[1] - half_size : left_eye_center[1] + half_size, left_eye_center[0] - half_size : left_eye_center[0] + half_size]
64
+ local_viz[right_eye_center[1] - half_size : right_eye_center[1] + half_size, right_eye_center[0] - half_size : right_eye_center[0] + half_size] = pred[right_eye_center[1] - half_size : right_eye_center[1] + half_size, right_eye_center[0] - half_size : right_eye_center[0] + half_size]
65
+ local_viz[mouth_center[1] - half_size : mouth_center[1] + half_size, mouth_center[0] - half_size : mouth_center[0] + half_size] = pred[mouth_center[1] - half_size : mouth_center[1] + half_size, mouth_center[0] - half_size : mouth_center[0] + half_size]
66
+
67
+ local_viz = torch.from_numpy(local_viz).to(device)
68
+ local_viz = local_viz.permute([2, 0, 1])
69
+ if real_tocrop is None:
70
+ local_viz = local_viz * 2 - 1.
71
+ return local_viz
72
+
73
+ def find_best_frame_byheadpose_fa(source_image, driving_video, fa):
74
+ input = img_as_ubyte(resize(source_image, (256, 256)))
75
+ try:
76
+ src_pose_array = fa.get_landmarks_from_image(input, return_landmark_score=False)[0]
77
+ except:
78
+ print ('undetected faces in the source image!!')
79
+ src_pose_array = np.zeros((68,2))
80
+ if len(src_pose_array) == 0:
81
+ return 0
82
+ min_diff = 1e8
83
+ best_frame = 0
84
+
85
+ for i in range(len(driving_video)):
86
+ frame = img_as_ubyte(resize(driving_video[i], (256, 256)))
87
+ try:
88
+ drv_pose_array = fa.get_landmarks_from_image(frame, return_landmark_score=False)[0]
89
+ except:
90
+ print ('undetected faces in the %d-th driving image!!'%i)
91
+ drv_pose_array = np.zeros((68,2))
92
+ diff = np.sum(np.abs(np.array(src_pose_array)-np.array(drv_pose_array)))
93
+ if diff < min_diff:
94
+ best_frame = i
95
+ min_diff = diff
96
+
97
+ return best_frame
98
+
99
+ def adjust_driving_video_to_src_image(source_image, driving_video, fa, nm_res, nmd_res, best_frame=-1):
100
+ if best_frame == -2:
101
+ return [resize(frame, (nm_res, nm_res)) for frame in driving_video], [resize(frame, (nmd_res, nmd_res)) for frame in driving_video]
102
+ src = img_as_ubyte(resize(source_image[..., :3], (256, 256)))
103
+ if best_frame >= len(source_image):
104
+ raise ValueError(
105
+ f"please specify one frame in driving video of which the pose match best with the pose of source image"
106
+ )
107
+
108
+ if best_frame < 0:
109
+ best_frame = find_best_frame_byheadpose_fa(src, driving_video, fa)
110
+
111
+ print ('Best Frame: %d' % best_frame)
112
+ driving = img_as_ubyte(resize(driving_video[best_frame], (256, 256)))
113
+
114
+ src_lmks = fa.get_landmarks_from_image(src, return_landmark_score=False)
115
+ drv_lmks = fa.get_landmarks_from_image(driving, return_landmark_score=False)
116
+
117
+ if (src_lmks is None) or (drv_lmks is None):
118
+ return [resize(frame, (nm_res, nm_res)) for frame in driving_video], [resize(frame, (nmd_res, nmd_res)) for frame in driving_video]
119
+ src_lmks = src_lmks[0]
120
+ drv_lmks = drv_lmks[0]
121
+ src_centers = np.mean(src_lmks, axis=0)
122
+ drv_centers = np.mean(drv_lmks, axis=0)
123
+ edge_src = (np.max(src_lmks, axis=0) - np.min(src_lmks, axis=0))*0.5
124
+ edge_drv = (np.max(drv_lmks, axis=0) - np.min(drv_lmks, axis=0))*0.5
125
+
126
+ #matching three points
127
+ src_point=np.array([[src_centers[0]-edge_src[0],src_centers[1]-edge_src[1]],[src_centers[0]+edge_src[0],src_centers[1]-edge_src[1]],[src_centers[0]-edge_src[0],src_centers[1]+edge_src[1]],[src_centers[0]+edge_src[0],src_centers[1]+edge_src[1]]]).astype(np.float32)
128
+ dst_point=np.array([[drv_centers[0]-edge_drv[0],drv_centers[1]-edge_drv[1]],[drv_centers[0]+edge_drv[0],drv_centers[1]-edge_drv[1]],[drv_centers[0]-edge_drv[0],drv_centers[1]+edge_drv[1]],[drv_centers[0]+edge_drv[0],drv_centers[1]+edge_drv[1]]]).astype(np.float32)
129
+
130
+ adjusted_driving_video = []
131
+ adjusted_driving_video_hd = []
132
+
133
+ for frame in driving_video:
134
+ frame_ld = resize(frame, (nm_res, nm_res))
135
+ frame_hd = resize(frame, (nmd_res, nmd_res))
136
+ zoomed=cv2.warpAffine(frame_ld, cv2.getAffineTransform(dst_point[:3], src_point[:3]), (nm_res, nm_res))
137
+ zoomed_hd=cv2.warpAffine(frame_hd, cv2.getAffineTransform(dst_point[:3] * 2, src_point[:3] * 2), (nmd_res, nmd_res))
138
+ adjusted_driving_video.append(zoomed)
139
+ adjusted_driving_video_hd.append(zoomed_hd)
140
+
141
+ return adjusted_driving_video, adjusted_driving_video_hd
142
+
143
+ def x_portrait_data_prep(source_image_path, driving_video_path, device, best_frame_id=0, start_idx = 0, num_frames=0, skip=1, output_local=False, more_source_image_pattern="", target_resolution = 512):
144
+ source_image = imageio.imread(source_image_path)
145
+ if '.mp4' in driving_video_path:
146
+ reader = imageio.get_reader(driving_video_path)
147
+ fps = reader.get_meta_data()['fps']
148
+ driving_video = []
149
+ try:
150
+ for im in reader:
151
+ driving_video.append(im)
152
+ except RuntimeError:
153
+ pass
154
+ reader.close()
155
+ else:
156
+ driving_video = [imageio.imread(driving_video_path)[...,:3]]
157
+ fps = 1
158
+
159
+ nmd_res = target_resolution
160
+ nm_res = 256
161
+ source_image_hd = resize(source_image, (nmd_res, nmd_res))[..., :3]
162
+
163
+ if more_source_image_pattern:
164
+ more_source_paths = glob.glob(more_source_image_pattern)
165
+ more_sources_hd = []
166
+ for more_source_path in more_source_paths:
167
+ more_source_image = imageio.imread(more_source_path)
168
+ more_source_image_hd = resize(more_source_image, (nmd_res, nmd_res))[..., :3]
169
+ more_source_hd = torch.tensor(more_source_image_hd[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
170
+ more_source_hd = more_source_hd.to(device)
171
+ more_sources_hd.append(more_source_hd)
172
+ more_sources_hd = torch.stack(more_sources_hd, dim = 1)
173
+ else:
174
+ more_sources_hd = None
175
+
176
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=True, device='cuda')
177
+
178
+ driving_video, driving_video_hd = adjust_driving_video_to_src_image(source_image, driving_video, fa, nm_res, nmd_res, best_frame_id)
179
+
180
+ if num_frames == 0:
181
+ end_idx = len(driving_video)
182
+ else:
183
+ num_frames = min(len(driving_video), num_frames)
184
+ end_idx = start_idx + num_frames * skip
185
+
186
+ driving_video = driving_video[start_idx:end_idx][::skip]
187
+ driving_video_hd = driving_video_hd[start_idx:end_idx][::skip]
188
+ num_frames = len(driving_video)
189
+
190
+ with torch.no_grad():
191
+ real_source_hd = torch.tensor(source_image_hd[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
192
+ real_source_hd = real_source_hd.to(device)
193
+
194
+ driving_hd = torch.tensor(np.array(driving_video_hd).astype(np.float32)).permute(0, 3, 1, 2).to(device)
195
+
196
+ local_features = []
197
+ raw_drivings=[]
198
+
199
+ for frame_idx in range(0, num_frames):
200
+ raw_drivings.append(driving_hd[frame_idx:frame_idx+1] * 2 - 1.)
201
+ if output_local:
202
+ local_feature_img = extract_local_feature_from_single_img(driving_hd[frame_idx], fa,target_res=nmd_res)
203
+ local_features.append(local_feature_img)
204
+
205
+
206
+ batch_data = {}
207
+ batch_data['fps'] = fps
208
+ real_source_hd = real_source_hd * 2 - 1
209
+ batch_data['sources'] = real_source_hd[:, None, :, :, :].repeat([num_frames, 1, 1, 1, 1])
210
+ if more_sources_hd is not None:
211
+ more_sources_hd = more_sources_hd * 2 - 1
212
+ batch_data['more_sources'] = more_sources_hd.repeat([num_frames, 1, 1, 1, 1])
213
+
214
+ raw_drivings = torch.stack(raw_drivings, dim = 0)
215
+ batch_data['conditions'] = raw_drivings
216
+ if output_local:
217
+ batch_data['local'] = torch.stack(local_features, dim = 0)
218
+
219
+ return batch_data
220
+
221
+ # You can now use the modified state_dict without the deleted keys
222
+ def load_state_dict(model, ckpt_path, reinit_hint_block=False, strict=True, map_location="cpu"):
223
+ print(f"Loading model state dict from {ckpt_path} ...")
224
+ state_dict = torch.load(ckpt_path, map_location=map_location)
225
+ state_dict = state_dict.get('state_dict', state_dict)
226
+ if reinit_hint_block:
227
+ print("Ignoring hint block parameters from checkpoint!")
228
+ for k in list(state_dict.keys()):
229
+ if k.startswith("control_model.input_hint_block"):
230
+ state_dict.pop(k)
231
+ model.load_state_dict(state_dict, strict=strict)
232
+ del state_dict
233
+
234
+ def get_cond_control(args, batch_data, control_type, device, start, end, model=None, batch_size=None, train=True, key=0):
235
+
236
+ control_type = copy.deepcopy(control_type)
237
+ vae_bs = 16
238
+ if control_type == "appearance_pose_local_mm":
239
+ src = batch_data['sources'][start:end, key].cuda()
240
+ c_cat_list = batch_data['conditions'][start:end].cuda()
241
+ cond_image = []
242
+ for k in range(0, end-start, vae_bs):
243
+ cond_image.append(model.get_first_stage_encoding(model.encode_first_stage(src[k:k+vae_bs])))
244
+ cond_image = torch.concat(cond_image, dim=0)
245
+ cond_img_cat = cond_image
246
+ p_local = batch_data['local'][start:end].cuda()
247
+ print ('Total frames:{}'.format(cond_img_cat.shape))
248
+ more_cond_imgs = []
249
+ if 'more_sources' in batch_data:
250
+ num_additional_cond_imgs = batch_data['more_sources'].shape[1]
251
+ for i in range(num_additional_cond_imgs):
252
+ m_cond_img = batch_data['more_sources'][start:end, i]
253
+ m_cond_img = model.get_first_stage_encoding(model.encode_first_stage(m_cond_img))
254
+ more_cond_imgs.append([m_cond_img.to(device)])
255
+
256
+ return [cond_img_cat.to(device), c_cat_list, p_local, more_cond_imgs]
257
+ else:
258
+ raise NotImplementedError(f"cond_type={control_type} not supported!")
259
+
260
+ def visualize_mm(args, name, batch_data, infer_model, nSample, local_image_dir, num_mix=4, preset_output_name=''):
261
+ driving_video_name = os.path.basename(batch_data['video_name']).split('.')[0]
262
+ source_name = os.path.basename(batch_data['source_name']).split('.')[0]
263
+
264
+ if not os.path.exists(local_image_dir):
265
+ os.mkdir(local_image_dir)
266
+
267
+ uc_scale = args.uc_scale
268
+ if preset_output_name:
269
+ preset_output_name = preset_output_name.split('.')[0]+'.mp4'
270
+ output_path = f"{local_image_dir}/{preset_output_name}"
271
+ else:
272
+ output_path = f"{local_image_dir}/{name}_{args.control_type}_uc{uc_scale}_{source_name}_by_{driving_video_name}_mix{num_mix}.mp4"
273
+
274
+ infer_model.eval()
275
+
276
+ gene_img_list = []
277
+
278
+ _, _, ch, h, w = batch_data['sources'].shape
279
+
280
+ vae_bs = 16
281
+
282
+ if args.initial_facevid2vid_results:
283
+ facevid2vid = []
284
+ facevid2vid_results = VideoReader(args.initial_facevid2vid_results, ctx=cpu(0))
285
+ for frame_id in range(len(facevid2vid_results)):
286
+ frame = cv2.resize(facevid2vid_results[frame_id].asnumpy(),(512,512)) / 255
287
+ facevid2vid.append(torch.from_numpy(frame * 2 - 1).permute(2,0,1))
288
+ cond = torch.stack(facevid2vid)[:nSample].float().to(args.device)
289
+ pre_noise=[]
290
+ for i in range(0, nSample, vae_bs):
291
+ pre_noise.append(infer_model.get_first_stage_encoding(infer_model.encode_first_stage(cond[i:i+vae_bs])))
292
+ pre_noise = torch.cat(pre_noise, dim=0)
293
+ pre_noise = infer_model.q_sample(x_start = pre_noise, t = torch.tensor([999]).to(pre_noise.device))
294
+ else:
295
+ cond = batch_data['sources'][:nSample].reshape([-1, ch, h, w])
296
+ pre_noise=[]
297
+ for i in range(0, nSample, vae_bs):
298
+ pre_noise.append(infer_model.get_first_stage_encoding(infer_model.encode_first_stage(cond[i:i+vae_bs])))
299
+ pre_noise = torch.cat(pre_noise, dim=0)
300
+ pre_noise = infer_model.q_sample(x_start = pre_noise, t = torch.tensor([999]).to(pre_noise.device))
301
+
302
+ text = ["" for _ in range(nSample)]
303
+
304
+ all_c_cat = get_cond_control(args, batch_data, args.control_type, args.device, start=0, end=nSample, model=infer_model, train=False)
305
+ cond_img_cat = [all_c_cat[0]]
306
+ pose_cond_list = [rearrange(all_c_cat[1], "b f c h w -> (b f) c h w")]
307
+ local_pose_cond_list = [all_c_cat[2]]
308
+
309
+ c_cross = infer_model.get_learned_conditioning(text)[:nSample]
310
+ uc_cross = infer_model.get_unconditional_conditioning(nSample)
311
+
312
+ c = {"c_crossattn": [c_cross], "image_control": cond_img_cat}
313
+ if "appearance_pose" in args.control_type:
314
+ c['c_concat'] = pose_cond_list
315
+ if "appearance_pose_local" in args.control_type:
316
+ c["local_c_concat"] = local_pose_cond_list
317
+
318
+ if len(all_c_cat) > 3 and len(all_c_cat[3]) > 0:
319
+ c['more_image_control'] = all_c_cat[3]
320
+
321
+ if args.control_mode == "controlnet_important":
322
+ uc = {"c_crossattn": [uc_cross]}
323
+ else:
324
+ uc = {"c_crossattn": [uc_cross], "image_control":cond_img_cat}
325
+
326
+ if "appearance_pose" in args.control_type:
327
+ uc['c_concat'] = [torch.zeros_like(pose_cond_list[0])]
328
+
329
+ if "appearance_pose_local" in args.control_type:
330
+ uc["local_c_concat"] = [torch.zeros_like(local_pose_cond_list[0])]
331
+
332
+ if len(all_c_cat) > 3 and len(all_c_cat[3]) > 0:
333
+ uc['more_image_control'] = all_c_cat[3]
334
+
335
+ if args.wonoise:
336
+ c['wonoise'] = True
337
+ uc['wonoise'] = True
338
+ else:
339
+ c['wonoise'] = False
340
+ uc['wonoise'] = False
341
+
342
+ noise = pre_noise.to(c_cross.device)
343
+
344
+ with torch.cuda.amp.autocast(enabled=args.use_fp16, dtype=FP16_DTYPE):
345
+ infer_model.to(args.device)
346
+ infer_model.eval()
347
+
348
+ gene_img, _ = infer_model.sample_log(cond=c,
349
+ batch_size=args.num_drivings, ddim=True,
350
+ ddim_steps=args.ddim_steps, eta=args.eta,
351
+ unconditional_guidance_scale=uc_scale,
352
+ unconditional_conditioning=uc,
353
+ inpaint=None,
354
+ x_T=noise,
355
+ num_overlap=num_mix,
356
+ )
357
+
358
+ for i in range(0, nSample, vae_bs):
359
+ gene_img_part = infer_model.decode_first_stage( gene_img[i:i+vae_bs] )
360
+ gene_img_list.append(gene_img_part.float().clamp(-1, 1).cpu())
361
+
362
+ _, c, h, w = gene_img_list[0].shape
363
+
364
+ cond_image = batch_data["conditions"].reshape([-1,c,h,w])[:nSample].cpu()
365
+ l_cond_image = batch_data["local"].reshape([-1,c,h,w])[:nSample].cpu()
366
+ orig_image = batch_data["sources"][:nSample, 0].cpu()
367
+
368
+ output_img = torch.cat(gene_img_list + [cond_image.cpu()]+[l_cond_image.cpu()]+[orig_image.cpu()]).float().clamp(-1,1).add(1).mul(0.5)
369
+
370
+ num_cols = 4
371
+ output_img = output_img.reshape([num_cols, 1, nSample, c, h, w]).permute([1, 0, 2, 3, 4,5])
372
+
373
+ output_img = output_img.permute([2, 3, 0, 4, 1, 5]).reshape([-1, c, h, num_cols * w])
374
+ output_img = torch.permute(output_img, [0, 2, 3, 1])
375
+
376
+ output_img = output_img.data.cpu().numpy()
377
+ output_img = img_as_ubyte(output_img)
378
+ imageio.mimsave(output_path, output_img[:,:,:512], fps=batch_data['fps'], quality=10, pixelformat='yuv420p', codec='libx264')
379
+
380
+ def main(args):
381
+
382
+ # ******************************
383
+ # initialize training
384
+ # ******************************
385
+ args.world_size = 1
386
+ args.local_rank = 0
387
+ args.rank = 0
388
+ args.device = torch.device("cuda", args.local_rank)
389
+
390
+ # set seed for reproducibility
391
+ set_seed(args.seed)
392
+
393
+ # ******************************
394
+ # create model
395
+ # ******************************
396
+ model = create_model(args.model_config).cpu()
397
+ model.sd_locked = args.sd_locked
398
+ model.only_mid_control = args.only_mid_control
399
+ model.to(args.local_rank)
400
+ if not os.path.exists(args.output_dir):
401
+ os.makedirs(args.output_dir)
402
+ if args.local_rank == 0:
403
+ print('Total base parameters {:.02f}M'.format(count_param([model])))
404
+ if args.ema_rate is not None and args.ema_rate > 0 and args.rank == 0:
405
+ print(f"Creating EMA model at ema_rate={args.ema_rate}")
406
+ model_ema = EMA(model, beta=args.ema_rate, update_after_step=0, update_every=1)
407
+ else:
408
+ model_ema = None
409
+
410
+ # ******************************
411
+ # load pre-trained models
412
+ # ******************************
413
+ if args.resume_dir is not None:
414
+ if args.local_rank == 0:
415
+ load_state_dict(model, args.resume_dir, strict=False)
416
+ else:
417
+ print('please privide the correct resume_dir!')
418
+ exit()
419
+
420
+ # ******************************
421
+ # create DDP model
422
+ # ******************************
423
+ if args.compile and TORCH_VERSION == "2":
424
+ model = torch.compile(model)
425
+
426
+ torch.cuda.set_device(args.local_rank)
427
+ print_peak_memory("Max memory allocated after creating DDP", args.local_rank)
428
+ infer_model = model.module if hasattr(model, "module") else model
429
+
430
+ with torch.no_grad():
431
+ driving_videos = glob.glob(args.driving_video)
432
+ for driving_video in driving_videos:
433
+ print ('working on {}'.format(os.path.basename(driving_video)))
434
+ infer_batch_data = x_portrait_data_prep(args.source_image, driving_video, args.device, args.best_frame, start_idx = args.start_idx, num_frames = args.out_frames, skip=args.skip, output_local=True)
435
+ infer_batch_data['video_name'] = os.path.basename(driving_video)
436
+ infer_batch_data['source_name'] = args.source_image
437
+ nSample = infer_batch_data['sources'].shape[0]
438
+ visualize_mm(args, "inference", infer_batch_data, infer_model, nSample=nSample, local_image_dir=args.output_dir, num_mix=args.num_mix)
439
+
440
+
441
+ if __name__ == "__main__":
442
+
443
+ str2bool = lambda arg: bool(int(arg))
444
+ parser = argparse.ArgumentParser(description='Control Net training')
445
+ ## Model
446
+ parser.add_argument('--model_config', type=str, default="model_lib/ControlNet/models/cldm_v15_video_appearance.yaml",
447
+ help="The path of model config file")
448
+ parser.add_argument('--reinit_hint_block', action='store_true', default=False,
449
+ help="Re-initialize hint blocks for channel mis-match")
450
+ parser.add_argument('--sd_locked', type =str2bool, default=True,
451
+ help='Freeze parameters in original stable-diffusion decoder')
452
+ parser.add_argument('--only_mid_control', type =str2bool, default=False,
453
+ help='Only control middle blocks')
454
+ parser.add_argument('--control_type', type=str, default="appearance_pose_local_mm",
455
+ help='The type of conditioning')
456
+ parser.add_argument("--control_mode", type=str, default="controlnet_important",
457
+ help="Set controlnet is more important or balance.")
458
+ parser.add_argument('--wonoise', action='store_false', default=True,
459
+ help='Use with referenceonly, remove adding noise on reference image')
460
+
461
+ ## Training
462
+ parser.add_argument("--local_rank", type=int, default=0)
463
+ parser.add_argument("--world_size", type=int, default=1)
464
+ parser.add_argument('--seed', type=int, default=42,
465
+ help='random seed for initialization')
466
+ parser.add_argument('--use_fp16', action='store_false', default=True,
467
+ help='Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit')
468
+ parser.add_argument('--compile', type=str2bool, default=False,
469
+ help='compile model (for torch 2)')
470
+ parser.add_argument('--eta', type = float, default = 0.0,
471
+ help='eta during DDIM Sampling')
472
+ parser.add_argument('--ema_rate', type = float, default = 0,
473
+ help='rate for ema')
474
+ ## inference
475
+ parser.add_argument("--initial_facevid2vid_results", type=str, default=None,
476
+ help="facevid2vid results for noise initialization")
477
+ parser.add_argument('--ddim_steps', type = int, default = 1,
478
+ help='denoising steps')
479
+ parser.add_argument('--uc_scale', type = int, default = 5,
480
+ help='cfg')
481
+ parser.add_argument("--num_drivings", type = int, default = 16,
482
+ help="Number of driving images in a single sequence of video.")
483
+ parser.add_argument("--output_dir", type=str, default=None, required=True,
484
+ help="The output directory where the model predictions and checkpoints will be written.")
485
+ parser.add_argument("--resume_dir", type=str, default=None,
486
+ help="The resume directory where the model checkpoints will be loaded.")
487
+ parser.add_argument("--source_image", type=str, default="",
488
+ help="The source image for neural motion.")
489
+ parser.add_argument("--more_source_image_pattern", type=str, default="",
490
+ help="The source image for neural motion.")
491
+ parser.add_argument("--driving_video", type=str, default="",
492
+ help="The source image mask for neural motion.")
493
+ parser.add_argument('--best_frame', type=int, default=0,
494
+ help='best matching frame index')
495
+ parser.add_argument('--start_idx', type=int, default=0,
496
+ help='starting frame index')
497
+ parser.add_argument('--skip', type=int, default=1,
498
+ help='skip frame')
499
+ parser.add_argument('--num_mix', type=int, default=4,
500
+ help='num overlapping frames')
501
+ parser.add_argument('--out_frames', type=int, default=0,
502
+ help='num frames')
503
+ args = parser.parse_args()
504
+
505
+ main(args)
506
+
env_install.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pip install -r requirements.txt
2
+ sudo apt install python3-tk
model_lib/ControlNet/cldm/__pycache__/cldm.cpython-39.pyc ADDED
Binary file (13.4 kB). View file
 
model_lib/ControlNet/cldm/__pycache__/model.cpython-39.pyc ADDED
Binary file (1.18 kB). View file
 
model_lib/ControlNet/cldm/cldm.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from re import I
2
+ import torch
3
+ import torch as th
4
+ import torch.nn as nn
5
+ from model_lib.ControlNet.ldm.modules.diffusionmodules.util import (
6
+ conv_nd,
7
+ linear,
8
+ zero_module,
9
+ timestep_embedding,
10
+ )
11
+
12
+ from model_lib.ControlNet.ldm.modules.attention import SpatialTransformer
13
+ from model_lib.ControlNet.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock,Upsample, UNetModel_Temporal
14
+ from model_lib.ControlNet.ldm.models.diffusion.ddpm import LatentDiffusionReferenceOnly
15
+ from model_lib.ControlNet.ldm.util import exists, instantiate_from_config
16
+
17
+ ## TODO: here UNet
18
+ class ControlledUnetModelAttn_Temporal_Pose_Local(UNetModel_Temporal):
19
+ def forward(self, x, timesteps=None, context=None, control=None, pose_control=None,local_pose_control=None,only_mid_control=False, attention_mode=None,uc=False, **kwargs):
20
+ hs = []
21
+ bank_attn = control
22
+ attn_index = 0
23
+
24
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
25
+ emb = self.time_embed(t_emb)
26
+ h = x.type(self.dtype)
27
+ num_input_motion_module = 0
28
+ if uc:
29
+ for i, module in enumerate(self.input_blocks):
30
+ if i in [1,2,4,5,7,8,10,11]:
31
+ motion_module = self.input_blocks_motion_module[num_input_motion_module]
32
+ h = module(h, emb, context,uc=uc) # Attn here
33
+ h = motion_module(h, emb, context)
34
+ num_input_motion_module += 1
35
+ else:
36
+ h = module(h, emb, context,uc=uc) # Attn here
37
+ hs.append(h)
38
+
39
+ h = self.middle_block(h, emb, context,uc=uc) # Attn here
40
+
41
+ for i, module in enumerate(self.output_blocks):
42
+ output_block_motion_module = self.output_blocks_motion_module[i]
43
+ if only_mid_control:
44
+ h = torch.cat([h, hs.pop()], dim=1)
45
+ h = module(h, emb, context,uc=uc)
46
+ else:
47
+ h = torch.cat([h, hs.pop()], dim=1)
48
+ h = module(h, emb, context,uc=uc) # Attn here
49
+ h = output_block_motion_module(h, emb, context)
50
+
51
+ else:
52
+ num_input_motion_module = 0
53
+ for i, module in enumerate(self.input_blocks):
54
+ if i in [1,2,4,5,7,8,10,11]:
55
+ motion_module = self.input_blocks_motion_module[num_input_motion_module]
56
+ h, attn_index = module(h, emb, context, bank_attn, attention_mode, attn_index)
57
+ h = motion_module(h, emb, context)
58
+ num_input_motion_module += 1
59
+ else:
60
+ h, attn_index = module(h, emb, context, bank_attn, attention_mode, attn_index) # Attn here
61
+ hs.append(h)
62
+
63
+ h, attn_index = self.middle_block(h, emb, context, bank_attn, attention_mode, attn_index) # Attn here
64
+
65
+ amplify_f = 1.
66
+
67
+ if pose_control is not None:
68
+ h += pose_control.pop() * amplify_f
69
+
70
+ if local_pose_control is not None:
71
+ h += local_pose_control.pop() * amplify_f
72
+
73
+ for i, module in enumerate(self.output_blocks):
74
+ output_block_motion_module = self.output_blocks_motion_module[i]
75
+ if only_mid_control or (bank_attn is None):
76
+ h = torch.cat([h, hs.pop()], dim=1)
77
+ h = module(h, emb, context)
78
+ else:
79
+ if pose_control is not None and local_pose_control is not None:
80
+ h = torch.cat([h, hs.pop() + pose_control.pop() * amplify_f + local_pose_control.pop() * amplify_f], dim=1)
81
+ elif pose_control is not None:
82
+ h = torch.cat([h, hs.pop() + pose_control.pop() * amplify_f], dim=1)
83
+ elif local_pose_control is not None:
84
+ h = torch.cat([h, hs.pop() + local_pose_control.pop() * amplify_f], dim=1)
85
+ else:
86
+ h = torch.cat([h, hs.pop()], dim=1)
87
+
88
+ h, attn_index = module(h, emb, context, bank_attn, attention_mode, attn_index) # Attn here
89
+ h = output_block_motion_module(h, emb, context)
90
+
91
+ h = h.type(x.dtype)
92
+ return self.out(h)
93
+
94
+
95
+ ## ControlNet Reference Only-Like Attention
96
+ class ControlNetReferenceOnly(nn.Module):
97
+ def __init__(
98
+ self,
99
+ image_size,
100
+ in_channels,
101
+ model_channels,
102
+ hint_channels,
103
+ out_channels,
104
+ num_res_blocks,
105
+ attention_resolutions,
106
+ dropout=0,
107
+ channel_mult=(1, 2, 4, 8),
108
+ conv_resample=True,
109
+ dims=2,
110
+ use_checkpoint=False,
111
+ use_fp16=False,
112
+ num_heads=-1,
113
+ num_head_channels=-1,
114
+ num_heads_upsample=-1,
115
+ use_scale_shift_norm=False,
116
+ resblock_updown=False,
117
+ use_new_attention_order=False,
118
+ use_spatial_transformer=False, # custom transformer support
119
+ transformer_depth=1, # custom transformer support
120
+ context_dim=None, # custom transformer support
121
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
122
+ legacy=True,
123
+ disable_self_attentions=None,
124
+ num_attention_blocks=None,
125
+ disable_middle_self_attn=False,
126
+ use_linear_in_transformer=False,
127
+ ):
128
+ super().__init__()
129
+ if use_spatial_transformer:
130
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
131
+
132
+ if context_dim is not None:
133
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
134
+ from omegaconf.listconfig import ListConfig
135
+ if type(context_dim) == ListConfig:
136
+ context_dim = list(context_dim)
137
+
138
+ if num_heads_upsample == -1:
139
+ num_heads_upsample = num_heads
140
+
141
+ if num_heads == -1:
142
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
143
+
144
+ if num_head_channels == -1:
145
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
146
+
147
+ self.dims = dims
148
+ self.image_size = image_size
149
+ self.in_channels = in_channels
150
+ self.out_channels = out_channels
151
+ self.model_channels = model_channels
152
+ if isinstance(num_res_blocks, int):
153
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
154
+ else:
155
+ if len(num_res_blocks) != len(channel_mult):
156
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
157
+ "as a list/tuple (per-level) with the same length as channel_mult")
158
+ self.num_res_blocks = num_res_blocks
159
+ if disable_self_attentions is not None:
160
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
161
+ assert len(disable_self_attentions) == len(channel_mult)
162
+ if num_attention_blocks is not None:
163
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
164
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
165
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
166
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
167
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
168
+ f"attention will still not be set.")
169
+
170
+ self.attention_resolutions = attention_resolutions
171
+ self.dropout = dropout
172
+ self.channel_mult = channel_mult
173
+ self.conv_resample = conv_resample
174
+ self.use_checkpoint = use_checkpoint
175
+ self.dtype = th.float16 if use_fp16 else th.float32
176
+ self.num_heads = num_heads
177
+ self.num_head_channels = num_head_channels
178
+ self.num_heads_upsample = num_heads_upsample
179
+ self.predict_codebook_ids = n_embed is not None
180
+
181
+ time_embed_dim = model_channels * 4
182
+ self.time_embed = nn.Sequential(
183
+ linear(model_channels, time_embed_dim),
184
+ nn.SiLU(),
185
+ linear(time_embed_dim, time_embed_dim),
186
+ )
187
+
188
+ self.input_blocks = nn.ModuleList(
189
+ [
190
+ TimestepEmbedSequential(
191
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
192
+ )
193
+ ]
194
+ )
195
+
196
+ self.input_hint_block = TimestepEmbedSequential(
197
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
198
+ nn.SiLU(),
199
+ conv_nd(dims, 16, 16, 3, padding=1),
200
+ nn.SiLU(),
201
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
202
+ nn.SiLU(),
203
+ conv_nd(dims, 32, 32, 3, padding=1),
204
+ nn.SiLU(),
205
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
206
+ nn.SiLU(),
207
+ conv_nd(dims, 96, 96, 3, padding=1),
208
+ nn.SiLU(),
209
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
210
+ nn.SiLU(),
211
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
212
+ )
213
+
214
+ self._feature_size = model_channels
215
+ input_block_chans = [model_channels]
216
+ ch = model_channels
217
+ ds = 1
218
+ for level, mult in enumerate(channel_mult):
219
+ for nr in range(self.num_res_blocks[level]):
220
+ layers = [
221
+ ResBlock(
222
+ ch,
223
+ time_embed_dim,
224
+ dropout,
225
+ out_channels=mult * model_channels,
226
+ dims=dims,
227
+ use_checkpoint=use_checkpoint,
228
+ use_scale_shift_norm=use_scale_shift_norm,
229
+ )
230
+ ]
231
+ ch = mult * model_channels
232
+ if ds in attention_resolutions:
233
+ if num_head_channels == -1:
234
+ dim_head = ch // num_heads
235
+ else:
236
+ num_heads = ch // num_head_channels
237
+ dim_head = num_head_channels
238
+ if legacy:
239
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
240
+ if exists(disable_self_attentions):
241
+ disabled_sa = disable_self_attentions[level]
242
+ else:
243
+ disabled_sa = False
244
+
245
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
246
+ layers.append(
247
+ AttentionBlock(
248
+ ch,
249
+ use_checkpoint=use_checkpoint,
250
+ num_heads=num_heads,
251
+ num_head_channels=dim_head,
252
+ use_new_attention_order=use_new_attention_order,
253
+ ) if not use_spatial_transformer else SpatialTransformer(
254
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
255
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
256
+ use_checkpoint=use_checkpoint
257
+ )
258
+ )
259
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
260
+ self._feature_size += ch
261
+ input_block_chans.append(ch)
262
+ if level != len(channel_mult) - 1:
263
+ out_ch = ch
264
+ self.input_blocks.append(
265
+ TimestepEmbedSequential(
266
+ ResBlock(
267
+ ch,
268
+ time_embed_dim,
269
+ dropout,
270
+ out_channels=out_ch,
271
+ dims=dims,
272
+ use_checkpoint=use_checkpoint,
273
+ use_scale_shift_norm=use_scale_shift_norm,
274
+ down=True,
275
+ )
276
+ if resblock_updown
277
+ else Downsample(
278
+ ch, conv_resample, dims=dims, out_channels=out_ch
279
+ )
280
+ )
281
+ )
282
+ ch = out_ch
283
+ input_block_chans.append(ch)
284
+ ds *= 2
285
+ self._feature_size += ch
286
+
287
+ if num_head_channels == -1:
288
+ dim_head = ch // num_heads
289
+ else:
290
+ num_heads = ch // num_head_channels
291
+ dim_head = num_head_channels
292
+ if legacy:
293
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
294
+ self.middle_block = TimestepEmbedSequential(
295
+ ResBlock(
296
+ ch,
297
+ time_embed_dim,
298
+ dropout,
299
+ dims=dims,
300
+ use_checkpoint=use_checkpoint,
301
+ use_scale_shift_norm=use_scale_shift_norm,
302
+ ),
303
+ AttentionBlock(
304
+ ch,
305
+ use_checkpoint=use_checkpoint,
306
+ num_heads=num_heads,
307
+ num_head_channels=dim_head,
308
+ use_new_attention_order=use_new_attention_order,
309
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
310
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
311
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
312
+ use_checkpoint=use_checkpoint
313
+ ),
314
+ ResBlock(
315
+ ch,
316
+ time_embed_dim,
317
+ dropout,
318
+ dims=dims,
319
+ use_checkpoint=use_checkpoint,
320
+ use_scale_shift_norm=use_scale_shift_norm,
321
+ ),
322
+ )
323
+ self._feature_size += ch
324
+
325
+
326
+ self.output_blocks = nn.ModuleList([])
327
+ for level, mult in list(enumerate(channel_mult))[::-1]:
328
+ for i in range(self.num_res_blocks[level] + 1):
329
+ ich = input_block_chans.pop()
330
+ layers = [
331
+ ResBlock(
332
+ ch + ich,
333
+ time_embed_dim,
334
+ dropout,
335
+ out_channels=model_channels * mult,
336
+ dims=dims,
337
+ use_checkpoint=use_checkpoint,
338
+ use_scale_shift_norm=use_scale_shift_norm,
339
+ )
340
+ ]
341
+ ch = model_channels * mult
342
+ if ds in attention_resolutions:
343
+ if num_head_channels == -1:
344
+ dim_head = ch // num_heads
345
+ else:
346
+ num_heads = ch // num_head_channels
347
+ dim_head = num_head_channels
348
+ if legacy:
349
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
350
+ if exists(disable_self_attentions):
351
+ disabled_sa = disable_self_attentions[level]
352
+ else:
353
+ disabled_sa = False
354
+
355
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
356
+ layers.append(
357
+ AttentionBlock(
358
+ ch,
359
+ use_checkpoint=use_checkpoint,
360
+ num_heads=num_heads_upsample,
361
+ num_head_channels=dim_head,
362
+ use_new_attention_order=use_new_attention_order,
363
+ ) if not use_spatial_transformer else SpatialTransformer(
364
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
365
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
366
+ use_checkpoint=use_checkpoint
367
+ )
368
+ )
369
+ if level and i == self.num_res_blocks[level]:
370
+ out_ch = ch
371
+ layers.append(
372
+ ResBlock(
373
+ ch,
374
+ time_embed_dim,
375
+ dropout,
376
+ out_channels=out_ch,
377
+ dims=dims,
378
+ use_checkpoint=use_checkpoint,
379
+ use_scale_shift_norm=use_scale_shift_norm,
380
+ up=True,
381
+ )
382
+ if resblock_updown
383
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
384
+ )
385
+ ds //= 2
386
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
387
+ self._feature_size += ch
388
+
389
+
390
+ def forward(self, x, hint, timesteps, context, attention_bank=None, attention_mode=None,uc=False, **kwargs):
391
+ hs = []
392
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
393
+ emb = self.time_embed(t_emb)
394
+ banks = attention_bank
395
+ outs = []
396
+ h = x.type(self.dtype)
397
+ for module in self.input_blocks:
398
+ h = module(h, emb, context, banks, attention_mode,uc)
399
+ hs.append(h)
400
+
401
+ h = self.middle_block(h, emb, context, banks, attention_mode,uc)
402
+
403
+ for module in self.output_blocks:
404
+ h = th.cat([h, hs.pop()], dim=1)
405
+ h = module(h, emb, context, banks, attention_mode,uc)
406
+
407
+ return outs
408
+
409
+ ### ControlNet Origin
410
+ class ControlNet(nn.Module):
411
+ def __init__(
412
+ self,
413
+ image_size,
414
+ in_channels,
415
+ model_channels,
416
+ hint_channels,
417
+ num_res_blocks,
418
+ attention_resolutions,
419
+ dropout=0,
420
+ channel_mult=(1, 2, 4, 8),
421
+ conv_resample=True,
422
+ dims=2,
423
+ use_checkpoint=False,
424
+ use_fp16=False,
425
+ num_heads=-1,
426
+ num_head_channels=-1,
427
+ num_heads_upsample=-1,
428
+ use_scale_shift_norm=False,
429
+ resblock_updown=False,
430
+ use_new_attention_order=False,
431
+ use_spatial_transformer=False, # custom transformer support
432
+ transformer_depth=1, # custom transformer support
433
+ context_dim=None, # custom transformer support
434
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
435
+ legacy=True,
436
+ disable_self_attentions=None,
437
+ num_attention_blocks=None,
438
+ disable_middle_self_attn=False,
439
+ use_linear_in_transformer=False,
440
+ ):
441
+ super().__init__()
442
+ if use_spatial_transformer:
443
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
444
+
445
+ if context_dim is not None:
446
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
447
+ from omegaconf.listconfig import ListConfig
448
+ if type(context_dim) == ListConfig:
449
+ context_dim = list(context_dim)
450
+
451
+ if num_heads_upsample == -1:
452
+ num_heads_upsample = num_heads
453
+
454
+ if num_heads == -1:
455
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
456
+
457
+ if num_head_channels == -1:
458
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
459
+
460
+ self.dims = dims
461
+ self.image_size = image_size
462
+ self.in_channels = in_channels
463
+ self.model_channels = model_channels
464
+ if isinstance(num_res_blocks, int):
465
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
466
+ else:
467
+ if len(num_res_blocks) != len(channel_mult):
468
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
469
+ "as a list/tuple (per-level) with the same length as channel_mult")
470
+ self.num_res_blocks = num_res_blocks
471
+ if disable_self_attentions is not None:
472
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
473
+ assert len(disable_self_attentions) == len(channel_mult)
474
+ if num_attention_blocks is not None:
475
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
476
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
477
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
478
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
479
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
480
+ f"attention will still not be set.")
481
+
482
+ self.attention_resolutions = attention_resolutions
483
+ self.dropout = dropout
484
+ self.channel_mult = channel_mult
485
+ self.conv_resample = conv_resample
486
+ self.use_checkpoint = use_checkpoint
487
+ self.dtype = th.float16 if use_fp16 else th.float32
488
+ self.num_heads = num_heads
489
+ self.num_head_channels = num_head_channels
490
+ self.num_heads_upsample = num_heads_upsample
491
+ self.predict_codebook_ids = n_embed is not None
492
+
493
+ time_embed_dim = model_channels * 4
494
+ self.time_embed = nn.Sequential(
495
+ linear(model_channels, time_embed_dim),
496
+ nn.SiLU(),
497
+ linear(time_embed_dim, time_embed_dim),
498
+ )
499
+
500
+ self.input_blocks = nn.ModuleList(
501
+ [
502
+ TimestepEmbedSequential(
503
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
504
+ )
505
+ ]
506
+ )
507
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
508
+
509
+ self.input_hint_block = TimestepEmbedSequential(
510
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
511
+ nn.SiLU(),
512
+ conv_nd(dims, 16, 16, 3, padding=1),
513
+ nn.SiLU(),
514
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
515
+ nn.SiLU(),
516
+ conv_nd(dims, 32, 32, 3, padding=1),
517
+ nn.SiLU(),
518
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
519
+ nn.SiLU(),
520
+ conv_nd(dims, 96, 96, 3, padding=1),
521
+ nn.SiLU(),
522
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
523
+ nn.SiLU(),
524
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
525
+ )
526
+
527
+ self._feature_size = model_channels
528
+ input_block_chans = [model_channels]
529
+ ch = model_channels
530
+ ds = 1
531
+ for level, mult in enumerate(channel_mult):
532
+ for nr in range(self.num_res_blocks[level]):
533
+ layers = [
534
+ ResBlock(
535
+ ch,
536
+ time_embed_dim,
537
+ dropout,
538
+ out_channels=mult * model_channels,
539
+ dims=dims,
540
+ use_checkpoint=use_checkpoint,
541
+ use_scale_shift_norm=use_scale_shift_norm,
542
+ )
543
+ ]
544
+ ch = mult * model_channels
545
+ if ds in attention_resolutions:
546
+ if num_head_channels == -1:
547
+ dim_head = ch // num_heads
548
+ else:
549
+ num_heads = ch // num_head_channels
550
+ dim_head = num_head_channels
551
+ if legacy:
552
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
553
+ if exists(disable_self_attentions):
554
+ disabled_sa = disable_self_attentions[level]
555
+ else:
556
+ disabled_sa = False
557
+
558
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
559
+ layers.append(
560
+ AttentionBlock(
561
+ ch,
562
+ use_checkpoint=use_checkpoint,
563
+ num_heads=num_heads,
564
+ num_head_channels=dim_head,
565
+ use_new_attention_order=use_new_attention_order,
566
+ ) if not use_spatial_transformer else SpatialTransformer(
567
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
568
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
569
+ use_checkpoint=use_checkpoint
570
+ )
571
+ )
572
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
573
+ self.zero_convs.append(self.make_zero_conv(ch))
574
+ self._feature_size += ch
575
+ input_block_chans.append(ch)
576
+ if level != len(channel_mult) - 1:
577
+ out_ch = ch
578
+ self.input_blocks.append(
579
+ TimestepEmbedSequential(
580
+ ResBlock(
581
+ ch,
582
+ time_embed_dim,
583
+ dropout,
584
+ out_channels=out_ch,
585
+ dims=dims,
586
+ use_checkpoint=use_checkpoint,
587
+ use_scale_shift_norm=use_scale_shift_norm,
588
+ down=True,
589
+ )
590
+ if resblock_updown
591
+ else Downsample(
592
+ ch, conv_resample, dims=dims, out_channels=out_ch
593
+ )
594
+ )
595
+ )
596
+ ch = out_ch
597
+ input_block_chans.append(ch)
598
+ self.zero_convs.append(self.make_zero_conv(ch))
599
+ ds *= 2
600
+ self._feature_size += ch
601
+
602
+ if num_head_channels == -1:
603
+ dim_head = ch // num_heads
604
+ else:
605
+ num_heads = ch // num_head_channels
606
+ dim_head = num_head_channels
607
+ if legacy:
608
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
609
+ self.middle_block = TimestepEmbedSequential(
610
+ ResBlock(
611
+ ch,
612
+ time_embed_dim,
613
+ dropout,
614
+ dims=dims,
615
+ use_checkpoint=use_checkpoint,
616
+ use_scale_shift_norm=use_scale_shift_norm,
617
+ ),
618
+ AttentionBlock(
619
+ ch,
620
+ use_checkpoint=use_checkpoint,
621
+ num_heads=num_heads,
622
+ num_head_channels=dim_head,
623
+ use_new_attention_order=use_new_attention_order,
624
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
625
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
626
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
627
+ use_checkpoint=use_checkpoint
628
+ ),
629
+ ResBlock(
630
+ ch,
631
+ time_embed_dim,
632
+ dropout,
633
+ dims=dims,
634
+ use_checkpoint=use_checkpoint,
635
+ use_scale_shift_norm=use_scale_shift_norm,
636
+ ),
637
+ )
638
+ self.middle_block_out = self.make_zero_conv(ch)
639
+ self._feature_size += ch
640
+
641
+ def make_zero_conv(self, channels):
642
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
643
+
644
+ def forward(self, x, hint, timesteps, context, **kwargs):
645
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
646
+ emb = self.time_embed(t_emb)
647
+
648
+ guided_hint = self.input_hint_block(hint, emb, context)
649
+
650
+ outs = []
651
+ h = x.type(self.dtype)
652
+
653
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
654
+ if guided_hint is not None:
655
+ h = module(h, emb, context)
656
+ h += guided_hint
657
+ guided_hint = None
658
+ else:
659
+ h = module(h, emb, context)
660
+ outs.append(zero_conv(h, emb, context))
661
+
662
+ h = self.middle_block(h, emb, context)
663
+ outs.append(self.middle_block_out(h, emb, context))
664
+
665
+ return outs
666
+
667
+ class ControlLDMReferenceOnly_Temporal_Pose_Local(LatentDiffusionReferenceOnly):
668
+
669
+ def __init__(self, control_key, only_mid_control,appearance_control_stage_config, pose_control_stage_config, local_pose_control_stage_config, *args, **kwargs):
670
+ super().__init__(*args, **kwargs)
671
+ print(args)
672
+ print(kwargs)
673
+ self.control_key = control_key
674
+ self.only_mid_control = only_mid_control
675
+ self.control_enabled = True
676
+ self.appearance_control_model = instantiate_from_config(appearance_control_stage_config)
677
+ self.pose_control_model = instantiate_from_config(pose_control_stage_config)
678
+ self.local_pose_control_model = instantiate_from_config(local_pose_control_stage_config)
679
+
680
+ def apply_model(self, x_noisy, t, cond, reference_image_noisy, more_reference_image_noisy=[], uc=False,*args, **kwargs):
681
+ assert isinstance(cond, dict)
682
+ diffusion_model = self.model.diffusion_model
683
+ cond_txt = torch.cat(cond['c_crossattn'], 1)
684
+ if self.control_enabled and 'c_crossattn_void' in cond and cond['c_crossattn_void'] is not None:
685
+ cond_txt_void = torch.cat(cond['c_crossattn_void'], 1)
686
+ else:
687
+ cond_txt_void = cond_txt
688
+ attention_bank = []
689
+
690
+ if reference_image_noisy is not None:
691
+ empty_outs = self.appearance_control_model(x=reference_image_noisy, hint=None, timesteps=t, context=cond_txt_void, attention_bank=attention_bank, attention_mode='write',uc=uc)
692
+ for m_reference_image_noisy in more_reference_image_noisy:
693
+ l_attention_bank = []
694
+ empty_outs = self.appearance_control_model(x=m_reference_image_noisy, hint=None, timesteps=t, context=cond_txt_void, attention_bank=l_attention_bank, attention_mode='write',uc=uc)
695
+ for j in range(len(attention_bank)):
696
+ for k in range(len(attention_bank[j])):
697
+ attention_bank[j][k] = torch.concat([attention_bank[j][k], l_attention_bank[j][k]], dim=1)
698
+
699
+ if not uc:
700
+ if self.control_enabled and 'c_concat' in cond and cond['c_concat'] is not None:
701
+ cond_hint = torch.cat(cond['c_concat'], 1)
702
+ pose_control = self.pose_control_model(x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt_void)
703
+
704
+ if self.control_enabled and 'local_c_concat' in cond and cond['local_c_concat'] is not None:
705
+ cond_hint = torch.cat(cond['local_c_concat'], 1)
706
+ local_pose_control = self.local_pose_control_model(x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt_void)
707
+ else:
708
+ pose_control = None
709
+ local_pose_control = None
710
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=attention_bank, pose_control=pose_control, local_pose_control=local_pose_control, only_mid_control=self.only_mid_control, attention_mode='read',uc=uc)
711
+ return eps
712
+
713
+ @torch.no_grad()
714
+ def get_unconditional_conditioning(self, N):
715
+ return self.get_learned_conditioning([""] * N)
model_lib/ControlNet/cldm/model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from omegaconf import OmegaConf
5
+ from model_lib.ControlNet.ldm.util import instantiate_from_config
6
+
7
+
8
+ def get_state_dict(d):
9
+ return d.get('state_dict', d)
10
+
11
+
12
+ def load_state_dict(ckpt_path, location='cpu'):
13
+ _, extension = os.path.splitext(ckpt_path)
14
+ if extension.lower() == ".safetensors":
15
+ import safetensors.torch
16
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
17
+ else:
18
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
19
+ state_dict = get_state_dict(state_dict)
20
+ print(f'Loaded state_dict from [{ckpt_path}]')
21
+ return state_dict
22
+
23
+
24
+ def create_model(config_path):
25
+ config = OmegaConf.load(config_path)
26
+ model = instantiate_from_config(config.model).cpu()
27
+ print(f'Loaded model config from [{config_path}]')
28
+ return model
model_lib/ControlNet/ldm/__pycache__/util.cpython-39.pyc ADDED
Binary file (6.23 kB). View file
 
model_lib/ControlNet/ldm/data/__init__.py ADDED
File without changes
model_lib/ControlNet/ldm/data/util.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from model_lib.ControlNet.ldm.modules.midas.api import load_midas_transform
4
+
5
+
6
+ class AddMiDaS(object):
7
+ def __init__(self, model_type):
8
+ super().__init__()
9
+ self.transform = load_midas_transform(model_type)
10
+
11
+ def pt2np(self, x):
12
+ x = ((x + 1.0) * .5).detach().cpu().numpy()
13
+ return x
14
+
15
+ def np2pt(self, x):
16
+ x = torch.from_numpy(x) * 2 - 1.
17
+ return x
18
+
19
+ def __call__(self, sample):
20
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
21
+ x = self.pt2np(sample['jpg'])
22
+ x = self.transform({"image": x})["image"]
23
+ sample['midas_in'] = x
24
+ return sample
model_lib/ControlNet/ldm/models/__pycache__/autoencoder.cpython-39.pyc ADDED
Binary file (7.83 kB). View file
 
model_lib/ControlNet/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+
6
+ from model_lib.ControlNet.ldm.modules.diffusionmodules.model import Encoder, Decoder
7
+ from model_lib.ControlNet.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
8
+
9
+ from model_lib.ControlNet.ldm.util import instantiate_from_config
10
+ from model_lib.ControlNet.ldm.modules.ema import LitEma
11
+
12
+
13
+ class AutoencoderKL(pl.LightningModule):
14
+ def __init__(self,
15
+ ddconfig,
16
+ lossconfig,
17
+ embed_dim,
18
+ ckpt_path=None,
19
+ ignore_keys=[],
20
+ image_key="image",
21
+ colorize_nlabels=None,
22
+ monitor=None,
23
+ ema_decay=None,
24
+ learn_logvar=False
25
+ ):
26
+ super().__init__()
27
+ self.learn_logvar = learn_logvar
28
+ self.image_key = image_key
29
+ self.encoder = Encoder(**ddconfig)
30
+ self.decoder = Decoder(**ddconfig)
31
+ self.loss = instantiate_from_config(lossconfig)
32
+ assert ddconfig["double_z"]
33
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
34
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35
+ self.embed_dim = embed_dim
36
+ if colorize_nlabels is not None:
37
+ assert type(colorize_nlabels)==int
38
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39
+ if monitor is not None:
40
+ self.monitor = monitor
41
+
42
+ self.use_ema = ema_decay is not None
43
+ if self.use_ema:
44
+ self.ema_decay = ema_decay
45
+ assert 0. < ema_decay < 1.
46
+ self.model_ema = LitEma(self, decay=ema_decay)
47
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
48
+
49
+ if ckpt_path is not None:
50
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51
+
52
+ def init_from_ckpt(self, path, ignore_keys=list()):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ keys = list(sd.keys())
55
+ for k in keys:
56
+ for ik in ignore_keys:
57
+ if k.startswith(ik):
58
+ print("Deleting key {} from state_dict.".format(k))
59
+ del sd[k]
60
+ self.load_state_dict(sd, strict=False)
61
+ print(f"Restored from {path}")
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ if self.use_ema:
80
+ self.model_ema(self)
81
+
82
+ def encode(self, x):
83
+ h = self.encoder(x)
84
+ moments = self.quant_conv(h)
85
+ posterior = DiagonalGaussianDistribution(moments)
86
+ return posterior
87
+
88
+ def decode(self, z):
89
+ z = self.post_quant_conv(z)
90
+ dec = self.decoder(z)
91
+ return dec
92
+
93
+ def forward(self, input, sample_posterior=True):
94
+ posterior = self.encode(input)
95
+ if sample_posterior:
96
+ z = posterior.sample()
97
+ else:
98
+ z = posterior.mode()
99
+ dec = self.decode(z)
100
+ return dec, posterior
101
+
102
+ def get_input(self, batch, k):
103
+ x = batch[k]
104
+ if len(x.shape) == 3:
105
+ x = x[..., None]
106
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
107
+ return x
108
+
109
+ def training_step(self, batch, batch_idx, optimizer_idx):
110
+ inputs = self.get_input(batch, self.image_key)
111
+ reconstructions, posterior = self(inputs)
112
+
113
+ if optimizer_idx == 0:
114
+ # train encoder+decoder+logvar
115
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
116
+ last_layer=self.get_last_layer(), split="train")
117
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
118
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
119
+ return aeloss
120
+
121
+ if optimizer_idx == 1:
122
+ # train the discriminator
123
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
124
+ last_layer=self.get_last_layer(), split="train")
125
+
126
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
127
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
128
+ return discloss
129
+
130
+ def validation_step(self, batch, batch_idx):
131
+ log_dict = self._validation_step(batch, batch_idx)
132
+ with self.ema_scope():
133
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
134
+ return log_dict
135
+
136
+ def _validation_step(self, batch, batch_idx, postfix=""):
137
+ inputs = self.get_input(batch, self.image_key)
138
+ reconstructions, posterior = self(inputs)
139
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
140
+ last_layer=self.get_last_layer(), split="val"+postfix)
141
+
142
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
143
+ last_layer=self.get_last_layer(), split="val"+postfix)
144
+
145
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
146
+ self.log_dict(log_dict_ae)
147
+ self.log_dict(log_dict_disc)
148
+ return self.log_dict
149
+
150
+ def configure_optimizers(self):
151
+ lr = self.learning_rate
152
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
153
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
154
+ if self.learn_logvar:
155
+ print(f"{self.__class__.__name__}: Learning logvar")
156
+ ae_params_list.append(self.loss.logvar)
157
+ opt_ae = torch.optim.Adam(ae_params_list,
158
+ lr=lr, betas=(0.5, 0.9))
159
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
160
+ lr=lr, betas=(0.5, 0.9))
161
+ return [opt_ae, opt_disc], []
162
+
163
+ def get_last_layer(self):
164
+ return self.decoder.conv_out.weight
165
+
166
+ @torch.no_grad()
167
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
168
+ log = dict()
169
+ x = self.get_input(batch, self.image_key)
170
+ x = x.to(self.device)
171
+ if not only_inputs:
172
+ xrec, posterior = self(x)
173
+ if x.shape[1] > 3:
174
+ # colorize with random projection
175
+ assert xrec.shape[1] > 3
176
+ x = self.to_rgb(x)
177
+ xrec = self.to_rgb(xrec)
178
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
179
+ log["reconstructions"] = xrec
180
+ if log_ema or self.use_ema:
181
+ with self.ema_scope():
182
+ xrec_ema, posterior_ema = self(x)
183
+ if x.shape[1] > 3:
184
+ # colorize with random projection
185
+ assert xrec_ema.shape[1] > 3
186
+ xrec_ema = self.to_rgb(xrec_ema)
187
+ log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
188
+ log["reconstructions_ema"] = xrec_ema
189
+ log["inputs"] = x
190
+ return log
191
+
192
+ def to_rgb(self, x):
193
+ assert self.image_key == "segmentation"
194
+ if not hasattr(self, "colorize"):
195
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
196
+ x = F.conv2d(x, weight=self.colorize)
197
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
198
+ return x
199
+
200
+
201
+ class IdentityFirstStage(torch.nn.Module):
202
+ def __init__(self, *args, vq_interface=False, **kwargs):
203
+ self.vq_interface = vq_interface
204
+ super().__init__()
205
+
206
+ def encode(self, x, *args, **kwargs):
207
+ return x
208
+
209
+ def decode(self, x, *args, **kwargs):
210
+ return x
211
+
212
+ def quantize(self, x, *args, **kwargs):
213
+ if self.vq_interface:
214
+ return x, None, [None, None, None]
215
+ return x
216
+
217
+ def forward(self, x, *args, **kwargs):
218
+ return x
219
+
model_lib/ControlNet/ldm/models/diffusion/__init__.py ADDED
File without changes
model_lib/ControlNet/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (228 Bytes). View file
 
model_lib/ControlNet/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc ADDED
Binary file (18.2 kB). View file
 
model_lib/ControlNet/ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc ADDED
Binary file (69 kB). View file
 
model_lib/ControlNet/ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import pdb
4
+ import random
5
+
6
+ import numpy as np
7
+ import torch
8
+ from model_lib.ControlNet.ldm.modules.diffusionmodules.util import (
9
+ extract_into_tensor, make_ddim_sampling_parameters, make_ddim_timesteps,
10
+ noise_like)
11
+ from model_lib.ControlNet.ldm.util import default
12
+ from tqdm import tqdm
13
+
14
+
15
+ class DDIMSampler(object):
16
+ def __init__(self, model, schedule="linear", **kwargs):
17
+ super().__init__()
18
+ self.model = model
19
+ self.ddpm_num_timesteps = model.num_timesteps
20
+ self.schedule = schedule
21
+
22
+ def register_buffer(self, name, attr):
23
+ if type(attr) == torch.Tensor:
24
+ if attr.device != torch.device("cuda"):
25
+ attr = attr.to(torch.device("cuda"))
26
+ setattr(self, name, attr)
27
+
28
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
29
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
30
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
31
+ alphas_cumprod = self.model.alphas_cumprod
32
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
33
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
34
+
35
+ self.register_buffer('betas', to_torch(self.model.betas))
36
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
37
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
38
+
39
+ # calculations for diffusion q(x_t | x_{t-1}) and others
40
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
42
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
43
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
44
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
45
+
46
+ # ddim sampling parameters
47
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
48
+ ddim_timesteps=self.ddim_timesteps,
49
+ eta=ddim_eta,verbose=verbose)
50
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
51
+ self.register_buffer('ddim_alphas', ddim_alphas)
52
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
53
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
54
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
55
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
56
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
57
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
58
+
59
+ @torch.no_grad()
60
+ def sample(self,
61
+ S,
62
+ batch_size,
63
+ shape,
64
+ conditioning=None,
65
+ callback=None,
66
+ normals_sequence=None,
67
+ img_callback=None,
68
+ quantize_x0=False,
69
+ eta=0.,
70
+ mask=None,
71
+ x0=None,
72
+ temperature=1.,
73
+ noise_dropout=0.,
74
+ score_corrector=None,
75
+ corrector_kwargs=None,
76
+ verbose=True,
77
+ x_T=None,
78
+ log_every_t=100,
79
+ unconditional_guidance_scale=1.,
80
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81
+ dynamic_threshold=None,
82
+ ucg_schedule=None,
83
+ inpaint=None,
84
+ **kwargs
85
+ ):
86
+ if conditioning is not None:
87
+ if isinstance(conditioning, dict):
88
+ ctmp = conditioning[list(conditioning.keys())[0]]
89
+ while isinstance(ctmp, list): ctmp = ctmp[0]
90
+ cbs = ctmp.shape[0]
91
+ if cbs != batch_size:
92
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
93
+
94
+ elif isinstance(conditioning, list):
95
+ for ctmp in conditioning:
96
+ if ctmp.shape[0] != batch_size:
97
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
98
+
99
+ else:
100
+ if conditioning.shape[0] != batch_size:
101
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
102
+
103
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
104
+ # sampling
105
+ C, H, W = shape
106
+ size = (batch_size, C, H, W)
107
+ print(f'Data shape for DDIM sampling is {C, H, W}')
108
+
109
+ samples, intermediates = self.ddim_sampling(conditioning, size,
110
+ callback=callback,
111
+ img_callback=img_callback,
112
+ quantize_denoised=quantize_x0,
113
+ mask=mask, x0=x0,
114
+ ddim_use_original_steps=False,
115
+ noise_dropout=noise_dropout,
116
+ temperature=temperature,
117
+ score_corrector=score_corrector,
118
+ corrector_kwargs=corrector_kwargs,
119
+ x_T=x_T,
120
+ log_every_t=log_every_t,
121
+ unconditional_guidance_scale=unconditional_guidance_scale,
122
+ unconditional_conditioning=unconditional_conditioning,
123
+ dynamic_threshold=dynamic_threshold,
124
+ ucg_schedule=ucg_schedule,
125
+ inpaint=inpaint
126
+ )
127
+ return samples, intermediates
128
+
129
+ @torch.no_grad()
130
+ def ddim_sampling(self, cond, shape,
131
+ x_T=None, ddim_use_original_steps=False,
132
+ callback=None, timesteps=None, quantize_denoised=False,
133
+ mask=None, x0=None, img_callback=None, log_every_t=100,
134
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
135
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
136
+ ucg_schedule=None,inpaint=None):
137
+ device = self.model.betas.device
138
+ b = shape[0]
139
+ if x_T is None:
140
+ img = torch.randn(shape, device=device)
141
+ else:
142
+ img = x_T
143
+
144
+ if timesteps is None:
145
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
146
+ elif timesteps is not None and not ddim_use_original_steps:
147
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
148
+ timesteps = self.ddim_timesteps[:subset_end]
149
+
150
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
151
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
152
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
153
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
154
+
155
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
156
+
157
+ for i, step in enumerate(iterator):
158
+ index = total_steps - i - 1
159
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
160
+
161
+ if mask is not None:
162
+ assert x0 is not None
163
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
164
+ img = img_orig * mask + (1. - mask) * img
165
+
166
+ if ucg_schedule is not None:
167
+ assert len(ucg_schedule) == len(time_range)
168
+ unconditional_guidance_scale = ucg_schedule[i]
169
+
170
+ model_output = self.p_sample_ddim(img, cond, ts,
171
+ unconditional_guidance_scale=unconditional_guidance_scale,
172
+ unconditional_conditioning=unconditional_conditioning,
173
+ inpaint=inpaint)
174
+ outs = self.pred_x_prev_from_eps(img, cond, ts, model_output, index=index, use_original_steps=ddim_use_original_steps,
175
+ quantize_denoised=quantize_denoised, temperature=temperature,
176
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
177
+ corrector_kwargs=corrector_kwargs,
178
+ dynamic_threshold=dynamic_threshold)
179
+ img, pred_x0 = outs
180
+ if callback: callback(i)
181
+ if img_callback: img_callback(pred_x0, i)
182
+
183
+ if index % log_every_t == 0 or index == total_steps - 1:
184
+ intermediates['x_inter'].append(img)
185
+ intermediates['pred_x0'].append(pred_x0)
186
+
187
+ return img, intermediates
188
+
189
+ @torch.no_grad()
190
+ def p_sample_ddim(self, x, c, t, unconditional_guidance_scale=1., unconditional_conditioning=None, inpaint=None):
191
+
192
+ if inpaint is None:
193
+ x_In = x
194
+ else:
195
+ x_In = torch.cat([x,inpaint],dim=1)
196
+
197
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
198
+ model_output = self.model.apply_model(x_In, t, c)
199
+ else:
200
+ x_in = torch.cat([x_In] * 2)
201
+ t_in = torch.cat([t] * 2)
202
+ if isinstance(c, dict):
203
+ assert isinstance(unconditional_conditioning, dict)
204
+ c_in = dict()
205
+ for k in c:
206
+ if isinstance(c[k], list):
207
+ c_in[k] = [torch.cat([
208
+ unconditional_conditioning[k][i],
209
+ c[k][i]]) for i in range(len(c[k]))]
210
+ else:
211
+ c_in[k] = torch.cat([
212
+ unconditional_conditioning[k],
213
+ c[k]])
214
+ elif isinstance(c, list):
215
+ c_in = list()
216
+ assert isinstance(unconditional_conditioning, list)
217
+ for i in range(len(c)):
218
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
219
+ else:
220
+ c_in = torch.cat([unconditional_conditioning, c])
221
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) # , reference_image_noisy
222
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
223
+
224
+ return model_output
225
+
226
+ @torch.no_grad()
227
+ def pred_x_prev_from_eps(self, x, c, t, model_output, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
228
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
229
+ dynamic_threshold=None):
230
+ b, *_, device = *x.shape, x.device
231
+ if self.model.parameterization == "v":
232
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
233
+ else:
234
+ e_t = model_output
235
+
236
+ if score_corrector is not None:
237
+ assert self.model.parameterization == "eps", 'not implemented'
238
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
239
+
240
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
241
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
242
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
243
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
244
+ # select parameters corresponding to the currently considered timestep
245
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
246
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
247
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
248
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
249
+ # current prediction for x_0
250
+ if self.model.parameterization != "v":
251
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
252
+ else:
253
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
254
+
255
+ if quantize_denoised:
256
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
257
+
258
+ if dynamic_threshold is not None:
259
+ raise NotImplementedError()
260
+
261
+ # direction pointing to x_t
262
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
263
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
264
+ if noise_dropout > 0.:
265
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
266
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
267
+ return x_prev, pred_x0
268
+
269
+ @torch.no_grad()
270
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
271
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
272
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
273
+
274
+ assert t_enc <= num_reference_steps
275
+ num_steps = t_enc
276
+
277
+ if use_original_steps:
278
+ alphas_next = self.alphas_cumprod[:num_steps]
279
+ alphas = self.alphas_cumprod_prev[:num_steps]
280
+ else:
281
+ alphas_next = self.ddim_alphas[:num_steps]
282
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
283
+
284
+ x_next = x0
285
+ intermediates = []
286
+ inter_steps = []
287
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
288
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
289
+ if unconditional_guidance_scale == 1.:
290
+ noise_pred = self.model.apply_model(x_next, t, c)
291
+ else:
292
+ assert unconditional_conditioning is not None
293
+ e_t_uncond, noise_pred = torch.chunk(
294
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
295
+ torch.cat((unconditional_conditioning, c))), 2)
296
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
297
+
298
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
299
+ weighted_noise_pred = alphas_next[i].sqrt() * (
300
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
301
+ x_next = xt_weighted + weighted_noise_pred
302
+ if return_intermediates and i % (
303
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
304
+ intermediates.append(x_next)
305
+ inter_steps.append(i)
306
+ elif return_intermediates and i >= num_steps - 2:
307
+ intermediates.append(x_next)
308
+ inter_steps.append(i)
309
+ if callback: callback(i)
310
+
311
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
312
+ if return_intermediates:
313
+ out.update({'intermediates': intermediates})
314
+ return x_next, out
315
+
316
+ @torch.no_grad()
317
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
318
+ # fast, but does not allow for exact reconstruction
319
+ # t serves as an index to gather the correct alphas
320
+ if use_original_steps:
321
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
322
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
323
+ else:
324
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
325
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
326
+
327
+ if noise is None:
328
+ noise = torch.randn_like(x0)
329
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
330
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
331
+
332
+ @torch.no_grad()
333
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
334
+ use_original_steps=False, callback=None, inpaint=None):
335
+
336
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
337
+ timesteps = timesteps[:t_start]
338
+
339
+ time_range = np.flip(timesteps)
340
+ total_steps = timesteps.shape[0]
341
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
342
+
343
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
344
+ x_dec = x_latent
345
+ for i, step in enumerate(iterator):
346
+ index = total_steps - i - 1
347
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
348
+ model_output = self.p_sample_ddim(x_dec, cond, ts,
349
+ unconditional_guidance_scale=unconditional_guidance_scale,
350
+ unconditional_conditioning=unconditional_conditioning, inpaint=inpaint)
351
+ x_dec, _ = self.pred_x_prev_from_eps(x_dec, cond, ts, model_output, index=index, use_original_steps=use_original_steps)
352
+
353
+ if callback: callback(i)
354
+ return x_dec
355
+
356
+
357
+ class DDIMSampler_ReferenceOnly(object):
358
+ def __init__(self, model, schedule="linear", **kwargs):
359
+ super().__init__()
360
+ self.model = model
361
+ self.ddpm_num_timesteps = model.num_timesteps
362
+ self.schedule = schedule
363
+
364
+ def register_buffer(self, name, attr):
365
+ if type(attr) == torch.Tensor:
366
+ if attr.device != torch.device("cuda"):
367
+ attr = attr.to(torch.device("cuda"))
368
+ setattr(self, name, attr)
369
+
370
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
371
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
372
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
373
+ alphas_cumprod = self.model.alphas_cumprod
374
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
375
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
376
+
377
+ self.register_buffer('betas', to_torch(self.model.betas))
378
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
379
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
380
+
381
+ # calculations for diffusion q(x_t | x_{t-1}) and others
382
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
383
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
384
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
385
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
386
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
387
+
388
+ # ddim sampling parameters
389
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
390
+ ddim_timesteps=self.ddim_timesteps,
391
+ eta=ddim_eta,verbose=verbose)
392
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
393
+ self.register_buffer('ddim_alphas', ddim_alphas)
394
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
395
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
396
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
397
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
398
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
399
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
400
+
401
+ @torch.no_grad()
402
+ def sample(self,
403
+ S,
404
+ batch_size,
405
+ shape,
406
+ conditioning=None,
407
+ callback=None,
408
+ normals_sequence=None,
409
+ img_callback=None,
410
+ quantize_x0=False,
411
+ eta=0.,
412
+ mask=None,
413
+ x0=None,
414
+ temperature=1.,
415
+ noise_dropout=0.,
416
+ score_corrector=None,
417
+ corrector_kwargs=None,
418
+ verbose=True,
419
+ x_T=None,
420
+ log_every_t=100,
421
+ unconditional_guidance_scale=1.,
422
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
423
+ dynamic_threshold=None,
424
+ ucg_schedule=None,
425
+ inpaint=None,
426
+ num_overlap=0,
427
+ **kwargs
428
+ ):
429
+ if conditioning is not None:
430
+ if isinstance(conditioning, dict):
431
+ ctmp = conditioning[list(conditioning.keys())[0]]
432
+ while isinstance(ctmp, list): ctmp = ctmp[0]
433
+ cbs = ctmp.shape[0]
434
+ if cbs != batch_size:
435
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
436
+
437
+ elif isinstance(conditioning, list):
438
+ for ctmp in conditioning:
439
+ if ctmp.shape[0] != batch_size:
440
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
441
+
442
+ else:
443
+ if conditioning.shape[0] != batch_size:
444
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
445
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
446
+ # sampling
447
+ C, H, W = shape
448
+ size = (batch_size, C, H, W)
449
+ print(f'Data shape for DDIM sampling is {C, H, W}')
450
+
451
+ samples, intermediates = self.ddim_sampling(conditioning, size,
452
+ callback=callback,
453
+ img_callback=img_callback,
454
+ quantize_denoised=quantize_x0,
455
+ mask=mask, x0=x0,
456
+ ddim_use_original_steps=False,
457
+ noise_dropout=noise_dropout,
458
+ temperature=temperature,
459
+ score_corrector=score_corrector,
460
+ corrector_kwargs=corrector_kwargs,
461
+ x_T=x_T,
462
+ log_every_t=log_every_t,
463
+ unconditional_guidance_scale=unconditional_guidance_scale,
464
+ unconditional_conditioning=unconditional_conditioning,
465
+ dynamic_threshold=dynamic_threshold,
466
+ ucg_schedule=ucg_schedule,
467
+ inpaint=inpaint,
468
+ num_overlap=num_overlap
469
+ )
470
+ return samples, intermediates
471
+
472
+ @torch.no_grad()
473
+ def ddim_sampling(self, cond, shape,
474
+ x_T=None, ddim_use_original_steps=False,
475
+ callback=None, timesteps=None, quantize_denoised=False,
476
+ mask=None, x0=None, img_callback=None, log_every_t=100,
477
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
478
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
479
+ ucg_schedule=None,inpaint=None,num_overlap=0):
480
+ device = self.model.betas.device
481
+ b = shape[0]
482
+ if x_T is None:
483
+ img = torch.randn(shape, device=device)
484
+ else:
485
+ img = x_T
486
+
487
+ if timesteps is None:
488
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
489
+ elif timesteps is not None and not ddim_use_original_steps:
490
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
491
+ timesteps = self.ddim_timesteps[:subset_end]
492
+
493
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
494
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
495
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
496
+
497
+
498
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
499
+
500
+ num_frames = img.shape[0]
501
+
502
+ for i, step in enumerate(iterator):
503
+ index = total_steps - i - 1
504
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
505
+
506
+ if mask is not None:
507
+ assert x0 is not None
508
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
509
+ img = img_orig * mask + (1. - mask) * img
510
+ if ucg_schedule is not None:
511
+ assert len(ucg_schedule) == len(time_range)
512
+ unconditional_guidance_scale = ucg_schedule[i]
513
+ if num_overlap == 0:
514
+ model_output = self.p_sample_ddim(img, cond, ts, unconditional_guidance_scale=unconditional_guidance_scale,
515
+ unconditional_conditioning=unconditional_conditioning, inpaint=inpaint)
516
+ else:
517
+ model_output_all = torch.zeros_like(img)
518
+ counts = torch.zeros(num_frames).cuda()
519
+ offset = random.randint(0, num_frames-1)
520
+ skip = b - num_overlap
521
+ for start_idx in range(offset, offset+num_frames-num_overlap, skip):
522
+ indices = torch.arange(start_idx, start_idx + b) % num_frames
523
+ sel_cond = {}
524
+ for k, v in cond.items():
525
+ if isinstance(v, list) and k != 'more_image_control':
526
+ sel_cond[k] = [c[indices] for c in v]
527
+ elif k == 'more_image_control':
528
+ num_more_refs = len(v)
529
+ sel_cond[k] = []
530
+ for i in range(num_more_refs):
531
+ sel_cond[k].append([c[indices] for c in v[i]])
532
+ else:
533
+ sel_cond[k] = v
534
+ sel_uncond = {}
535
+ for k, v in unconditional_conditioning.items():
536
+ if isinstance(v, list) and k != 'more_image_control':
537
+ sel_uncond[k] = [c[indices] for c in v]
538
+ elif k == 'more_image_control':
539
+ num_more_refs = len(v)
540
+ sel_uncond[k] = []
541
+ for i in range(num_more_refs):
542
+ sel_uncond[k].append([c[indices] for c in v[i]])
543
+ else:
544
+ sel_uncond[k] = v
545
+ model_output = self.p_sample_ddim(img[indices], sel_cond, ts, unconditional_guidance_scale=unconditional_guidance_scale,
546
+ unconditional_conditioning=sel_uncond, inpaint=inpaint)
547
+ model_output_all[indices] += model_output
548
+ counts[indices] += 1
549
+ model_output = model_output_all / counts.reshape(-1, 1, 1, 1)
550
+
551
+ outs = self.pred_x_prev_from_eps(img, cond, ts, model_output, index=index, temperature=temperature,
552
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
553
+ corrector_kwargs=corrector_kwargs, dynamic_threshold=dynamic_threshold)
554
+
555
+ img, pred_x0 = outs
556
+ if callback: callback(i)
557
+ if img_callback: img_callback(pred_x0, i)
558
+
559
+ if index % log_every_t == 0 or index == total_steps - 1:
560
+ intermediates['x_inter'].append(img)
561
+ intermediates['pred_x0'].append(pred_x0)
562
+
563
+ return img, intermediates
564
+
565
+ @torch.no_grad()
566
+ def p_sample_ddim(self, x, c, t, unconditional_guidance_scale=1., unconditional_conditioning=None, inpaint=None):
567
+ if inpaint is None:
568
+ x_In = x
569
+ else:
570
+ x_In = torch.cat([x,inpaint],dim=1)
571
+
572
+ if 'image_control' in c and c['image_control'] is not None:
573
+ cond_image_start = torch.cat(c['image_control'], 1)
574
+ if c['wonoise']:
575
+ reference_image_noisy = cond_image_start
576
+ else:
577
+ reference_image_noisy = self.model.q_sample(cond_image_start,t)
578
+
579
+ more_reference_image_noisy = []
580
+ if 'more_image_control' in c and c['more_image_control'] is not None:
581
+ num_additional_ref_imgs = len(c['more_image_control'])
582
+ for i in range(num_additional_ref_imgs):
583
+ m_ref_img_noisy = torch.cat(c['more_image_control'][i], 1)
584
+ if not c['wonoise']:
585
+ m_ref_img_noisy = self.model.q_sample(m_ref_img_noisy, t)
586
+ more_reference_image_noisy.append(m_ref_img_noisy)
587
+
588
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
589
+ model_output = self.model.apply_model(x_In, t, c)
590
+ else:
591
+ if 'image_control' in unconditional_conditioning and unconditional_conditioning['image_control'] is not None:
592
+ x_in = torch.cat([x_In] * 2)
593
+ t_in = torch.cat([t] * 2)
594
+ reference_image_noisy_in = torch.cat([reference_image_noisy] * 2)
595
+ more_reference_image_noisy = [torch.cat([m_ref_img] * 2) for m_ref_img in more_reference_image_noisy]
596
+ if isinstance(c, dict):
597
+ assert isinstance(unconditional_conditioning, dict)
598
+ c_in = dict()
599
+ for k in c:
600
+ if isinstance(c[k], list):
601
+ c_in[k] = [torch.cat([
602
+ unconditional_conditioning[k][i],
603
+ c[k][i]]) for i in range(len(c[k]))]
604
+ else:
605
+ try:
606
+ c_in[k] = torch.cat([
607
+ unconditional_conditioning[k],
608
+ c[k]])
609
+ except:
610
+ c_in[k] = unconditional_conditioning[k]
611
+ elif isinstance(c, list):
612
+ c_in = list()
613
+ assert isinstance(unconditional_conditioning, list)
614
+ for i in range(len(c)):
615
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
616
+ else:
617
+ c_in = torch.cat([unconditional_conditioning, c])
618
+ # pdb.set_trace()
619
+
620
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in, reference_image_noisy_in, more_reference_image_noisy = more_reference_image_noisy).chunk(2) # , reference_image_noisy
621
+ else:
622
+ x_in = x_In
623
+ t_in = t
624
+ c_in = c
625
+ reference_image_noisy_in = reference_image_noisy
626
+ model_t = self.model.apply_model(x_in, t_in, c_in, reference_image_noisy_in, more_reference_image_noisy = more_reference_image_noisy)
627
+ model_uncond = self.model.apply_model(x_in, t_in, unconditional_conditioning, None,uc=True)
628
+ # pdb.set_trace()
629
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
630
+
631
+ return model_output
632
+
633
+ @torch.no_grad()
634
+ def pred_x_prev_from_eps(self, x, c, t, model_output, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
635
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
636
+ dynamic_threshold=None):
637
+ b, *_, device = *x.shape, x.device
638
+ if self.model.parameterization == "v":
639
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
640
+ else:
641
+ e_t = model_output
642
+ if score_corrector is not None:
643
+ assert self.model.parameterization == "eps", 'not implemented'
644
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
645
+
646
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
647
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
648
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
649
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
650
+
651
+ # select parameters corresponding to the currently considered timestep
652
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
653
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
654
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
655
+ # print ('sigma_t: {}'.format(sigma_t[0, 0, 0, 0]))
656
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
657
+ # current prediction for x_0
658
+ if self.model.parameterization != "v":
659
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
660
+ else:
661
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
662
+ if quantize_denoised:
663
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
664
+
665
+ if dynamic_threshold is not None:
666
+ raise NotImplementedError()
667
+ # direction pointing to x_t
668
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
669
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
670
+ if noise_dropout > 0.:
671
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
672
+
673
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
674
+
675
+ return x_prev, pred_x0
676
+
677
+ @torch.no_grad()
678
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
679
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
680
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
681
+
682
+ assert t_enc <= num_reference_steps
683
+ num_steps = t_enc
684
+
685
+ if use_original_steps:
686
+ alphas_next = self.alphas_cumprod[:num_steps]
687
+ alphas = self.alphas_cumprod_prev[:num_steps]
688
+ else:
689
+ alphas_next = self.ddim_alphas[:num_steps]
690
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
691
+
692
+ x_next = x0
693
+ intermediates = []
694
+ inter_steps = []
695
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
696
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
697
+ if unconditional_guidance_scale == 1.:
698
+ noise_pred = self.model.apply_model(x_next, t, c)
699
+ else:
700
+ assert unconditional_conditioning is not None
701
+ e_t_uncond, noise_pred = torch.chunk(
702
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
703
+ torch.cat((unconditional_conditioning, c))), 2)
704
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
705
+
706
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
707
+ weighted_noise_pred = alphas_next[i].sqrt() * (
708
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
709
+ x_next = xt_weighted + weighted_noise_pred
710
+ if return_intermediates and i % (
711
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
712
+ intermediates.append(x_next)
713
+ inter_steps.append(i)
714
+ elif return_intermediates and i >= num_steps - 2:
715
+ intermediates.append(x_next)
716
+ inter_steps.append(i)
717
+ if callback: callback(i)
718
+
719
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
720
+ if return_intermediates:
721
+ out.update({'intermediates': intermediates})
722
+ return x_next, out
723
+
724
+ @torch.no_grad()
725
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
726
+ # fast, but does not allow for exact reconstruction
727
+ # t serves as an index to gather the correct alphas
728
+ if use_original_steps:
729
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
730
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
731
+ else:
732
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
733
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
734
+
735
+ if noise is None:
736
+ noise = torch.randn_like(x0)
737
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
738
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
739
+
740
+ @torch.no_grad()
741
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
742
+ use_original_steps=False, callback=None, inpaint=None):
743
+
744
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
745
+ timesteps = timesteps[:t_start]
746
+
747
+ time_range = np.flip(timesteps)
748
+ total_steps = timesteps.shape[0]
749
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
750
+
751
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
752
+ x_dec = x_latent
753
+ for i, step in enumerate(iterator):
754
+ index = total_steps - i - 1
755
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
756
+
757
+ model_output = self.p_sample_ddim(x_dec, cond, ts,
758
+ unconditional_guidance_scale=unconditional_guidance_scale,
759
+ unconditional_conditioning=unconditional_conditioning, inpaint=inpaint)
760
+
761
+ x_dec, _ = self.pred_x_prev_from_eps(x_dec, cond, ts, model_output, index)
762
+ if callback: callback(i)
763
+ return x_dec
model_lib/ControlNet/ldm/models/diffusion/ddpm.py ADDED
The diff for this file is too large to render. See raw diff
 
model_lib/ControlNet/ldm/models/diffusion/dpm_solver/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampler import DPMSolverSampler
model_lib/ControlNet/ldm/models/diffusion/dpm_solver/dpm_solver.py ADDED
@@ -0,0 +1,1154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ from tqdm import tqdm
5
+
6
+
7
+ class NoiseScheduleVP:
8
+ def __init__(
9
+ self,
10
+ schedule='discrete',
11
+ betas=None,
12
+ alphas_cumprod=None,
13
+ continuous_beta_0=0.1,
14
+ continuous_beta_1=20.,
15
+ ):
16
+ """Create a wrapper class for the forward SDE (VP type).
17
+ ***
18
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
+ ***
21
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
22
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
23
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
24
+ log_alpha_t = self.marginal_log_mean_coeff(t)
25
+ sigma_t = self.marginal_std(t)
26
+ lambda_t = self.marginal_lambda(t)
27
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
28
+ t = self.inverse_lambda(lambda_t)
29
+ ===============================================================
30
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
31
+ 1. For discrete-time DPMs:
32
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
33
+ t_i = (i + 1) / N
34
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
35
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
36
+ Args:
37
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
38
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
39
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
40
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
41
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
42
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
43
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
44
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
45
+ and
46
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
47
+ 2. For continuous-time DPMs:
48
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
49
+ schedule are the default settings in DDPM and improved-DDPM:
50
+ Args:
51
+ beta_min: A `float` number. The smallest beta for the linear schedule.
52
+ beta_max: A `float` number. The largest beta for the linear schedule.
53
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
54
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
55
+ T: A `float` number. The ending time of the forward process.
56
+ ===============================================================
57
+ Args:
58
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
59
+ 'linear' or 'cosine' for continuous-time DPMs.
60
+ Returns:
61
+ A wrapper object of the forward SDE (VP type).
62
+
63
+ ===============================================================
64
+ Example:
65
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
66
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
67
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
68
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
69
+ # For continuous-time DPMs (VPSDE), linear schedule:
70
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
71
+ """
72
+
73
+ if schedule not in ['discrete', 'linear', 'cosine']:
74
+ raise ValueError(
75
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
76
+ schedule))
77
+
78
+ self.schedule = schedule
79
+ if schedule == 'discrete':
80
+ if betas is not None:
81
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
82
+ else:
83
+ assert alphas_cumprod is not None
84
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
85
+ self.total_N = len(log_alphas)
86
+ self.T = 1.
87
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
88
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
89
+ else:
90
+ self.total_N = 1000
91
+ self.beta_0 = continuous_beta_0
92
+ self.beta_1 = continuous_beta_1
93
+ self.cosine_s = 0.008
94
+ self.cosine_beta_max = 999.
95
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
96
+ 1. + self.cosine_s) / math.pi - self.cosine_s
97
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
98
+ self.schedule = schedule
99
+ if schedule == 'cosine':
100
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
101
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
102
+ self.T = 0.9946
103
+ else:
104
+ self.T = 1.
105
+
106
+ def marginal_log_mean_coeff(self, t):
107
+ """
108
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
109
+ """
110
+ if self.schedule == 'discrete':
111
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
112
+ self.log_alpha_array.to(t.device)).reshape((-1))
113
+ elif self.schedule == 'linear':
114
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
115
+ elif self.schedule == 'cosine':
116
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
117
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
118
+ return log_alpha_t
119
+
120
+ def marginal_alpha(self, t):
121
+ """
122
+ Compute alpha_t of a given continuous-time label t in [0, T].
123
+ """
124
+ return torch.exp(self.marginal_log_mean_coeff(t))
125
+
126
+ def marginal_std(self, t):
127
+ """
128
+ Compute sigma_t of a given continuous-time label t in [0, T].
129
+ """
130
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
131
+
132
+ def marginal_lambda(self, t):
133
+ """
134
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
135
+ """
136
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
137
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
138
+ return log_mean_coeff - log_std
139
+
140
+ def inverse_lambda(self, lamb):
141
+ """
142
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
143
+ """
144
+ if self.schedule == 'linear':
145
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
146
+ Delta = self.beta_0 ** 2 + tmp
147
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
148
+ elif self.schedule == 'discrete':
149
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
150
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
151
+ torch.flip(self.t_array.to(lamb.device), [1]))
152
+ return t.reshape((-1,))
153
+ else:
154
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
155
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
156
+ 1. + self.cosine_s) / math.pi - self.cosine_s
157
+ t = t_fn(log_alpha)
158
+ return t
159
+
160
+
161
+ def model_wrapper(
162
+ model,
163
+ noise_schedule,
164
+ model_type="noise",
165
+ model_kwargs={},
166
+ guidance_type="uncond",
167
+ condition=None,
168
+ unconditional_condition=None,
169
+ guidance_scale=1.,
170
+ classifier_fn=None,
171
+ classifier_kwargs={},
172
+ ):
173
+ """Create a wrapper function for the noise prediction model.
174
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
175
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
176
+ We support four types of the diffusion model by setting `model_type`:
177
+ 1. "noise": noise prediction model. (Trained by predicting noise).
178
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
179
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
180
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
181
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
182
+ arXiv preprint arXiv:2202.00512 (2022).
183
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
184
+ arXiv preprint arXiv:2210.02303 (2022).
185
+
186
+ 4. "score": marginal score function. (Trained by denoising score matching).
187
+ Note that the score function and the noise prediction model follows a simple relationship:
188
+ ```
189
+ noise(x_t, t) = -sigma_t * score(x_t, t)
190
+ ```
191
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
192
+ 1. "uncond": unconditional sampling by DPMs.
193
+ The input `model` has the following format:
194
+ ``
195
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
196
+ ``
197
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
198
+ The input `model` has the following format:
199
+ ``
200
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
201
+ ``
202
+ The input `classifier_fn` has the following format:
203
+ ``
204
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
205
+ ``
206
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
207
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
208
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
209
+ The input `model` has the following format:
210
+ ``
211
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
212
+ ``
213
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
214
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
215
+ arXiv preprint arXiv:2207.12598 (2022).
216
+
217
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
218
+ or continuous-time labels (i.e. epsilon to T).
219
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
220
+ ``
221
+ def model_fn(x, t_continuous) -> noise:
222
+ t_input = get_model_input_time(t_continuous)
223
+ return noise_pred(model, x, t_input, **model_kwargs)
224
+ ``
225
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
226
+ ===============================================================
227
+ Args:
228
+ model: A diffusion model with the corresponding format described above.
229
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
230
+ model_type: A `str`. The parameterization type of the diffusion model.
231
+ "noise" or "x_start" or "v" or "score".
232
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
233
+ guidance_type: A `str`. The type of the guidance for sampling.
234
+ "uncond" or "classifier" or "classifier-free".
235
+ condition: A pytorch tensor. The condition for the guided sampling.
236
+ Only used for "classifier" or "classifier-free" guidance type.
237
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
238
+ Only used for "classifier-free" guidance type.
239
+ guidance_scale: A `float`. The scale for the guided sampling.
240
+ classifier_fn: A classifier function. Only used for the classifier guidance.
241
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
242
+ Returns:
243
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
244
+ """
245
+
246
+ def get_model_input_time(t_continuous):
247
+ """
248
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
249
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
250
+ For continuous-time DPMs, we just use `t_continuous`.
251
+ """
252
+ if noise_schedule.schedule == 'discrete':
253
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
254
+ else:
255
+ return t_continuous
256
+
257
+ def noise_pred_fn(x, t_continuous, cond=None):
258
+ if t_continuous.reshape((-1,)).shape[0] == 1:
259
+ t_continuous = t_continuous.expand((x.shape[0]))
260
+ t_input = get_model_input_time(t_continuous)
261
+ if cond is None:
262
+ output = model(x, t_input, **model_kwargs)
263
+ else:
264
+ output = model(x, t_input, cond, **model_kwargs)
265
+ if model_type == "noise":
266
+ return output
267
+ elif model_type == "x_start":
268
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
269
+ dims = x.dim()
270
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
271
+ elif model_type == "v":
272
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
273
+ dims = x.dim()
274
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
275
+ elif model_type == "score":
276
+ sigma_t = noise_schedule.marginal_std(t_continuous)
277
+ dims = x.dim()
278
+ return -expand_dims(sigma_t, dims) * output
279
+
280
+ def cond_grad_fn(x, t_input):
281
+ """
282
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
283
+ """
284
+ with torch.enable_grad():
285
+ x_in = x.detach().requires_grad_(True)
286
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
287
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
288
+
289
+ def model_fn(x, t_continuous):
290
+ """
291
+ The noise predicition model function that is used for DPM-Solver.
292
+ """
293
+ if t_continuous.reshape((-1,)).shape[0] == 1:
294
+ t_continuous = t_continuous.expand((x.shape[0]))
295
+ if guidance_type == "uncond":
296
+ return noise_pred_fn(x, t_continuous)
297
+ elif guidance_type == "classifier":
298
+ assert classifier_fn is not None
299
+ t_input = get_model_input_time(t_continuous)
300
+ cond_grad = cond_grad_fn(x, t_input)
301
+ sigma_t = noise_schedule.marginal_std(t_continuous)
302
+ noise = noise_pred_fn(x, t_continuous)
303
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
304
+ elif guidance_type == "classifier-free":
305
+ if guidance_scale == 1. or unconditional_condition is None:
306
+ return noise_pred_fn(x, t_continuous, cond=condition)
307
+ else:
308
+ x_in = torch.cat([x] * 2)
309
+ t_in = torch.cat([t_continuous] * 2)
310
+ c_in = torch.cat([unconditional_condition, condition])
311
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
312
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
313
+
314
+ assert model_type in ["noise", "x_start", "v"]
315
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
316
+ return model_fn
317
+
318
+
319
+ class DPM_Solver:
320
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
321
+ """Construct a DPM-Solver.
322
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
323
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
324
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
325
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
326
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
327
+ Args:
328
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
329
+ ``
330
+ def model_fn(x, t_continuous):
331
+ return noise
332
+ ``
333
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
334
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
335
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
336
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
337
+
338
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
339
+ """
340
+ self.model = model_fn
341
+ self.noise_schedule = noise_schedule
342
+ self.predict_x0 = predict_x0
343
+ self.thresholding = thresholding
344
+ self.max_val = max_val
345
+
346
+ def noise_prediction_fn(self, x, t):
347
+ """
348
+ Return the noise prediction model.
349
+ """
350
+ return self.model(x, t)
351
+
352
+ def data_prediction_fn(self, x, t):
353
+ """
354
+ Return the data prediction model (with thresholding).
355
+ """
356
+ noise = self.noise_prediction_fn(x, t)
357
+ dims = x.dim()
358
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
359
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
360
+ if self.thresholding:
361
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
362
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
363
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
364
+ x0 = torch.clamp(x0, -s, s) / s
365
+ return x0
366
+
367
+ def model_fn(self, x, t):
368
+ """
369
+ Convert the model to the noise prediction model or the data prediction model.
370
+ """
371
+ if self.predict_x0:
372
+ return self.data_prediction_fn(x, t)
373
+ else:
374
+ return self.noise_prediction_fn(x, t)
375
+
376
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
377
+ """Compute the intermediate time steps for sampling.
378
+ Args:
379
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
380
+ - 'logSNR': uniform logSNR for the time steps.
381
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
382
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
383
+ t_T: A `float`. The starting time of the sampling (default is T).
384
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
385
+ N: A `int`. The total number of the spacing of the time steps.
386
+ device: A torch device.
387
+ Returns:
388
+ A pytorch tensor of the time steps, with the shape (N + 1,).
389
+ """
390
+ if skip_type == 'logSNR':
391
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
392
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
393
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
394
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
395
+ elif skip_type == 'time_uniform':
396
+ return torch.linspace(t_T, t_0, N + 1).to(device)
397
+ elif skip_type == 'time_quadratic':
398
+ t_order = 2
399
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
400
+ return t
401
+ else:
402
+ raise ValueError(
403
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
404
+
405
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
406
+ """
407
+ Get the order of each step for sampling by the singlestep DPM-Solver.
408
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
409
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
410
+ - If order == 1:
411
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
412
+ - If order == 2:
413
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
414
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
415
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
416
+ - If order == 3:
417
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
418
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
419
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
420
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
421
+ ============================================
422
+ Args:
423
+ order: A `int`. The max order for the solver (2 or 3).
424
+ steps: A `int`. The total number of function evaluations (NFE).
425
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
426
+ - 'logSNR': uniform logSNR for the time steps.
427
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
428
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
429
+ t_T: A `float`. The starting time of the sampling (default is T).
430
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
431
+ device: A torch device.
432
+ Returns:
433
+ orders: A list of the solver order of each step.
434
+ """
435
+ if order == 3:
436
+ K = steps // 3 + 1
437
+ if steps % 3 == 0:
438
+ orders = [3, ] * (K - 2) + [2, 1]
439
+ elif steps % 3 == 1:
440
+ orders = [3, ] * (K - 1) + [1]
441
+ else:
442
+ orders = [3, ] * (K - 1) + [2]
443
+ elif order == 2:
444
+ if steps % 2 == 0:
445
+ K = steps // 2
446
+ orders = [2, ] * K
447
+ else:
448
+ K = steps // 2 + 1
449
+ orders = [2, ] * (K - 1) + [1]
450
+ elif order == 1:
451
+ K = 1
452
+ orders = [1, ] * steps
453
+ else:
454
+ raise ValueError("'order' must be '1' or '2' or '3'.")
455
+ if skip_type == 'logSNR':
456
+ # To reproduce the results in DPM-Solver paper
457
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
458
+ else:
459
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
460
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
461
+ return timesteps_outer, orders
462
+
463
+ def denoise_to_zero_fn(self, x, s):
464
+ """
465
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
466
+ """
467
+ return self.data_prediction_fn(x, s)
468
+
469
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
470
+ """
471
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
472
+ Args:
473
+ x: A pytorch tensor. The initial value at time `s`.
474
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
475
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
476
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
477
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
478
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
479
+ Returns:
480
+ x_t: A pytorch tensor. The approximated solution at time `t`.
481
+ """
482
+ ns = self.noise_schedule
483
+ dims = x.dim()
484
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
485
+ h = lambda_t - lambda_s
486
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
487
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
488
+ alpha_t = torch.exp(log_alpha_t)
489
+
490
+ if self.predict_x0:
491
+ phi_1 = torch.expm1(-h)
492
+ if model_s is None:
493
+ model_s = self.model_fn(x, s)
494
+ x_t = (
495
+ expand_dims(sigma_t / sigma_s, dims) * x
496
+ - expand_dims(alpha_t * phi_1, dims) * model_s
497
+ )
498
+ if return_intermediate:
499
+ return x_t, {'model_s': model_s}
500
+ else:
501
+ return x_t
502
+ else:
503
+ phi_1 = torch.expm1(h)
504
+ if model_s is None:
505
+ model_s = self.model_fn(x, s)
506
+ x_t = (
507
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
508
+ - expand_dims(sigma_t * phi_1, dims) * model_s
509
+ )
510
+ if return_intermediate:
511
+ return x_t, {'model_s': model_s}
512
+ else:
513
+ return x_t
514
+
515
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
516
+ solver_type='dpm_solver'):
517
+ """
518
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
519
+ Args:
520
+ x: A pytorch tensor. The initial value at time `s`.
521
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
522
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
523
+ r1: A `float`. The hyperparameter of the second-order solver.
524
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
525
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
526
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
527
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
528
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
529
+ Returns:
530
+ x_t: A pytorch tensor. The approximated solution at time `t`.
531
+ """
532
+ if solver_type not in ['dpm_solver', 'taylor']:
533
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
534
+ if r1 is None:
535
+ r1 = 0.5
536
+ ns = self.noise_schedule
537
+ dims = x.dim()
538
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
539
+ h = lambda_t - lambda_s
540
+ lambda_s1 = lambda_s + r1 * h
541
+ s1 = ns.inverse_lambda(lambda_s1)
542
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
543
+ s1), ns.marginal_log_mean_coeff(t)
544
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
545
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
546
+
547
+ if self.predict_x0:
548
+ phi_11 = torch.expm1(-r1 * h)
549
+ phi_1 = torch.expm1(-h)
550
+
551
+ if model_s is None:
552
+ model_s = self.model_fn(x, s)
553
+ x_s1 = (
554
+ expand_dims(sigma_s1 / sigma_s, dims) * x
555
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
556
+ )
557
+ model_s1 = self.model_fn(x_s1, s1)
558
+ if solver_type == 'dpm_solver':
559
+ x_t = (
560
+ expand_dims(sigma_t / sigma_s, dims) * x
561
+ - expand_dims(alpha_t * phi_1, dims) * model_s
562
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
563
+ )
564
+ elif solver_type == 'taylor':
565
+ x_t = (
566
+ expand_dims(sigma_t / sigma_s, dims) * x
567
+ - expand_dims(alpha_t * phi_1, dims) * model_s
568
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
569
+ model_s1 - model_s)
570
+ )
571
+ else:
572
+ phi_11 = torch.expm1(r1 * h)
573
+ phi_1 = torch.expm1(h)
574
+
575
+ if model_s is None:
576
+ model_s = self.model_fn(x, s)
577
+ x_s1 = (
578
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
579
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
580
+ )
581
+ model_s1 = self.model_fn(x_s1, s1)
582
+ if solver_type == 'dpm_solver':
583
+ x_t = (
584
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
585
+ - expand_dims(sigma_t * phi_1, dims) * model_s
586
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
587
+ )
588
+ elif solver_type == 'taylor':
589
+ x_t = (
590
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
591
+ - expand_dims(sigma_t * phi_1, dims) * model_s
592
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
593
+ )
594
+ if return_intermediate:
595
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
596
+ else:
597
+ return x_t
598
+
599
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
600
+ return_intermediate=False, solver_type='dpm_solver'):
601
+ """
602
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
603
+ Args:
604
+ x: A pytorch tensor. The initial value at time `s`.
605
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
606
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
607
+ r1: A `float`. The hyperparameter of the third-order solver.
608
+ r2: A `float`. The hyperparameter of the third-order solver.
609
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
610
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
611
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
612
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
613
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
614
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
615
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
616
+ Returns:
617
+ x_t: A pytorch tensor. The approximated solution at time `t`.
618
+ """
619
+ if solver_type not in ['dpm_solver', 'taylor']:
620
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
621
+ if r1 is None:
622
+ r1 = 1. / 3.
623
+ if r2 is None:
624
+ r2 = 2. / 3.
625
+ ns = self.noise_schedule
626
+ dims = x.dim()
627
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
628
+ h = lambda_t - lambda_s
629
+ lambda_s1 = lambda_s + r1 * h
630
+ lambda_s2 = lambda_s + r2 * h
631
+ s1 = ns.inverse_lambda(lambda_s1)
632
+ s2 = ns.inverse_lambda(lambda_s2)
633
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
634
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
635
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
636
+ s2), ns.marginal_std(t)
637
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
638
+
639
+ if self.predict_x0:
640
+ phi_11 = torch.expm1(-r1 * h)
641
+ phi_12 = torch.expm1(-r2 * h)
642
+ phi_1 = torch.expm1(-h)
643
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
644
+ phi_2 = phi_1 / h + 1.
645
+ phi_3 = phi_2 / h - 0.5
646
+
647
+ if model_s is None:
648
+ model_s = self.model_fn(x, s)
649
+ if model_s1 is None:
650
+ x_s1 = (
651
+ expand_dims(sigma_s1 / sigma_s, dims) * x
652
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
653
+ )
654
+ model_s1 = self.model_fn(x_s1, s1)
655
+ x_s2 = (
656
+ expand_dims(sigma_s2 / sigma_s, dims) * x
657
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
658
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
659
+ )
660
+ model_s2 = self.model_fn(x_s2, s2)
661
+ if solver_type == 'dpm_solver':
662
+ x_t = (
663
+ expand_dims(sigma_t / sigma_s, dims) * x
664
+ - expand_dims(alpha_t * phi_1, dims) * model_s
665
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
666
+ )
667
+ elif solver_type == 'taylor':
668
+ D1_0 = (1. / r1) * (model_s1 - model_s)
669
+ D1_1 = (1. / r2) * (model_s2 - model_s)
670
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
671
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
672
+ x_t = (
673
+ expand_dims(sigma_t / sigma_s, dims) * x
674
+ - expand_dims(alpha_t * phi_1, dims) * model_s
675
+ + expand_dims(alpha_t * phi_2, dims) * D1
676
+ - expand_dims(alpha_t * phi_3, dims) * D2
677
+ )
678
+ else:
679
+ phi_11 = torch.expm1(r1 * h)
680
+ phi_12 = torch.expm1(r2 * h)
681
+ phi_1 = torch.expm1(h)
682
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
683
+ phi_2 = phi_1 / h - 1.
684
+ phi_3 = phi_2 / h - 0.5
685
+
686
+ if model_s is None:
687
+ model_s = self.model_fn(x, s)
688
+ if model_s1 is None:
689
+ x_s1 = (
690
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
691
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
692
+ )
693
+ model_s1 = self.model_fn(x_s1, s1)
694
+ x_s2 = (
695
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
696
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
697
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
698
+ )
699
+ model_s2 = self.model_fn(x_s2, s2)
700
+ if solver_type == 'dpm_solver':
701
+ x_t = (
702
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
703
+ - expand_dims(sigma_t * phi_1, dims) * model_s
704
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
705
+ )
706
+ elif solver_type == 'taylor':
707
+ D1_0 = (1. / r1) * (model_s1 - model_s)
708
+ D1_1 = (1. / r2) * (model_s2 - model_s)
709
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
710
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
711
+ x_t = (
712
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
713
+ - expand_dims(sigma_t * phi_1, dims) * model_s
714
+ - expand_dims(sigma_t * phi_2, dims) * D1
715
+ - expand_dims(sigma_t * phi_3, dims) * D2
716
+ )
717
+
718
+ if return_intermediate:
719
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
720
+ else:
721
+ return x_t
722
+
723
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
724
+ """
725
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
726
+ Args:
727
+ x: A pytorch tensor. The initial value at time `s`.
728
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
729
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
730
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
731
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
732
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
733
+ Returns:
734
+ x_t: A pytorch tensor. The approximated solution at time `t`.
735
+ """
736
+ if solver_type not in ['dpm_solver', 'taylor']:
737
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
738
+ ns = self.noise_schedule
739
+ dims = x.dim()
740
+ model_prev_1, model_prev_0 = model_prev_list
741
+ t_prev_1, t_prev_0 = t_prev_list
742
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
743
+ t_prev_0), ns.marginal_lambda(t)
744
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
745
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
746
+ alpha_t = torch.exp(log_alpha_t)
747
+
748
+ h_0 = lambda_prev_0 - lambda_prev_1
749
+ h = lambda_t - lambda_prev_0
750
+ r0 = h_0 / h
751
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
752
+ if self.predict_x0:
753
+ if solver_type == 'dpm_solver':
754
+ x_t = (
755
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
756
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
757
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
758
+ )
759
+ elif solver_type == 'taylor':
760
+ x_t = (
761
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
762
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
763
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
764
+ )
765
+ else:
766
+ if solver_type == 'dpm_solver':
767
+ x_t = (
768
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
769
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
770
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
771
+ )
772
+ elif solver_type == 'taylor':
773
+ x_t = (
774
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
775
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
776
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
777
+ )
778
+ return x_t
779
+
780
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
781
+ """
782
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
783
+ Args:
784
+ x: A pytorch tensor. The initial value at time `s`.
785
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
786
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
787
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
788
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
789
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
790
+ Returns:
791
+ x_t: A pytorch tensor. The approximated solution at time `t`.
792
+ """
793
+ ns = self.noise_schedule
794
+ dims = x.dim()
795
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
796
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
797
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
798
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
799
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
800
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
801
+ alpha_t = torch.exp(log_alpha_t)
802
+
803
+ h_1 = lambda_prev_1 - lambda_prev_2
804
+ h_0 = lambda_prev_0 - lambda_prev_1
805
+ h = lambda_t - lambda_prev_0
806
+ r0, r1 = h_0 / h, h_1 / h
807
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
808
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
809
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
810
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
811
+ if self.predict_x0:
812
+ x_t = (
813
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
814
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
815
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
816
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
817
+ )
818
+ else:
819
+ x_t = (
820
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
821
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
822
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
823
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
824
+ )
825
+ return x_t
826
+
827
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
828
+ r2=None):
829
+ """
830
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
831
+ Args:
832
+ x: A pytorch tensor. The initial value at time `s`.
833
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
834
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
835
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
836
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
837
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
838
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
839
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
840
+ r2: A `float`. The hyperparameter of the third-order solver.
841
+ Returns:
842
+ x_t: A pytorch tensor. The approximated solution at time `t`.
843
+ """
844
+ if order == 1:
845
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
846
+ elif order == 2:
847
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
848
+ solver_type=solver_type, r1=r1)
849
+ elif order == 3:
850
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
851
+ solver_type=solver_type, r1=r1, r2=r2)
852
+ else:
853
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
854
+
855
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
856
+ """
857
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
858
+ Args:
859
+ x: A pytorch tensor. The initial value at time `s`.
860
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
861
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
862
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
863
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
864
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
865
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
866
+ Returns:
867
+ x_t: A pytorch tensor. The approximated solution at time `t`.
868
+ """
869
+ if order == 1:
870
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
871
+ elif order == 2:
872
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
873
+ elif order == 3:
874
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
875
+ else:
876
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
877
+
878
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
879
+ solver_type='dpm_solver'):
880
+ """
881
+ The adaptive step size solver based on singlestep DPM-Solver.
882
+ Args:
883
+ x: A pytorch tensor. The initial value at time `t_T`.
884
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
885
+ t_T: A `float`. The starting time of the sampling (default is T).
886
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
887
+ h_init: A `float`. The initial step size (for logSNR).
888
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
889
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
890
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
891
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
892
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
893
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
894
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
895
+ Returns:
896
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
897
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
898
+ """
899
+ ns = self.noise_schedule
900
+ s = t_T * torch.ones((x.shape[0],)).to(x)
901
+ lambda_s = ns.marginal_lambda(s)
902
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
903
+ h = h_init * torch.ones_like(s).to(x)
904
+ x_prev = x
905
+ nfe = 0
906
+ if order == 2:
907
+ r1 = 0.5
908
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
909
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
910
+ solver_type=solver_type,
911
+ **kwargs)
912
+ elif order == 3:
913
+ r1, r2 = 1. / 3., 2. / 3.
914
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
915
+ return_intermediate=True,
916
+ solver_type=solver_type)
917
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
918
+ solver_type=solver_type,
919
+ **kwargs)
920
+ else:
921
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
922
+ while torch.abs((s - t_0)).mean() > t_err:
923
+ t = ns.inverse_lambda(lambda_s + h)
924
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
925
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
926
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
927
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
928
+ E = norm_fn((x_higher - x_lower) / delta).max()
929
+ if torch.all(E <= 1.):
930
+ x = x_higher
931
+ s = t
932
+ x_prev = x_lower
933
+ lambda_s = ns.marginal_lambda(s)
934
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
935
+ nfe += order
936
+ print('adaptive solver nfe', nfe)
937
+ return x
938
+
939
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
940
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
941
+ atol=0.0078, rtol=0.05,
942
+ ):
943
+ """
944
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
945
+ =====================================================
946
+ We support the following algorithms for both noise prediction model and data prediction model:
947
+ - 'singlestep':
948
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
949
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
950
+ The total number of function evaluations (NFE) == `steps`.
951
+ Given a fixed NFE == `steps`, the sampling procedure is:
952
+ - If `order` == 1:
953
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
954
+ - If `order` == 2:
955
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
956
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
957
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
958
+ - If `order` == 3:
959
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
960
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
961
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
962
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
963
+ - 'multistep':
964
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
965
+ We initialize the first `order` values by lower order multistep solvers.
966
+ Given a fixed NFE == `steps`, the sampling procedure is:
967
+ Denote K = steps.
968
+ - If `order` == 1:
969
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
970
+ - If `order` == 2:
971
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
972
+ - If `order` == 3:
973
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
974
+ - 'singlestep_fixed':
975
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
976
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
977
+ - 'adaptive':
978
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
979
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
980
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
981
+ (NFE) and the sample quality.
982
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
983
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
984
+ =====================================================
985
+ Some advices for choosing the algorithm:
986
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
987
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
988
+ e.g.
989
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
990
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
991
+ skip_type='time_uniform', method='singlestep')
992
+ - For **guided sampling with large guidance scale** by DPMs:
993
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
994
+ e.g.
995
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
996
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
997
+ skip_type='time_uniform', method='multistep')
998
+ We support three types of `skip_type`:
999
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1000
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1001
+ - 'time_quadratic': quadratic time for the time steps.
1002
+ =====================================================
1003
+ Args:
1004
+ x: A pytorch tensor. The initial value at time `t_start`
1005
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1006
+ steps: A `int`. The total number of function evaluations (NFE).
1007
+ t_start: A `float`. The starting time of the sampling.
1008
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1009
+ t_end: A `float`. The ending time of the sampling.
1010
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1011
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1012
+ For discrete-time DPMs:
1013
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1014
+ For continuous-time DPMs:
1015
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1016
+ order: A `int`. The order of DPM-Solver.
1017
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1018
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1019
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1020
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1021
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1022
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1023
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1024
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1025
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1026
+ it for high-resolutional images.
1027
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1028
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1029
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1030
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1031
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1032
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1033
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1034
+ Returns:
1035
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1036
+ """
1037
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1038
+ t_T = self.noise_schedule.T if t_start is None else t_start
1039
+ device = x.device
1040
+ if method == 'adaptive':
1041
+ with torch.no_grad():
1042
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
1043
+ solver_type=solver_type)
1044
+ elif method == 'multistep':
1045
+ assert steps >= order
1046
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1047
+ assert timesteps.shape[0] - 1 == steps
1048
+ with torch.no_grad():
1049
+ vec_t = timesteps[0].expand((x.shape[0]))
1050
+ model_prev_list = [self.model_fn(x, vec_t)]
1051
+ t_prev_list = [vec_t]
1052
+ # Init the first `order` values by lower order multistep DPM-Solver.
1053
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
1054
+ vec_t = timesteps[init_order].expand(x.shape[0])
1055
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
1056
+ solver_type=solver_type)
1057
+ model_prev_list.append(self.model_fn(x, vec_t))
1058
+ t_prev_list.append(vec_t)
1059
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1060
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
1061
+ vec_t = timesteps[step].expand(x.shape[0])
1062
+ if lower_order_final and steps < 15:
1063
+ step_order = min(order, steps + 1 - step)
1064
+ else:
1065
+ step_order = order
1066
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
1067
+ solver_type=solver_type)
1068
+ for i in range(order - 1):
1069
+ t_prev_list[i] = t_prev_list[i + 1]
1070
+ model_prev_list[i] = model_prev_list[i + 1]
1071
+ t_prev_list[-1] = vec_t
1072
+ # We do not need to evaluate the final model value.
1073
+ if step < steps:
1074
+ model_prev_list[-1] = self.model_fn(x, vec_t)
1075
+ elif method in ['singlestep', 'singlestep_fixed']:
1076
+ if method == 'singlestep':
1077
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
1078
+ skip_type=skip_type,
1079
+ t_T=t_T, t_0=t_0,
1080
+ device=device)
1081
+ elif method == 'singlestep_fixed':
1082
+ K = steps // order
1083
+ orders = [order, ] * K
1084
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1085
+ for i, order in enumerate(orders):
1086
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1087
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
1088
+ N=order, device=device)
1089
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1090
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
1091
+ h = lambda_inner[-1] - lambda_inner[0]
1092
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1093
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1094
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1095
+ if denoise_to_zero:
1096
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1097
+ return x
1098
+
1099
+
1100
+ #############################################################
1101
+ # other utility functions
1102
+ #############################################################
1103
+
1104
+ def interpolate_fn(x, xp, yp):
1105
+ """
1106
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1107
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1108
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1109
+ Args:
1110
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1111
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1112
+ yp: PyTorch tensor with shape [C, K].
1113
+ Returns:
1114
+ The function values f(x), with shape [N, C].
1115
+ """
1116
+ N, K = x.shape[0], xp.shape[1]
1117
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1118
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1119
+ x_idx = torch.argmin(x_indices, dim=2)
1120
+ cand_start_idx = x_idx - 1
1121
+ start_idx = torch.where(
1122
+ torch.eq(x_idx, 0),
1123
+ torch.tensor(1, device=x.device),
1124
+ torch.where(
1125
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1126
+ ),
1127
+ )
1128
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1129
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1130
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1131
+ start_idx2 = torch.where(
1132
+ torch.eq(x_idx, 0),
1133
+ torch.tensor(0, device=x.device),
1134
+ torch.where(
1135
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1136
+ ),
1137
+ )
1138
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1139
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1140
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1141
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1142
+ return cand
1143
+
1144
+
1145
+ def expand_dims(v, dims):
1146
+ """
1147
+ Expand the tensor `v` to the dim `dims`.
1148
+ Args:
1149
+ `v`: a PyTorch tensor with shape [N].
1150
+ `dim`: a `int`.
1151
+ Returns:
1152
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1153
+ """
1154
+ return v[(...,) + (None,) * (dims - 1)]
model_lib/ControlNet/ldm/models/diffusion/dpm_solver/sampler.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+ import torch
3
+
4
+ from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
5
+
6
+
7
+ MODEL_TYPES = {
8
+ "eps": "noise",
9
+ "v": "v"
10
+ }
11
+
12
+
13
+ class DPMSolverSampler(object):
14
+ def __init__(self, model, **kwargs):
15
+ super().__init__()
16
+ self.model = model
17
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
18
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
19
+
20
+ def register_buffer(self, name, attr):
21
+ if type(attr) == torch.Tensor:
22
+ if attr.device != torch.device("cuda"):
23
+ attr = attr.to(torch.device("cuda"))
24
+ setattr(self, name, attr)
25
+
26
+ @torch.no_grad()
27
+ def sample(self,
28
+ S,
29
+ batch_size,
30
+ shape,
31
+ conditioning=None,
32
+ callback=None,
33
+ normals_sequence=None,
34
+ img_callback=None,
35
+ quantize_x0=False,
36
+ eta=0.,
37
+ mask=None,
38
+ x0=None,
39
+ temperature=1.,
40
+ noise_dropout=0.,
41
+ score_corrector=None,
42
+ corrector_kwargs=None,
43
+ verbose=True,
44
+ x_T=None,
45
+ log_every_t=100,
46
+ unconditional_guidance_scale=1.,
47
+ unconditional_conditioning=None,
48
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
49
+ **kwargs
50
+ ):
51
+ if conditioning is not None:
52
+ if isinstance(conditioning, dict):
53
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
54
+ if cbs != batch_size:
55
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
56
+ else:
57
+ if conditioning.shape[0] != batch_size:
58
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
59
+
60
+ # sampling
61
+ C, H, W = shape
62
+ size = (batch_size, C, H, W)
63
+
64
+ print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
65
+
66
+ device = self.model.betas.device
67
+ if x_T is None:
68
+ img = torch.randn(size, device=device)
69
+ else:
70
+ img = x_T
71
+
72
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
73
+
74
+ model_fn = model_wrapper(
75
+ lambda x, t, c: self.model.apply_model(x, t, c),
76
+ ns,
77
+ model_type=MODEL_TYPES[self.model.parameterization],
78
+ guidance_type="classifier-free",
79
+ condition=conditioning,
80
+ unconditional_condition=unconditional_conditioning,
81
+ guidance_scale=unconditional_guidance_scale,
82
+ )
83
+
84
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
85
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
86
+
87
+ return x.to(device), None
model_lib/ControlNet/ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+ from ldm.models.diffusion.sampling_util import norm_thresholding
10
+
11
+
12
+ class PLMSSampler(object):
13
+ def __init__(self, model, schedule="linear", **kwargs):
14
+ super().__init__()
15
+ self.model = model
16
+ self.ddpm_num_timesteps = model.num_timesteps
17
+ self.schedule = schedule
18
+
19
+ def register_buffer(self, name, attr):
20
+ if type(attr) == torch.Tensor:
21
+ if attr.device != torch.device("cuda"):
22
+ attr = attr.to(torch.device("cuda"))
23
+ setattr(self, name, attr)
24
+
25
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26
+ if ddim_eta != 0:
27
+ raise ValueError('ddim_eta must be 0 for PLMS')
28
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30
+ alphas_cumprod = self.model.alphas_cumprod
31
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
33
+
34
+ self.register_buffer('betas', to_torch(self.model.betas))
35
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
37
+
38
+ # calculations for diffusion q(x_t | x_{t-1}) and others
39
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44
+
45
+ # ddim sampling parameters
46
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47
+ ddim_timesteps=self.ddim_timesteps,
48
+ eta=ddim_eta,verbose=verbose)
49
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
50
+ self.register_buffer('ddim_alphas', ddim_alphas)
51
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57
+
58
+ @torch.no_grad()
59
+ def sample(self,
60
+ S,
61
+ batch_size,
62
+ shape,
63
+ conditioning=None,
64
+ callback=None,
65
+ normals_sequence=None,
66
+ img_callback=None,
67
+ quantize_x0=False,
68
+ eta=0.,
69
+ mask=None,
70
+ x0=None,
71
+ temperature=1.,
72
+ noise_dropout=0.,
73
+ score_corrector=None,
74
+ corrector_kwargs=None,
75
+ verbose=True,
76
+ x_T=None,
77
+ log_every_t=100,
78
+ unconditional_guidance_scale=1.,
79
+ unconditional_conditioning=None,
80
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81
+ dynamic_threshold=None,
82
+ **kwargs
83
+ ):
84
+ if conditioning is not None:
85
+ if isinstance(conditioning, dict):
86
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
87
+ if cbs != batch_size:
88
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
89
+ else:
90
+ if conditioning.shape[0] != batch_size:
91
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
92
+
93
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
94
+ # sampling
95
+ C, H, W = shape
96
+ size = (batch_size, C, H, W)
97
+ print(f'Data shape for PLMS sampling is {size}')
98
+
99
+ samples, intermediates = self.plms_sampling(conditioning, size,
100
+ callback=callback,
101
+ img_callback=img_callback,
102
+ quantize_denoised=quantize_x0,
103
+ mask=mask, x0=x0,
104
+ ddim_use_original_steps=False,
105
+ noise_dropout=noise_dropout,
106
+ temperature=temperature,
107
+ score_corrector=score_corrector,
108
+ corrector_kwargs=corrector_kwargs,
109
+ x_T=x_T,
110
+ log_every_t=log_every_t,
111
+ unconditional_guidance_scale=unconditional_guidance_scale,
112
+ unconditional_conditioning=unconditional_conditioning,
113
+ dynamic_threshold=dynamic_threshold,
114
+ )
115
+ return samples, intermediates
116
+
117
+ @torch.no_grad()
118
+ def plms_sampling(self, cond, shape,
119
+ x_T=None, ddim_use_original_steps=False,
120
+ callback=None, timesteps=None, quantize_denoised=False,
121
+ mask=None, x0=None, img_callback=None, log_every_t=100,
122
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
123
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
124
+ dynamic_threshold=None):
125
+ device = self.model.betas.device
126
+ b = shape[0]
127
+ if x_T is None:
128
+ img = torch.randn(shape, device=device)
129
+ else:
130
+ img = x_T
131
+
132
+ if timesteps is None:
133
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
134
+ elif timesteps is not None and not ddim_use_original_steps:
135
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
136
+ timesteps = self.ddim_timesteps[:subset_end]
137
+
138
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
139
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
140
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
141
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
142
+
143
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
144
+ old_eps = []
145
+
146
+ for i, step in enumerate(iterator):
147
+ index = total_steps - i - 1
148
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
149
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
150
+
151
+ if mask is not None:
152
+ assert x0 is not None
153
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
154
+ img = img_orig * mask + (1. - mask) * img
155
+
156
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
157
+ quantize_denoised=quantize_denoised, temperature=temperature,
158
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
159
+ corrector_kwargs=corrector_kwargs,
160
+ unconditional_guidance_scale=unconditional_guidance_scale,
161
+ unconditional_conditioning=unconditional_conditioning,
162
+ old_eps=old_eps, t_next=ts_next,
163
+ dynamic_threshold=dynamic_threshold)
164
+ img, pred_x0, e_t = outs
165
+ old_eps.append(e_t)
166
+ if len(old_eps) >= 4:
167
+ old_eps.pop(0)
168
+ if callback: callback(i)
169
+ if img_callback: img_callback(pred_x0, i)
170
+
171
+ if index % log_every_t == 0 or index == total_steps - 1:
172
+ intermediates['x_inter'].append(img)
173
+ intermediates['pred_x0'].append(pred_x0)
174
+
175
+ return img, intermediates
176
+
177
+ @torch.no_grad()
178
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
179
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
180
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
181
+ dynamic_threshold=None):
182
+ b, *_, device = *x.shape, x.device
183
+
184
+ def get_model_output(x, t):
185
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
186
+ e_t = self.model.apply_model(x, t, c)
187
+ else:
188
+ x_in = torch.cat([x] * 2)
189
+ t_in = torch.cat([t] * 2)
190
+ c_in = torch.cat([unconditional_conditioning, c])
191
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
192
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
193
+
194
+ if score_corrector is not None:
195
+ assert self.model.parameterization == "eps"
196
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
197
+
198
+ return e_t
199
+
200
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
201
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
202
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
203
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
204
+
205
+ def get_x_prev_and_pred_x0(e_t, index):
206
+ # select parameters corresponding to the currently considered timestep
207
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
208
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
209
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
210
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
211
+
212
+ # current prediction for x_0
213
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
214
+ if quantize_denoised:
215
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
216
+ if dynamic_threshold is not None:
217
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
218
+ # direction pointing to x_t
219
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
220
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
221
+ if noise_dropout > 0.:
222
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
223
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
224
+ return x_prev, pred_x0
225
+
226
+ e_t = get_model_output(x, t)
227
+ if len(old_eps) == 0:
228
+ # Pseudo Improved Euler (2nd order)
229
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
230
+ e_t_next = get_model_output(x_prev, t_next)
231
+ e_t_prime = (e_t + e_t_next) / 2
232
+ elif len(old_eps) == 1:
233
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
234
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
235
+ elif len(old_eps) == 2:
236
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
237
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
238
+ elif len(old_eps) >= 3:
239
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
240
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
241
+
242
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
243
+
244
+ return x_prev, pred_x0, e_t
model_lib/ControlNet/ldm/models/diffusion/sampling_util.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def append_dims(x, target_dims):
6
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7
+ From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8
+ dims_to_append = target_dims - x.ndim
9
+ if dims_to_append < 0:
10
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11
+ return x[(...,) + (None,) * dims_to_append]
12
+
13
+
14
+ def norm_thresholding(x0, value):
15
+ s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
16
+ return x0 * (value / s)
17
+
18
+
19
+ def spatial_norm_thresholding(x0, value):
20
+ # b c h w
21
+ s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
22
+ return x0 * (value / s)
model_lib/ControlNet/ldm/modules/__pycache__/attention.cpython-39.pyc ADDED
Binary file (11 kB). View file
 
model_lib/ControlNet/ldm/modules/__pycache__/ema.cpython-39.pyc ADDED
Binary file (3.25 kB). View file
 
model_lib/ControlNet/ldm/modules/__pycache__/motion_module.cpython-39.pyc ADDED
Binary file (8.6 kB). View file
 
model_lib/ControlNet/ldm/modules/attention.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ # from turtle import forward
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn, einsum
7
+ from einops import rearrange, repeat
8
+ from typing import Optional, Any
9
+ import pdb
10
+ from model_lib.ControlNet.ldm.modules.diffusionmodules.util import checkpoint
11
+
12
+
13
+ try:
14
+ import xformers
15
+ import xformers.ops
16
+ XFORMERS_IS_AVAILBLE = True
17
+ except:
18
+ XFORMERS_IS_AVAILBLE = False
19
+
20
+ # CrossAttn precision handling
21
+ import os
22
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
23
+
24
+ def exists(val):
25
+ return val is not None
26
+
27
+
28
+ def uniq(arr):
29
+ return{el: True for el in arr}.keys()
30
+
31
+
32
+ def default(val, d):
33
+ if exists(val):
34
+ return val
35
+ return d() if isfunction(d) else d
36
+
37
+
38
+ def max_neg_value(t):
39
+ return -torch.finfo(t.dtype).max
40
+
41
+
42
+ def init_(tensor):
43
+ dim = tensor.shape[-1]
44
+ std = 1 / math.sqrt(dim)
45
+ tensor.uniform_(-std, std)
46
+ return tensor
47
+
48
+
49
+ # feedforward
50
+ class GEGLU(nn.Module):
51
+ def __init__(self, dim_in, dim_out):
52
+ super().__init__()
53
+ self.proj = nn.Linear(dim_in, dim_out * 2)
54
+
55
+ def forward(self, x):
56
+ x, gate = self.proj(x).chunk(2, dim=-1)
57
+ return x * F.gelu(gate)
58
+
59
+
60
+ class FeedForward(nn.Module):
61
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
62
+ super().__init__()
63
+ inner_dim = int(dim * mult)
64
+ dim_out = default(dim_out, dim)
65
+ project_in = nn.Sequential(
66
+ nn.Linear(dim, inner_dim),
67
+ nn.GELU()
68
+ ) if not glu else GEGLU(dim, inner_dim)
69
+
70
+ self.net = nn.Sequential(
71
+ project_in,
72
+ nn.Dropout(dropout),
73
+ nn.Linear(inner_dim, dim_out)
74
+ )
75
+
76
+ def forward(self, x):
77
+ return self.net(x)
78
+
79
+
80
+ def zero_module(module):
81
+ """
82
+ Zero out the parameters of a module and return it.
83
+ """
84
+ for p in module.parameters():
85
+ p.detach().zero_()
86
+ return module
87
+
88
+
89
+ def Normalize(in_channels):
90
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
91
+
92
+
93
+ class SpatialSelfAttention(nn.Module):
94
+ def __init__(self, in_channels):
95
+ super().__init__()
96
+ self.in_channels = in_channels
97
+
98
+ self.norm = Normalize(in_channels)
99
+ self.q = torch.nn.Conv2d(in_channels,
100
+ in_channels,
101
+ kernel_size=1,
102
+ stride=1,
103
+ padding=0)
104
+ self.k = torch.nn.Conv2d(in_channels,
105
+ in_channels,
106
+ kernel_size=1,
107
+ stride=1,
108
+ padding=0)
109
+ self.v = torch.nn.Conv2d(in_channels,
110
+ in_channels,
111
+ kernel_size=1,
112
+ stride=1,
113
+ padding=0)
114
+ self.proj_out = torch.nn.Conv2d(in_channels,
115
+ in_channels,
116
+ kernel_size=1,
117
+ stride=1,
118
+ padding=0)
119
+
120
+ def forward(self, x):
121
+ h_ = x
122
+ h_ = self.norm(h_)
123
+ q = self.q(h_)
124
+ k = self.k(h_)
125
+ v = self.v(h_)
126
+
127
+ # compute attention
128
+ b,c,h,w = q.shape
129
+ q = rearrange(q, 'b c h w -> b (h w) c')
130
+ k = rearrange(k, 'b c h w -> b c (h w)')
131
+ w_ = torch.einsum('bij,bjk->bik', q, k)
132
+
133
+ w_ = w_ * (int(c)**(-0.5))
134
+ w_ = torch.nn.functional.softmax(w_, dim=2)
135
+
136
+ # attend to values
137
+ v = rearrange(v, 'b c h w -> b c (h w)')
138
+ w_ = rearrange(w_, 'b i j -> b j i')
139
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
140
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
141
+ h_ = self.proj_out(h_)
142
+
143
+ return x+h_
144
+
145
+
146
+ class CrossAttention(nn.Module):
147
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., checkpoint=False):
148
+ super().__init__()
149
+ inner_dim = dim_head * heads
150
+ context_dim = default(context_dim, query_dim)
151
+
152
+ self.scale = dim_head ** -0.5
153
+ self.heads = heads
154
+
155
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
156
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
157
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
158
+
159
+ self.to_out = nn.Sequential(
160
+ nn.Linear(inner_dim, query_dim),
161
+ nn.Dropout(dropout)
162
+ )
163
+ self.checkpoint = checkpoint
164
+
165
+ def forward(self, x, context=None):
166
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
167
+
168
+ def _forward(self, x, context=None):
169
+ h = self.heads
170
+
171
+ q = self.to_q(x)
172
+ context = default(context, x)
173
+ k = self.to_k(context)
174
+ v = self.to_v(context)
175
+
176
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
177
+
178
+ # force cast to fp32 to avoid overflowing
179
+ if _ATTN_PRECISION =="fp32":
180
+ with torch.autocast(enabled=False, device_type = 'cuda'):
181
+ q, k = q.float(), k.float()
182
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
183
+ else:
184
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
185
+
186
+ del q, k
187
+
188
+ # if exists(mask):
189
+ # mask = rearrange(mask, 'b ... -> b (...)')
190
+ # max_neg_value = -torch.finfo(sim.dtype).max
191
+ # mask = repeat(mask, 'b j -> (b h) () j', h=h)
192
+ # sim.masked_fill_(~mask, max_neg_value)
193
+
194
+ # attention, what we cannot get enough of
195
+ sim = sim.softmax(dim=-1)
196
+
197
+ out = einsum('b i j, b j d -> b i d', sim, v)
198
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
199
+ return self.to_out(out)
200
+
201
+
202
+ class MemoryEfficientCrossAttention(nn.Module):
203
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
204
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, checkpoint=False):
205
+ super().__init__()
206
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
207
+ f"{heads} heads.")
208
+ inner_dim = dim_head * heads
209
+ context_dim = default(context_dim, query_dim)
210
+
211
+ self.heads = heads
212
+ self.dim_head = dim_head
213
+
214
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
215
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
216
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
217
+
218
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
219
+ self.attention_op: Optional[Any] = None
220
+ self.checkpoint = checkpoint
221
+
222
+ def forward(self, x, context=None):
223
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
224
+
225
+ def _forward(self, x, context=None):
226
+ q = self.to_q(x)
227
+ context = default(context, x)
228
+ k = self.to_k(context)
229
+ v = self.to_v(context)
230
+
231
+ b, _, _ = q.shape
232
+ q, k, v = map(
233
+ lambda t: t.unsqueeze(3)
234
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
235
+ .permute(0, 2, 1, 3)
236
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
237
+ .contiguous(),
238
+ (q, k, v),
239
+ )
240
+
241
+ # actually compute the attention, what we cannot get enough of
242
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
243
+
244
+ out = (
245
+ out.unsqueeze(0)
246
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
247
+ .permute(0, 2, 1, 3)
248
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
249
+ )
250
+ return self.to_out(out)
251
+
252
+
253
+ class BasicTransformerBlock(nn.Module):
254
+ ATTENTION_MODES = {
255
+ "softmax": CrossAttention, # vanilla attention
256
+ "softmax-xformers": MemoryEfficientCrossAttention
257
+ }
258
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
259
+ disable_self_attn=False):
260
+ super().__init__()
261
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
262
+ assert attn_mode in self.ATTENTION_MODES
263
+ attn_cls = self.ATTENTION_MODES[attn_mode]
264
+ self.disable_self_attn = disable_self_attn
265
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, checkpoint=checkpoint,
266
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
267
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
268
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
269
+ heads=n_heads, dim_head=d_head, dropout=dropout, checkpoint=checkpoint,) # is self-attn if context is none
270
+ self.norm1 = nn.LayerNorm(dim)
271
+ self.norm2 = nn.LayerNorm(dim)
272
+ self.norm3 = nn.LayerNorm(dim)
273
+ self.checkpoint = checkpoint
274
+
275
+ # def forward(self, x, context=None, banks=None, attention_mode=None, attn_index=None,uc=False):
276
+ # return checkpoint(self._forward, (x, context, banks, attention_mode, attn_index,uc), self.parameters(), self.checkpoint)
277
+
278
+ def forward(self, x, context=None, banks=None, attention_mode=None, attn_index=None,uc=False):
279
+
280
+ if uc or attention_mode is None:
281
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
282
+ else:
283
+
284
+ x_norm1 = self.norm1(x)
285
+ self_attn1 = None
286
+ self_attention_context = x_norm1
287
+ if attention_mode == 'write':
288
+ bank = []
289
+ bank.append(self_attention_context)
290
+ # bank.append(self_attention_context.detach().clone())
291
+
292
+ banks.append(bank)
293
+
294
+ if self_attn1 is None:
295
+
296
+ self_attn1 = self.attn1(x_norm1, context=self_attention_context)
297
+
298
+ x = self_attn1 + x
299
+
300
+
301
+ elif attention_mode == 'read':
302
+
303
+ current_bank = banks[attn_index]
304
+
305
+ if len(current_bank) > 0:
306
+ tmp = [self_attention_context] + current_bank
307
+ self_attn1 = self.attn1(x_norm1, context=torch.cat([self_attention_context] + current_bank, dim=1))
308
+
309
+ if self_attn1 is None:
310
+
311
+ self_attn1 = self.attn1(x_norm1, context=self_attention_context)
312
+
313
+ x = self_attn1 + x
314
+
315
+ else:
316
+ raise NotImplementedError
317
+
318
+ x = self.attn2(self.norm2(x), context=context) + x
319
+ x = self.ff(self.norm3(x)) + x
320
+ return x
321
+
322
+
323
+ class SpatialTransformer(nn.Module):
324
+ """
325
+ Transformer block for image-like data.
326
+ First, project the input (aka embedding)
327
+ and reshape to b, t, d.
328
+ Then apply standard transformer action.
329
+ Finally, reshape to image
330
+ NEW: use_linear for more efficiency instead of the 1x1 convs
331
+ """
332
+ def __init__(self, in_channels, n_heads, d_head,
333
+ depth=1, dropout=0., context_dim=None,
334
+ disable_self_attn=False, use_linear=False,
335
+ use_checkpoint=True):
336
+ super().__init__()
337
+ if exists(context_dim) and not isinstance(context_dim, list):
338
+ context_dim = [context_dim]
339
+ self.in_channels = in_channels
340
+ inner_dim = n_heads * d_head
341
+ self.norm = Normalize(in_channels)
342
+ if not use_linear:
343
+ self.proj_in = nn.Conv2d(in_channels,
344
+ inner_dim,
345
+ kernel_size=1,
346
+ stride=1,
347
+ padding=0)
348
+ else:
349
+ self.proj_in = nn.Linear(in_channels, inner_dim)
350
+
351
+ self.transformer_blocks = nn.ModuleList(
352
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
353
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
354
+ for d in range(depth)]
355
+ )
356
+ if not use_linear:
357
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
358
+ in_channels,
359
+ kernel_size=1,
360
+ stride=1,
361
+ padding=0))
362
+ else:
363
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
364
+ self.use_linear = use_linear
365
+
366
+ def forward(self, x, context=None, banks=None, attention_mode=None, attn_index=None,uc=False):
367
+ # note: if no context is given, cross-attention defaults to self-attention
368
+ if not isinstance(context, list):
369
+ context = [context]
370
+ b, c, h, w = x.shape
371
+ x_in = x
372
+ x = self.norm(x)
373
+ if not self.use_linear:
374
+ x = self.proj_in(x)
375
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
376
+ if self.use_linear:
377
+ x = self.proj_in(x)
378
+ for i, block in enumerate(self.transformer_blocks):
379
+ x = block(x, context=context[i], banks=banks, attention_mode=attention_mode, attn_index=attn_index,uc=uc)
380
+ if self.use_linear:
381
+ x = self.proj_out(x)
382
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
383
+ if not self.use_linear:
384
+ x = self.proj_out(x)
385
+ return x + x_in
386
+
model_lib/ControlNet/ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
model_lib/ControlNet/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (236 Bytes). View file
 
model_lib/ControlNet/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc ADDED
Binary file (21.6 kB). View file
 
model_lib/ControlNet/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc ADDED
Binary file (26.4 kB). View file
 
model_lib/ControlNet/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc ADDED
Binary file (10.2 kB). View file
 
model_lib/ControlNet/ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,859 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+ from typing import Optional, Any
8
+
9
+ from model_lib.ControlNet.ldm.modules.attention import MemoryEfficientCrossAttention
10
+
11
+ # try:
12
+ import xformers
13
+ import xformers.ops
14
+ XFORMERS_IS_AVAILBLE = True
15
+ # except:
16
+ # XFORMERS_IS_AVAILBLE = False
17
+ # print("No module 'xformers'. Proceeding without it.")
18
+
19
+
20
+ def get_timestep_embedding(timesteps, embedding_dim):
21
+ """
22
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
23
+ From Fairseq.
24
+ Build sinusoidal embeddings.
25
+ This matches the implementation in tensor2tensor, but differs slightly
26
+ from the description in Section 3.5 of "Attention Is All You Need".
27
+ """
28
+ assert len(timesteps.shape) == 1
29
+
30
+ half_dim = embedding_dim // 2
31
+ emb = math.log(10000) / (half_dim - 1)
32
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
33
+ emb = emb.to(device=timesteps.device)
34
+ emb = timesteps.float()[:, None] * emb[None, :]
35
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
36
+ if embedding_dim % 2 == 1: # zero pad
37
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
38
+ return emb
39
+
40
+
41
+ def nonlinearity(x):
42
+ # swish
43
+ return x*torch.sigmoid(x)
44
+
45
+
46
+ def Normalize(in_channels, num_groups=32):
47
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
48
+
49
+
50
+ class Upsample(nn.Module):
51
+ def __init__(self, in_channels, with_conv):
52
+ super().__init__()
53
+ self.with_conv = with_conv
54
+ if self.with_conv:
55
+ self.conv = torch.nn.Conv2d(in_channels,
56
+ in_channels,
57
+ kernel_size=3,
58
+ stride=1,
59
+ padding=1)
60
+
61
+ def nearest_neighbor_upsample(self, x: torch.Tensor, scale_factor: int):
62
+ # Upsample {x} (NCHW) by scale factor {scale_factor} using nearest neighbor interpolation.
63
+ s = scale_factor
64
+ return x.reshape(*x.shape, 1, 1).expand(*x.shape, s, s).transpose(-2, -3).reshape(*x.shape[:2], *(s * hw for hw in x.shape[2:]))
65
+
66
+
67
+ def forward(self, x):
68
+ # x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
69
+ x = self.nearest_neighbor_upsample(x, scale_factor=2)
70
+ if self.with_conv:
71
+ x = self.conv(x)
72
+ return x
73
+
74
+
75
+ class Downsample(nn.Module):
76
+ def __init__(self, in_channels, with_conv):
77
+ super().__init__()
78
+ self.with_conv = with_conv
79
+ if self.with_conv:
80
+ # no asymmetric padding in torch conv, must do it ourselves
81
+ self.conv = torch.nn.Conv2d(in_channels,
82
+ in_channels,
83
+ kernel_size=3,
84
+ stride=2,
85
+ padding=0)
86
+
87
+ def forward(self, x):
88
+ if self.with_conv:
89
+ pad = (0,1,0,1)
90
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
91
+ x = self.conv(x)
92
+ else:
93
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
94
+ return x
95
+
96
+
97
+ class ResnetBlock(nn.Module):
98
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
99
+ dropout, temb_channels=512):
100
+ super().__init__()
101
+ self.in_channels = in_channels
102
+ out_channels = in_channels if out_channels is None else out_channels
103
+ self.out_channels = out_channels
104
+ self.use_conv_shortcut = conv_shortcut
105
+
106
+ self.norm1 = Normalize(in_channels)
107
+ self.conv1 = torch.nn.Conv2d(in_channels,
108
+ out_channels,
109
+ kernel_size=3,
110
+ stride=1,
111
+ padding=1)
112
+ if temb_channels > 0:
113
+ self.temb_proj = torch.nn.Linear(temb_channels,
114
+ out_channels)
115
+ self.norm2 = Normalize(out_channels)
116
+ self.dropout = torch.nn.Dropout(dropout)
117
+ self.conv2 = torch.nn.Conv2d(out_channels,
118
+ out_channels,
119
+ kernel_size=3,
120
+ stride=1,
121
+ padding=1)
122
+ if self.in_channels != self.out_channels:
123
+ if self.use_conv_shortcut:
124
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
125
+ out_channels,
126
+ kernel_size=3,
127
+ stride=1,
128
+ padding=1)
129
+ else:
130
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
131
+ out_channels,
132
+ kernel_size=1,
133
+ stride=1,
134
+ padding=0)
135
+
136
+ def forward(self, x, temb):
137
+ h = x
138
+ h = self.norm1(h)
139
+ h = nonlinearity(h)
140
+ h = self.conv1(h)
141
+
142
+ if temb is not None:
143
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
144
+
145
+ h = self.norm2(h)
146
+ h = nonlinearity(h)
147
+ h = self.dropout(h)
148
+ h = self.conv2(h)
149
+
150
+ if self.in_channels != self.out_channels:
151
+ if self.use_conv_shortcut:
152
+ x = self.conv_shortcut(x)
153
+ else:
154
+ x = self.nin_shortcut(x)
155
+
156
+ return x+h
157
+
158
+
159
+ class AttnBlock(nn.Module):
160
+ def __init__(self, in_channels):
161
+ super().__init__()
162
+ self.in_channels = in_channels
163
+
164
+ self.norm = Normalize(in_channels)
165
+ self.q = torch.nn.Conv2d(in_channels,
166
+ in_channels,
167
+ kernel_size=1,
168
+ stride=1,
169
+ padding=0)
170
+ self.k = torch.nn.Conv2d(in_channels,
171
+ in_channels,
172
+ kernel_size=1,
173
+ stride=1,
174
+ padding=0)
175
+ self.v = torch.nn.Conv2d(in_channels,
176
+ in_channels,
177
+ kernel_size=1,
178
+ stride=1,
179
+ padding=0)
180
+ self.proj_out = torch.nn.Conv2d(in_channels,
181
+ in_channels,
182
+ kernel_size=1,
183
+ stride=1,
184
+ padding=0)
185
+
186
+ def forward(self, x):
187
+ h_ = x
188
+ h_ = self.norm(h_)
189
+ q = self.q(h_)
190
+ k = self.k(h_)
191
+ v = self.v(h_)
192
+
193
+ # compute attention
194
+ b,c,h,w = q.shape
195
+ q = q.reshape(b,c,h*w)
196
+ q = q.permute(0,2,1) # b,hw,c
197
+ k = k.reshape(b,c,h*w) # b,c,hw
198
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
199
+ w_ = w_ * (int(c)**(-0.5))
200
+ w_ = torch.nn.functional.softmax(w_, dim=2)
201
+
202
+ # attend to values
203
+ v = v.reshape(b,c,h*w)
204
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
205
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
206
+ h_ = h_.reshape(b,c,h,w)
207
+
208
+ h_ = self.proj_out(h_)
209
+
210
+ return x+h_
211
+
212
+ class MemoryEfficientAttnBlock(nn.Module):
213
+ """
214
+ Uses xformers efficient implementation,
215
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
216
+ Note: this is a single-head self-attention operation
217
+ """
218
+ #
219
+ def __init__(self, in_channels):
220
+ super().__init__()
221
+ self.in_channels = in_channels
222
+
223
+ self.norm = Normalize(in_channels)
224
+ self.q = torch.nn.Conv2d(in_channels,
225
+ in_channels,
226
+ kernel_size=1,
227
+ stride=1,
228
+ padding=0)
229
+ self.k = torch.nn.Conv2d(in_channels,
230
+ in_channels,
231
+ kernel_size=1,
232
+ stride=1,
233
+ padding=0)
234
+ self.v = torch.nn.Conv2d(in_channels,
235
+ in_channels,
236
+ kernel_size=1,
237
+ stride=1,
238
+ padding=0)
239
+ self.proj_out = torch.nn.Conv2d(in_channels,
240
+ in_channels,
241
+ kernel_size=1,
242
+ stride=1,
243
+ padding=0)
244
+ self.attention_op: Optional[Any] = None
245
+
246
+ def forward(self, x):
247
+ h_ = x
248
+ h_ = self.norm(h_)
249
+ q = self.q(h_)
250
+ k = self.k(h_)
251
+ v = self.v(h_)
252
+
253
+ # compute attention
254
+ B, C, H, W = q.shape
255
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
256
+
257
+ q, k, v = map(
258
+ lambda t: t.unsqueeze(3)
259
+ .reshape(B, t.shape[1], 1, C)
260
+ .permute(0, 2, 1, 3)
261
+ .reshape(B * 1, t.shape[1], C)
262
+ .contiguous(),
263
+ (q, k, v),
264
+ )
265
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
266
+
267
+ out = (
268
+ out.unsqueeze(0)
269
+ .reshape(B, 1, out.shape[1], C)
270
+ .permute(0, 2, 1, 3)
271
+ .reshape(B, out.shape[1], C)
272
+ )
273
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
274
+ out = self.proj_out(out)
275
+ return x+out
276
+
277
+
278
+ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
279
+ def forward(self, x, context=None, mask=None):
280
+ b, c, h, w = x.shape
281
+ x = rearrange(x, 'b c h w -> b (h w) c')
282
+ out = super().forward(x, context=context, mask=mask)
283
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
284
+ return x + out
285
+
286
+
287
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
288
+ assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
289
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
290
+ attn_type = "vanilla-xformers"
291
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
292
+ if attn_type == "vanilla":
293
+ assert attn_kwargs is None
294
+ return AttnBlock(in_channels)
295
+ elif attn_type == "vanilla-xformers":
296
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
297
+ return MemoryEfficientAttnBlock(in_channels)
298
+ elif type == "memory-efficient-cross-attn":
299
+ attn_kwargs["query_dim"] = in_channels
300
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
301
+ elif attn_type == "none":
302
+ return nn.Identity(in_channels)
303
+ else:
304
+ raise NotImplementedError()
305
+
306
+
307
+ class Model(nn.Module):
308
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
309
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
310
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
311
+ super().__init__()
312
+ if use_linear_attn: attn_type = "linear"
313
+ self.ch = ch
314
+ self.temb_ch = self.ch*4
315
+ self.num_resolutions = len(ch_mult)
316
+ self.num_res_blocks = num_res_blocks
317
+ self.resolution = resolution
318
+ self.in_channels = in_channels
319
+
320
+ self.use_timestep = use_timestep
321
+ if self.use_timestep:
322
+ # timestep embedding
323
+ self.temb = nn.Module()
324
+ self.temb.dense = nn.ModuleList([
325
+ torch.nn.Linear(self.ch,
326
+ self.temb_ch),
327
+ torch.nn.Linear(self.temb_ch,
328
+ self.temb_ch),
329
+ ])
330
+
331
+ # downsampling
332
+ self.conv_in = torch.nn.Conv2d(in_channels,
333
+ self.ch,
334
+ kernel_size=3,
335
+ stride=1,
336
+ padding=1)
337
+
338
+ curr_res = resolution
339
+ in_ch_mult = (1,)+tuple(ch_mult)
340
+ self.down = nn.ModuleList()
341
+ for i_level in range(self.num_resolutions):
342
+ block = nn.ModuleList()
343
+ attn = nn.ModuleList()
344
+ block_in = ch*in_ch_mult[i_level]
345
+ block_out = ch*ch_mult[i_level]
346
+ for i_block in range(self.num_res_blocks):
347
+ block.append(ResnetBlock(in_channels=block_in,
348
+ out_channels=block_out,
349
+ temb_channels=self.temb_ch,
350
+ dropout=dropout))
351
+ block_in = block_out
352
+ if curr_res in attn_resolutions:
353
+ attn.append(make_attn(block_in, attn_type=attn_type))
354
+ down = nn.Module()
355
+ down.block = block
356
+ down.attn = attn
357
+ if i_level != self.num_resolutions-1:
358
+ down.downsample = Downsample(block_in, resamp_with_conv)
359
+ curr_res = curr_res // 2
360
+ self.down.append(down)
361
+
362
+ # middle
363
+ self.mid = nn.Module()
364
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
365
+ out_channels=block_in,
366
+ temb_channels=self.temb_ch,
367
+ dropout=dropout)
368
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
369
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
370
+ out_channels=block_in,
371
+ temb_channels=self.temb_ch,
372
+ dropout=dropout)
373
+
374
+ # upsampling
375
+ self.up = nn.ModuleList()
376
+ for i_level in reversed(range(self.num_resolutions)):
377
+ block = nn.ModuleList()
378
+ attn = nn.ModuleList()
379
+ block_out = ch*ch_mult[i_level]
380
+ skip_in = ch*ch_mult[i_level]
381
+ for i_block in range(self.num_res_blocks+1):
382
+ if i_block == self.num_res_blocks:
383
+ skip_in = ch*in_ch_mult[i_level]
384
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
385
+ out_channels=block_out,
386
+ temb_channels=self.temb_ch,
387
+ dropout=dropout))
388
+ block_in = block_out
389
+ if curr_res in attn_resolutions:
390
+ attn.append(make_attn(block_in, attn_type=attn_type))
391
+ up = nn.Module()
392
+ up.block = block
393
+ up.attn = attn
394
+ if i_level != 0:
395
+ up.upsample = Upsample(block_in, resamp_with_conv)
396
+ curr_res = curr_res * 2
397
+ self.up.insert(0, up) # prepend to get consistent order
398
+
399
+ # end
400
+ self.norm_out = Normalize(block_in)
401
+ self.conv_out = torch.nn.Conv2d(block_in,
402
+ out_ch,
403
+ kernel_size=3,
404
+ stride=1,
405
+ padding=1)
406
+
407
+ def forward(self, x, t=None, context=None):
408
+ #assert x.shape[2] == x.shape[3] == self.resolution
409
+ if context is not None:
410
+ # assume aligned context, cat along channel axis
411
+ x = torch.cat((x, context), dim=1)
412
+ if self.use_timestep:
413
+ # timestep embedding
414
+ assert t is not None
415
+ temb = get_timestep_embedding(t, self.ch)
416
+ temb = self.temb.dense[0](temb)
417
+ temb = nonlinearity(temb)
418
+ temb = self.temb.dense[1](temb)
419
+ else:
420
+ temb = None
421
+
422
+ # downsampling
423
+ hs = [self.conv_in(x)]
424
+ for i_level in range(self.num_resolutions):
425
+ for i_block in range(self.num_res_blocks):
426
+ h = self.down[i_level].block[i_block](hs[-1], temb)
427
+ if len(self.down[i_level].attn) > 0:
428
+ h = self.down[i_level].attn[i_block](h)
429
+ hs.append(h)
430
+ if i_level != self.num_resolutions-1:
431
+ hs.append(self.down[i_level].downsample(hs[-1]))
432
+
433
+ # middle
434
+ h = hs[-1]
435
+ h = self.mid.block_1(h, temb)
436
+ h = self.mid.attn_1(h)
437
+ h = self.mid.block_2(h, temb)
438
+
439
+ # upsampling
440
+ for i_level in reversed(range(self.num_resolutions)):
441
+ for i_block in range(self.num_res_blocks+1):
442
+ h = self.up[i_level].block[i_block](
443
+ torch.cat([h, hs.pop()], dim=1), temb)
444
+ if len(self.up[i_level].attn) > 0:
445
+ h = self.up[i_level].attn[i_block](h)
446
+ if i_level != 0:
447
+ h = self.up[i_level].upsample(h)
448
+
449
+ # end
450
+ h = self.norm_out(h)
451
+ h = nonlinearity(h)
452
+ h = self.conv_out(h)
453
+ return h
454
+
455
+ def get_last_layer(self):
456
+ return self.conv_out.weight
457
+
458
+
459
+ class Encoder(nn.Module):
460
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
461
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
462
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
463
+ **ignore_kwargs):
464
+ super().__init__()
465
+ if use_linear_attn: attn_type = "linear"
466
+ self.ch = ch
467
+ self.temb_ch = 0
468
+ self.num_resolutions = len(ch_mult)
469
+ self.num_res_blocks = num_res_blocks
470
+ self.resolution = resolution
471
+ self.in_channels = in_channels
472
+
473
+ # downsampling
474
+ self.conv_in = torch.nn.Conv2d(in_channels,
475
+ self.ch,
476
+ kernel_size=3,
477
+ stride=1,
478
+ padding=1)
479
+
480
+ curr_res = resolution
481
+ in_ch_mult = (1,)+tuple(ch_mult)
482
+ self.in_ch_mult = in_ch_mult
483
+ self.down = nn.ModuleList()
484
+ for i_level in range(self.num_resolutions):
485
+ block = nn.ModuleList()
486
+ attn = nn.ModuleList()
487
+ block_in = ch*in_ch_mult[i_level]
488
+ block_out = ch*ch_mult[i_level]
489
+ for i_block in range(self.num_res_blocks):
490
+ block.append(ResnetBlock(in_channels=block_in,
491
+ out_channels=block_out,
492
+ temb_channels=self.temb_ch,
493
+ dropout=dropout))
494
+ block_in = block_out
495
+ if curr_res in attn_resolutions:
496
+ attn.append(make_attn(block_in, attn_type=attn_type))
497
+ down = nn.Module()
498
+ down.block = block
499
+ down.attn = attn
500
+ if i_level != self.num_resolutions-1:
501
+ down.downsample = Downsample(block_in, resamp_with_conv)
502
+ curr_res = curr_res // 2
503
+ self.down.append(down)
504
+
505
+ # middle
506
+ self.mid = nn.Module()
507
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
508
+ out_channels=block_in,
509
+ temb_channels=self.temb_ch,
510
+ dropout=dropout)
511
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
512
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
513
+ out_channels=block_in,
514
+ temb_channels=self.temb_ch,
515
+ dropout=dropout)
516
+
517
+ # end
518
+ self.norm_out = Normalize(block_in)
519
+ self.conv_out = torch.nn.Conv2d(block_in,
520
+ 2*z_channels if double_z else z_channels,
521
+ kernel_size=3,
522
+ stride=1,
523
+ padding=1)
524
+
525
+ def forward(self, x):
526
+ # timestep embedding
527
+ temb = None
528
+
529
+ # downsampling
530
+ hs = [self.conv_in(x)]
531
+ for i_level in range(self.num_resolutions):
532
+ for i_block in range(self.num_res_blocks):
533
+ h = self.down[i_level].block[i_block](hs[-1], temb)
534
+ if len(self.down[i_level].attn) > 0:
535
+ h = self.down[i_level].attn[i_block](h)
536
+ hs.append(h)
537
+ if i_level != self.num_resolutions-1:
538
+ hs.append(self.down[i_level].downsample(hs[-1]))
539
+
540
+ # middle
541
+ h = hs[-1]
542
+ h = self.mid.block_1(h, temb)
543
+ h = self.mid.attn_1(h)
544
+ h = self.mid.block_2(h, temb)
545
+
546
+ # end
547
+ h = self.norm_out(h)
548
+ h = nonlinearity(h)
549
+ h = self.conv_out(h)
550
+ return h
551
+
552
+
553
+ class Decoder(nn.Module):
554
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
555
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
556
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
557
+ attn_type="vanilla", **ignorekwargs):
558
+ super().__init__()
559
+ if use_linear_attn: attn_type = "linear"
560
+ self.ch = ch
561
+ self.temb_ch = 0
562
+ self.num_resolutions = len(ch_mult)
563
+ self.num_res_blocks = num_res_blocks
564
+ self.resolution = resolution
565
+ self.in_channels = in_channels
566
+ self.give_pre_end = give_pre_end
567
+ self.tanh_out = tanh_out
568
+
569
+ # compute in_ch_mult, block_in and curr_res at lowest res
570
+ in_ch_mult = (1,)+tuple(ch_mult)
571
+ block_in = ch*ch_mult[self.num_resolutions-1]
572
+ curr_res = resolution // 2**(self.num_resolutions-1)
573
+ self.z_shape = (1,z_channels,curr_res,curr_res)
574
+ print("Working with z of shape {} = {} dimensions.".format(
575
+ self.z_shape, np.prod(self.z_shape)))
576
+
577
+ # z to block_in
578
+ self.conv_in = torch.nn.Conv2d(z_channels,
579
+ block_in,
580
+ kernel_size=3,
581
+ stride=1,
582
+ padding=1)
583
+
584
+ # middle
585
+ self.mid = nn.Module()
586
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
587
+ out_channels=block_in,
588
+ temb_channels=self.temb_ch,
589
+ dropout=dropout)
590
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
591
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
592
+ out_channels=block_in,
593
+ temb_channels=self.temb_ch,
594
+ dropout=dropout)
595
+
596
+ # upsampling
597
+ self.up = nn.ModuleList()
598
+ for i_level in reversed(range(self.num_resolutions)):
599
+ block = nn.ModuleList()
600
+ attn = nn.ModuleList()
601
+ block_out = ch*ch_mult[i_level]
602
+ for i_block in range(self.num_res_blocks+1):
603
+ block.append(ResnetBlock(in_channels=block_in,
604
+ out_channels=block_out,
605
+ temb_channels=self.temb_ch,
606
+ dropout=dropout))
607
+ block_in = block_out
608
+ if curr_res in attn_resolutions:
609
+ attn.append(make_attn(block_in, attn_type=attn_type))
610
+ up = nn.Module()
611
+ up.block = block
612
+ up.attn = attn
613
+ if i_level != 0:
614
+ up.upsample = Upsample(block_in, resamp_with_conv)
615
+ curr_res = curr_res * 2
616
+ self.up.insert(0, up) # prepend to get consistent order
617
+
618
+ # end
619
+ self.norm_out = Normalize(block_in)
620
+ self.conv_out = torch.nn.Conv2d(block_in,
621
+ out_ch,
622
+ kernel_size=3,
623
+ stride=1,
624
+ padding=1)
625
+
626
+ def forward(self, z):
627
+ #assert z.shape[1:] == self.z_shape[1:]
628
+ self.last_z_shape = z.shape
629
+
630
+ # timestep embedding
631
+ temb = None
632
+
633
+ # z to block_in
634
+ h = self.conv_in(z)
635
+
636
+ # middle
637
+ h = self.mid.block_1(h, temb)
638
+ h = self.mid.attn_1(h)
639
+ h = self.mid.block_2(h, temb)
640
+
641
+ # upsampling
642
+ for i_level in reversed(range(self.num_resolutions)):
643
+ for i_block in range(self.num_res_blocks+1):
644
+ h = self.up[i_level].block[i_block](h, temb)
645
+ if len(self.up[i_level].attn) > 0:
646
+ h = self.up[i_level].attn[i_block](h)
647
+ if i_level != 0:
648
+ h = self.up[i_level].upsample(h)
649
+
650
+ # end
651
+ if self.give_pre_end:
652
+ return h
653
+
654
+ h = self.norm_out(h)
655
+ h = nonlinearity(h)
656
+ h = self.conv_out(h)
657
+ if self.tanh_out:
658
+ h = torch.tanh(h)
659
+ return h
660
+
661
+
662
+ class SimpleDecoder(nn.Module):
663
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
664
+ super().__init__()
665
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
666
+ ResnetBlock(in_channels=in_channels,
667
+ out_channels=2 * in_channels,
668
+ temb_channels=0, dropout=0.0),
669
+ ResnetBlock(in_channels=2 * in_channels,
670
+ out_channels=4 * in_channels,
671
+ temb_channels=0, dropout=0.0),
672
+ ResnetBlock(in_channels=4 * in_channels,
673
+ out_channels=2 * in_channels,
674
+ temb_channels=0, dropout=0.0),
675
+ nn.Conv2d(2*in_channels, in_channels, 1),
676
+ Upsample(in_channels, with_conv=True)])
677
+ # end
678
+ self.norm_out = Normalize(in_channels)
679
+ self.conv_out = torch.nn.Conv2d(in_channels,
680
+ out_channels,
681
+ kernel_size=3,
682
+ stride=1,
683
+ padding=1)
684
+
685
+ def forward(self, x):
686
+ for i, layer in enumerate(self.model):
687
+ if i in [1,2,3]:
688
+ x = layer(x, None)
689
+ else:
690
+ x = layer(x)
691
+
692
+ h = self.norm_out(x)
693
+ h = nonlinearity(h)
694
+ x = self.conv_out(h)
695
+ return x
696
+
697
+
698
+ class UpsampleDecoder(nn.Module):
699
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
700
+ ch_mult=(2,2), dropout=0.0):
701
+ super().__init__()
702
+ # upsampling
703
+ self.temb_ch = 0
704
+ self.num_resolutions = len(ch_mult)
705
+ self.num_res_blocks = num_res_blocks
706
+ block_in = in_channels
707
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
708
+ self.res_blocks = nn.ModuleList()
709
+ self.upsample_blocks = nn.ModuleList()
710
+ for i_level in range(self.num_resolutions):
711
+ res_block = []
712
+ block_out = ch * ch_mult[i_level]
713
+ for i_block in range(self.num_res_blocks + 1):
714
+ res_block.append(ResnetBlock(in_channels=block_in,
715
+ out_channels=block_out,
716
+ temb_channels=self.temb_ch,
717
+ dropout=dropout))
718
+ block_in = block_out
719
+ self.res_blocks.append(nn.ModuleList(res_block))
720
+ if i_level != self.num_resolutions - 1:
721
+ self.upsample_blocks.append(Upsample(block_in, True))
722
+ curr_res = curr_res * 2
723
+
724
+ # end
725
+ self.norm_out = Normalize(block_in)
726
+ self.conv_out = torch.nn.Conv2d(block_in,
727
+ out_channels,
728
+ kernel_size=3,
729
+ stride=1,
730
+ padding=1)
731
+
732
+ def forward(self, x):
733
+ # upsampling
734
+ h = x
735
+ for k, i_level in enumerate(range(self.num_resolutions)):
736
+ for i_block in range(self.num_res_blocks + 1):
737
+ h = self.res_blocks[i_level][i_block](h, None)
738
+ if i_level != self.num_resolutions - 1:
739
+ h = self.upsample_blocks[k](h)
740
+ h = self.norm_out(h)
741
+ h = nonlinearity(h)
742
+ h = self.conv_out(h)
743
+ return h
744
+
745
+
746
+ class LatentRescaler(nn.Module):
747
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
748
+ super().__init__()
749
+ # residual block, interpolate, residual block
750
+ self.factor = factor
751
+ self.conv_in = nn.Conv2d(in_channels,
752
+ mid_channels,
753
+ kernel_size=3,
754
+ stride=1,
755
+ padding=1)
756
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
757
+ out_channels=mid_channels,
758
+ temb_channels=0,
759
+ dropout=0.0) for _ in range(depth)])
760
+ self.attn = AttnBlock(mid_channels)
761
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
762
+ out_channels=mid_channels,
763
+ temb_channels=0,
764
+ dropout=0.0) for _ in range(depth)])
765
+
766
+ self.conv_out = nn.Conv2d(mid_channels,
767
+ out_channels,
768
+ kernel_size=1,
769
+ )
770
+
771
+ def forward(self, x):
772
+ x = self.conv_in(x)
773
+ for block in self.res_block1:
774
+ x = block(x, None)
775
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
776
+ x = self.attn(x)
777
+ for block in self.res_block2:
778
+ x = block(x, None)
779
+ x = self.conv_out(x)
780
+ return x
781
+
782
+
783
+ class MergedRescaleEncoder(nn.Module):
784
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
785
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
786
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
787
+ super().__init__()
788
+ intermediate_chn = ch * ch_mult[-1]
789
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
790
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
791
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
792
+ out_ch=None)
793
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
794
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
795
+
796
+ def forward(self, x):
797
+ x = self.encoder(x)
798
+ x = self.rescaler(x)
799
+ return x
800
+
801
+
802
+ class MergedRescaleDecoder(nn.Module):
803
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
804
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
805
+ super().__init__()
806
+ tmp_chn = z_channels*ch_mult[-1]
807
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
808
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
809
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
810
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
811
+ out_channels=tmp_chn, depth=rescale_module_depth)
812
+
813
+ def forward(self, x):
814
+ x = self.rescaler(x)
815
+ x = self.decoder(x)
816
+ return x
817
+
818
+
819
+ class Upsampler(nn.Module):
820
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
821
+ super().__init__()
822
+ assert out_size >= in_size
823
+ num_blocks = int(np.log2(out_size//in_size))+1
824
+ factor_up = 1.+ (out_size % in_size)
825
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
826
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
827
+ out_channels=in_channels)
828
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
829
+ attn_resolutions=[], in_channels=None, ch=in_channels,
830
+ ch_mult=[ch_mult for _ in range(num_blocks)])
831
+
832
+ def forward(self, x):
833
+ x = self.rescaler(x)
834
+ x = self.decoder(x)
835
+ return x
836
+
837
+
838
+ class Resize(nn.Module):
839
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
840
+ super().__init__()
841
+ self.with_conv = learned
842
+ self.mode = mode
843
+ if self.with_conv:
844
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
845
+ raise NotImplementedError()
846
+ assert in_channels is not None
847
+ # no asymmetric padding in torch conv, must do it ourselves
848
+ self.conv = torch.nn.Conv2d(in_channels,
849
+ in_channels,
850
+ kernel_size=4,
851
+ stride=2,
852
+ padding=1)
853
+
854
+ def forward(self, x, scale_factor=1.0):
855
+ if scale_factor==1.0:
856
+ return x
857
+ else:
858
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
859
+ return x
model_lib/ControlNet/ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,1212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch as th
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import pdb
9
+ from model_lib.ControlNet.ldm.modules.diffusionmodules.util import (
10
+ checkpoint,
11
+ conv_nd,
12
+ linear,
13
+ avg_pool_nd,
14
+ zero_module,
15
+ normalization,
16
+ timestep_embedding,
17
+ )
18
+ from model_lib.ControlNet.ldm.modules.attention import SpatialTransformer
19
+ from model_lib.ControlNet.ldm.util import exists
20
+ from model_lib.ControlNet.ldm.modules.motion_module import get_motion_module, VanillaTemporalModule, TemporalTransformer3DModel
21
+
22
+ # dummy replace
23
+ def convert_module_to_f16(x):
24
+ pass
25
+
26
+ def convert_module_to_f32(x):
27
+ pass
28
+
29
+
30
+ ## go
31
+ class AttentionPool2d(nn.Module):
32
+ """
33
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ spacial_dim: int,
39
+ embed_dim: int,
40
+ num_heads_channels: int,
41
+ output_dim: int = None,
42
+ ):
43
+ super().__init__()
44
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
45
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
46
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
47
+ self.num_heads = embed_dim // num_heads_channels
48
+ self.attention = QKVAttention(self.num_heads)
49
+
50
+ def forward(self, x):
51
+ b, c, *_spatial = x.shape
52
+ x = x.reshape(b, c, -1) # NC(HW)
53
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
54
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
55
+ x = self.qkv_proj(x)
56
+ x = self.attention(x)
57
+ x = self.c_proj(x)
58
+ return x[:, :, 0]
59
+
60
+
61
+ class TimestepBlock(nn.Module):
62
+ """
63
+ Any module where forward() takes timestep embeddings as a second argument.
64
+ """
65
+
66
+ @abstractmethod
67
+ def forward(self, x, emb):
68
+ """
69
+ Apply the module to `x` given `emb` timestep embeddings.
70
+ """
71
+
72
+
73
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
74
+ """
75
+ A sequential module that passes timestep embeddings to the children that
76
+ support it as an extra input.
77
+ """
78
+
79
+ def forward(self, x, emb, context=None, banks=None, attention_mode=None, attn_index=None, uc=False):
80
+ for layer in self:
81
+
82
+ if isinstance(layer, TimestepBlock):
83
+ # print("layer TimestepBlock")
84
+ x = layer(x, emb)
85
+ elif isinstance(layer, SpatialTransformer):
86
+ # print("layer SpatialTransformer")
87
+ if uc:
88
+ x = layer(x, context,uc=uc)
89
+ else:
90
+ # pdb.set_trace()
91
+ x = layer(x, context, banks, attention_mode, attn_index)
92
+ if attention_mode == 'read':
93
+ attn_index+=1
94
+ elif isinstance(layer, VanillaTemporalModule):
95
+ # print("layer Motion Module")
96
+ # pdb.set_trace()
97
+ x = layer(x, context)
98
+ else:
99
+ # print("layer others")
100
+ # pdb.set_trace()
101
+ x = layer(x)
102
+
103
+ if attention_mode == 'write':
104
+ return x
105
+ if attention_mode == 'read':
106
+ return x, attn_index
107
+ else:
108
+ return x
109
+
110
+
111
+ class Upsample(nn.Module):
112
+ """
113
+ An upsampling layer with an optional convolution.
114
+ :param channels: channels in the inputs and outputs.
115
+ :param use_conv: a bool determining if a convolution is applied.
116
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
117
+ upsampling occurs in the inner-two dimensions.
118
+ """
119
+
120
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
121
+ super().__init__()
122
+ self.channels = channels
123
+ self.out_channels = out_channels or channels
124
+ self.use_conv = use_conv
125
+ self.dims = dims
126
+ if use_conv:
127
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
128
+
129
+ def nearest_neighbor_upsample(self, x: th.Tensor, scale_factor: int):
130
+ # Upsample {x} (NCHW) by scale factor {scale_factor} using nearest neighbor interpolation.
131
+ s = scale_factor
132
+ return x.reshape(*x.shape, 1, 1).expand(*x.shape, s, s).transpose(-2, -3).reshape(*x.shape[:2], *(s * hw for hw in x.shape[2:]))
133
+
134
+
135
+ def forward(self, x):
136
+ assert x.shape[1] == self.channels
137
+ if self.dims == 3:
138
+ x = F.interpolate(
139
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
140
+ )
141
+ else:
142
+ # x = F.interpolate(x, scale_factor=2, mode="nearest")
143
+ x = self.nearest_neighbor_upsample(x, scale_factor=2)
144
+
145
+ if self.use_conv:
146
+ x = self.conv(x)
147
+ return x
148
+
149
+ class TransposedUpsample(nn.Module):
150
+ 'Learned 2x upsampling without padding'
151
+ def __init__(self, channels, out_channels=None, ks=5):
152
+ super().__init__()
153
+ self.channels = channels
154
+ self.out_channels = out_channels or channels
155
+
156
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
157
+
158
+ def forward(self,x):
159
+ return self.up(x)
160
+
161
+
162
+ class Downsample(nn.Module):
163
+ """
164
+ A downsampling layer with an optional convolution.
165
+ :param channels: channels in the inputs and outputs.
166
+ :param use_conv: a bool determining if a convolution is applied.
167
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
168
+ downsampling occurs in the inner-two dimensions.
169
+ """
170
+
171
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
172
+ super().__init__()
173
+ self.channels = channels
174
+ self.out_channels = out_channels or channels
175
+ self.use_conv = use_conv
176
+ self.dims = dims
177
+ stride = 2 if dims != 3 else (1, 2, 2)
178
+ if use_conv:
179
+ self.op = conv_nd(
180
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
181
+ )
182
+ else:
183
+ assert self.channels == self.out_channels
184
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
185
+
186
+ def forward(self, x):
187
+ assert x.shape[1] == self.channels
188
+ return self.op(x)
189
+
190
+
191
+ class ResBlock(TimestepBlock):
192
+ """
193
+ A residual block that can optionally change the number of channels.
194
+ :param channels: the number of input channels.
195
+ :param emb_channels: the number of timestep embedding channels.
196
+ :param dropout: the rate of dropout.
197
+ :param out_channels: if specified, the number of out channels.
198
+ :param use_conv: if True and out_channels is specified, use a spatial
199
+ convolution instead of a smaller 1x1 convolution to change the
200
+ channels in the skip connection.
201
+ :param dims: determines if the signal is 1D, 2D, or 3D.
202
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
203
+ :param up: if True, use this block for upsampling.
204
+ :param down: if True, use this block for downsampling.
205
+ """
206
+
207
+ def __init__(
208
+ self,
209
+ channels,
210
+ emb_channels,
211
+ dropout,
212
+ out_channels=None,
213
+ use_conv=False,
214
+ use_scale_shift_norm=False,
215
+ dims=2,
216
+ use_checkpoint=False,
217
+ up=False,
218
+ down=False,
219
+ ):
220
+ super().__init__()
221
+ self.channels = channels
222
+ self.emb_channels = emb_channels
223
+ self.dropout = dropout
224
+ self.out_channels = out_channels or channels
225
+ self.use_conv = use_conv
226
+ self.use_checkpoint = use_checkpoint
227
+ self.use_scale_shift_norm = use_scale_shift_norm
228
+
229
+ self.in_layers = nn.Sequential(
230
+ normalization(channels),
231
+ nn.SiLU(),
232
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
233
+ )
234
+
235
+ self.updown = up or down
236
+
237
+ if up:
238
+ self.h_upd = Upsample(channels, False, dims)
239
+ self.x_upd = Upsample(channels, False, dims)
240
+ elif down:
241
+ self.h_upd = Downsample(channels, False, dims)
242
+ self.x_upd = Downsample(channels, False, dims)
243
+ else:
244
+ self.h_upd = self.x_upd = nn.Identity()
245
+
246
+ self.emb_layers = nn.Sequential(
247
+ nn.SiLU(),
248
+ linear(
249
+ emb_channels,
250
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
251
+ ),
252
+ )
253
+ self.out_layers = nn.Sequential(
254
+ normalization(self.out_channels),
255
+ nn.SiLU(),
256
+ nn.Dropout(p=dropout),
257
+ zero_module(
258
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
259
+ ),
260
+ )
261
+
262
+ if self.out_channels == channels:
263
+ self.skip_connection = nn.Identity()
264
+ elif use_conv:
265
+ self.skip_connection = conv_nd(
266
+ dims, channels, self.out_channels, 3, padding=1
267
+ )
268
+ else:
269
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
270
+
271
+ def forward(self, x, emb):
272
+ """
273
+ Apply the block to a Tensor, conditioned on a timestep embedding.
274
+ :param x: an [N x C x ...] Tensor of features.
275
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
276
+ :return: an [N x C x ...] Tensor of outputs.
277
+ """
278
+ return checkpoint(
279
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
280
+ )
281
+
282
+
283
+ def _forward(self, x, emb):
284
+ if self.updown:
285
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
286
+ h = in_rest(x)
287
+ h = self.h_upd(h)
288
+ x = self.x_upd(x)
289
+ h = in_conv(h)
290
+ else:
291
+ h = self.in_layers(x)
292
+ emb_out = self.emb_layers(emb).type(h.dtype)
293
+ while len(emb_out.shape) < len(h.shape):
294
+ emb_out = emb_out[..., None]
295
+ if self.use_scale_shift_norm:
296
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
297
+ scale, shift = th.chunk(emb_out, 2, dim=1)
298
+ h = out_norm(h) * (1 + scale) + shift
299
+ h = out_rest(h)
300
+ else:
301
+ h = h + emb_out
302
+ h = self.out_layers(h)
303
+ return self.skip_connection(x) + h
304
+
305
+ class AttentionBlock(nn.Module):
306
+ """
307
+ An attention block that allows spatial positions to attend to each other.
308
+ Originally ported from here, but adapted to the N-d case.
309
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
310
+ """
311
+
312
+ def __init__(
313
+ self,
314
+ channels,
315
+ num_heads=1,
316
+ num_head_channels=-1,
317
+ use_checkpoint=False,
318
+ use_new_attention_order=False,
319
+ ):
320
+ super().__init__()
321
+ self.channels = channels
322
+ if num_head_channels == -1:
323
+ self.num_heads = num_heads
324
+ else:
325
+ assert (
326
+ channels % num_head_channels == 0
327
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
328
+ self.num_heads = channels // num_head_channels
329
+ self.use_checkpoint = use_checkpoint
330
+ self.norm = normalization(channels)
331
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
332
+ if use_new_attention_order:
333
+ # split qkv before split heads
334
+ self.attention = QKVAttention(self.num_heads)
335
+ else:
336
+ # split heads before split qkv
337
+ self.attention = QKVAttentionLegacy(self.num_heads)
338
+
339
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
340
+
341
+ def forward(self, x):
342
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
343
+ #return pt_checkpoint(self._forward, x) # pytorch
344
+
345
+ def _forward(self, x):
346
+ b, c, *spatial = x.shape
347
+ x = x.reshape(b, c, -1)
348
+ qkv = self.qkv(self.norm(x))
349
+ h = self.attention(qkv)
350
+ h = self.proj_out(h)
351
+ return (x + h).reshape(b, c, *spatial)
352
+
353
+
354
+ def count_flops_attn(model, _x, y):
355
+ """
356
+ A counter for the `thop` package to count the operations in an
357
+ attention operation.
358
+ Meant to be used like:
359
+ macs, params = thop.profile(
360
+ model,
361
+ inputs=(inputs, timestamps),
362
+ custom_ops={QKVAttention: QKVAttention.count_flops},
363
+ )
364
+ """
365
+ b, c, *spatial = y[0].shape
366
+ num_spatial = int(np.prod(spatial))
367
+ # We perform two matmuls with the same number of ops.
368
+ # The first computes the weight matrix, the second computes
369
+ # the combination of the value vectors.
370
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
371
+ model.total_ops += th.DoubleTensor([matmul_ops])
372
+
373
+
374
+ class QKVAttentionLegacy(nn.Module):
375
+ """
376
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
377
+ """
378
+
379
+ def __init__(self, n_heads):
380
+ super().__init__()
381
+ self.n_heads = n_heads
382
+
383
+ def forward(self, qkv):
384
+ """
385
+ Apply QKV attention.
386
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
387
+ :return: an [N x (H * C) x T] tensor after attention.
388
+ """
389
+ bs, width, length = qkv.shape
390
+ assert width % (3 * self.n_heads) == 0
391
+ ch = width // (3 * self.n_heads)
392
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
393
+ scale = 1 / math.sqrt(math.sqrt(ch))
394
+ weight = th.einsum(
395
+ "bct,bcs->bts", q * scale, k * scale
396
+ ) # More stable with f16 than dividing afterwards
397
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
398
+ a = th.einsum("bts,bcs->bct", weight, v)
399
+ return a.reshape(bs, -1, length)
400
+
401
+ @staticmethod
402
+ def count_flops(model, _x, y):
403
+ return count_flops_attn(model, _x, y)
404
+
405
+
406
+ class QKVAttention(nn.Module):
407
+ """
408
+ A module which performs QKV attention and splits in a different order.
409
+ """
410
+
411
+ def __init__(self, n_heads):
412
+ super().__init__()
413
+ self.n_heads = n_heads
414
+
415
+ def forward(self, qkv):
416
+ """
417
+ Apply QKV attention.
418
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
419
+ :return: an [N x (H * C) x T] tensor after attention.
420
+ """
421
+ bs, width, length = qkv.shape
422
+ assert width % (3 * self.n_heads) == 0
423
+ ch = width // (3 * self.n_heads)
424
+ q, k, v = qkv.chunk(3, dim=1)
425
+ scale = 1 / math.sqrt(math.sqrt(ch))
426
+ weight = th.einsum(
427
+ "bct,bcs->bts",
428
+ (q * scale).view(bs * self.n_heads, ch, length),
429
+ (k * scale).view(bs * self.n_heads, ch, length),
430
+ ) # More stable with f16 than dividing afterwards
431
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
432
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
433
+ return a.reshape(bs, -1, length)
434
+
435
+ @staticmethod
436
+ def count_flops(model, _x, y):
437
+ return count_flops_attn(model, _x, y)
438
+
439
+
440
+ class UNetModel(nn.Module):
441
+ """
442
+ The full UNet model with attention and timestep embedding.
443
+ :param in_channels: channels in the input Tensor.
444
+ :param model_channels: base channel count for the model.
445
+ :param out_channels: channels in the output Tensor.
446
+ :param num_res_blocks: number of residual blocks per downsample.
447
+ :param attention_resolutions: a collection of downsample rates at which
448
+ attention will take place. May be a set, list, or tuple.
449
+ For example, if this contains 4, then at 4x downsampling, attention
450
+ will be used.
451
+ :param dropout: the dropout probability.
452
+ :param channel_mult: channel multiplier for each level of the UNet.
453
+ :param conv_resample: if True, use learned convolutions for upsampling and
454
+ downsampling.
455
+ :param dims: determines if the signal is 1D, 2D, or 3D.
456
+ :param num_classes: if specified (as an int), then this model will be
457
+ class-conditional with `num_classes` classes.
458
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
459
+ :param num_heads: the number of attention heads in each attention layer.
460
+ :param num_heads_channels: if specified, ignore num_heads and instead use
461
+ a fixed channel width per attention head.
462
+ :param num_heads_upsample: works with num_heads to set a different number
463
+ of heads for upsampling. Deprecated.
464
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
465
+ :param resblock_updown: use residual blocks for up/downsampling.
466
+ :param use_new_attention_order: use a different attention pattern for potentially
467
+ increased efficiency.
468
+ """
469
+
470
+ def __init__(
471
+ self,
472
+ image_size,
473
+ in_channels,
474
+ model_channels,
475
+ out_channels,
476
+ num_res_blocks,
477
+ attention_resolutions,
478
+ dropout=0,
479
+ channel_mult=(1, 2, 4, 8),
480
+ conv_resample=True,
481
+ dims=2,
482
+ num_classes=None,
483
+ use_checkpoint=False,
484
+ use_fp16=False,
485
+ num_heads=-1,
486
+ num_head_channels=-1,
487
+ num_heads_upsample=-1,
488
+ use_scale_shift_norm=False,
489
+ resblock_updown=False,
490
+ use_new_attention_order=False,
491
+ use_spatial_transformer=False, # custom transformer support
492
+ transformer_depth=1, # custom transformer support
493
+ context_dim=None, # custom transformer support
494
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
495
+ legacy=True,
496
+ disable_self_attentions=None,
497
+ num_attention_blocks=None,
498
+ disable_middle_self_attn=False,
499
+ use_linear_in_transformer=False,
500
+ ):
501
+ super().__init__()
502
+ if use_spatial_transformer:
503
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
504
+
505
+ if context_dim is not None:
506
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
507
+ from omegaconf.listconfig import ListConfig
508
+ if type(context_dim) == ListConfig:
509
+ context_dim = list(context_dim)
510
+
511
+ if num_heads_upsample == -1:
512
+ num_heads_upsample = num_heads
513
+
514
+ if num_heads == -1:
515
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
516
+
517
+ if num_head_channels == -1:
518
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
519
+
520
+ self.image_size = image_size
521
+ self.in_channels = in_channels
522
+ self.model_channels = model_channels
523
+ self.out_channels = out_channels
524
+ if isinstance(num_res_blocks, int):
525
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
526
+ else:
527
+ if len(num_res_blocks) != len(channel_mult):
528
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
529
+ "as a list/tuple (per-level) with the same length as channel_mult")
530
+ self.num_res_blocks = num_res_blocks
531
+ if disable_self_attentions is not None:
532
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
533
+ assert len(disable_self_attentions) == len(channel_mult)
534
+ if num_attention_blocks is not None:
535
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
536
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
537
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
538
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
539
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
540
+ f"attention will still not be set.")
541
+
542
+ self.attention_resolutions = attention_resolutions
543
+ self.dropout = dropout
544
+ self.channel_mult = channel_mult
545
+ self.conv_resample = conv_resample
546
+ self.num_classes = num_classes
547
+ self.use_checkpoint = use_checkpoint
548
+ self.dtype = th.float16 if use_fp16 else th.float32
549
+ self.num_heads = num_heads
550
+ self.num_head_channels = num_head_channels
551
+ self.num_heads_upsample = num_heads_upsample
552
+ self.predict_codebook_ids = n_embed is not None
553
+
554
+ time_embed_dim = model_channels * 4
555
+ self.time_embed = nn.Sequential(
556
+ linear(model_channels, time_embed_dim),
557
+ nn.SiLU(),
558
+ linear(time_embed_dim, time_embed_dim),
559
+ )
560
+
561
+ if self.num_classes is not None:
562
+ if isinstance(self.num_classes, int):
563
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
564
+ elif self.num_classes == "continuous":
565
+ print("setting up linear c_adm embedding layer")
566
+ self.label_emb = nn.Linear(1, time_embed_dim)
567
+ else:
568
+ raise ValueError()
569
+
570
+ self.input_blocks = nn.ModuleList(
571
+ [
572
+ TimestepEmbedSequential(
573
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
574
+ )
575
+ ]
576
+ )
577
+ self._feature_size = model_channels
578
+ input_block_chans = [model_channels]
579
+ ch = model_channels
580
+ ds = 1
581
+ for level, mult in enumerate(channel_mult):
582
+ for nr in range(self.num_res_blocks[level]):
583
+ layers = [
584
+ ResBlock(
585
+ ch,
586
+ time_embed_dim,
587
+ dropout,
588
+ out_channels=mult * model_channels,
589
+ dims=dims,
590
+ use_checkpoint=use_checkpoint,
591
+ use_scale_shift_norm=use_scale_shift_norm,
592
+ )
593
+ ]
594
+ ch = mult * model_channels
595
+ if ds in attention_resolutions:
596
+ if num_head_channels == -1:
597
+ dim_head = ch // num_heads
598
+ else:
599
+ num_heads = ch // num_head_channels
600
+ dim_head = num_head_channels
601
+ if legacy:
602
+ #num_heads = 1
603
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
604
+ if exists(disable_self_attentions):
605
+ disabled_sa = disable_self_attentions[level]
606
+ else:
607
+ disabled_sa = False
608
+
609
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
610
+ layers.append(
611
+ AttentionBlock(
612
+ ch,
613
+ use_checkpoint=use_checkpoint,
614
+ num_heads=num_heads,
615
+ num_head_channels=dim_head,
616
+ use_new_attention_order=use_new_attention_order,
617
+ ) if not use_spatial_transformer else SpatialTransformer(
618
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
619
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
620
+ use_checkpoint=use_checkpoint
621
+ )
622
+ )
623
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
624
+ self._feature_size += ch
625
+ input_block_chans.append(ch)
626
+ if level != len(channel_mult) - 1:
627
+ out_ch = ch
628
+ self.input_blocks.append(
629
+ TimestepEmbedSequential(
630
+ ResBlock(
631
+ ch,
632
+ time_embed_dim,
633
+ dropout,
634
+ out_channels=out_ch,
635
+ dims=dims,
636
+ use_checkpoint=use_checkpoint,
637
+ use_scale_shift_norm=use_scale_shift_norm,
638
+ down=True,
639
+ )
640
+ if resblock_updown
641
+ else Downsample(
642
+ ch, conv_resample, dims=dims, out_channels=out_ch
643
+ )
644
+ )
645
+ )
646
+ ch = out_ch
647
+ input_block_chans.append(ch)
648
+ ds *= 2
649
+ self._feature_size += ch
650
+
651
+ if num_head_channels == -1:
652
+ dim_head = ch // num_heads
653
+ else:
654
+ num_heads = ch // num_head_channels
655
+ dim_head = num_head_channels
656
+ if legacy:
657
+ #num_heads = 1
658
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
659
+ self.middle_block = TimestepEmbedSequential(
660
+ ResBlock(
661
+ ch,
662
+ time_embed_dim,
663
+ dropout,
664
+ dims=dims,
665
+ use_checkpoint=use_checkpoint,
666
+ use_scale_shift_norm=use_scale_shift_norm,
667
+ ),
668
+ AttentionBlock(
669
+ ch,
670
+ use_checkpoint=use_checkpoint,
671
+ num_heads=num_heads,
672
+ num_head_channels=dim_head,
673
+ use_new_attention_order=use_new_attention_order,
674
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
675
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
676
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
677
+ use_checkpoint=use_checkpoint
678
+ ),
679
+ ResBlock(
680
+ ch,
681
+ time_embed_dim,
682
+ dropout,
683
+ dims=dims,
684
+ use_checkpoint=use_checkpoint,
685
+ use_scale_shift_norm=use_scale_shift_norm,
686
+ ),
687
+ )
688
+ self._feature_size += ch
689
+
690
+ self.output_blocks = nn.ModuleList([])
691
+ for level, mult in list(enumerate(channel_mult))[::-1]:
692
+ for i in range(self.num_res_blocks[level] + 1):
693
+ ich = input_block_chans.pop()
694
+ layers = [
695
+ ResBlock(
696
+ ch + ich,
697
+ time_embed_dim,
698
+ dropout,
699
+ out_channels=model_channels * mult,
700
+ dims=dims,
701
+ use_checkpoint=use_checkpoint,
702
+ use_scale_shift_norm=use_scale_shift_norm,
703
+ )
704
+ ]
705
+ ch = model_channels * mult
706
+ if ds in attention_resolutions:
707
+ if num_head_channels == -1:
708
+ dim_head = ch // num_heads
709
+ else:
710
+ num_heads = ch // num_head_channels
711
+ dim_head = num_head_channels
712
+ if legacy:
713
+ #num_heads = 1
714
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
715
+ if exists(disable_self_attentions):
716
+ disabled_sa = disable_self_attentions[level]
717
+ else:
718
+ disabled_sa = False
719
+
720
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
721
+ layers.append(
722
+ AttentionBlock(
723
+ ch,
724
+ use_checkpoint=use_checkpoint,
725
+ num_heads=num_heads_upsample,
726
+ num_head_channels=dim_head,
727
+ use_new_attention_order=use_new_attention_order,
728
+ ) if not use_spatial_transformer else SpatialTransformer(
729
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
730
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
731
+ use_checkpoint=use_checkpoint
732
+ )
733
+ )
734
+ if level and i == self.num_res_blocks[level]:
735
+ out_ch = ch
736
+ layers.append(
737
+ ResBlock(
738
+ ch,
739
+ time_embed_dim,
740
+ dropout,
741
+ out_channels=out_ch,
742
+ dims=dims,
743
+ use_checkpoint=use_checkpoint,
744
+ use_scale_shift_norm=use_scale_shift_norm,
745
+ up=True,
746
+ )
747
+ if resblock_updown
748
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
749
+ )
750
+ ds //= 2
751
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
752
+ self._feature_size += ch
753
+
754
+ self.out = nn.Sequential(
755
+ normalization(ch),
756
+ nn.SiLU(),
757
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
758
+ )
759
+ if self.predict_codebook_ids:
760
+ self.id_predictor = nn.Sequential(
761
+ normalization(ch),
762
+ conv_nd(dims, model_channels, n_embed, 1),
763
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
764
+ )
765
+
766
+ def convert_to_fp16(self):
767
+ """
768
+ Convert the torso of the model to float16.
769
+ """
770
+ self.input_blocks.apply(convert_module_to_f16)
771
+ self.middle_block.apply(convert_module_to_f16)
772
+ self.output_blocks.apply(convert_module_to_f16)
773
+
774
+ def convert_to_fp32(self):
775
+ """
776
+ Convert the torso of the model to float32.
777
+ """
778
+ self.input_blocks.apply(convert_module_to_f32)
779
+ self.middle_block.apply(convert_module_to_f32)
780
+ self.output_blocks.apply(convert_module_to_f32)
781
+
782
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
783
+ """
784
+ Apply the model to an input batch.
785
+ :param x: an [N x C x ...] Tensor of inputs.
786
+ :param timesteps: a 1-D batch of timesteps.
787
+ :param context: conditioning plugged in via crossattn
788
+ :param y: an [N] Tensor of labels, if class-conditional.
789
+ :return: an [N x C x ...] Tensor of outputs.
790
+ """
791
+ assert (y is not None) == (
792
+ self.num_classes is not None
793
+ ), "must specify y if and only if the model is class-conditional"
794
+ hs = []
795
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
796
+ emb = self.time_embed(t_emb)
797
+
798
+ if self.num_classes is not None:
799
+ assert y.shape[0] == x.shape[0]
800
+ emb = emb + self.label_emb(y)
801
+
802
+ h = x.type(self.dtype)
803
+ for module in self.input_blocks:
804
+ h = module(h, emb, context)
805
+ hs.append(h)
806
+ h = self.middle_block(h, emb, context)
807
+ for module in self.output_blocks:
808
+ h = th.cat([h, hs.pop()], dim=1)
809
+ h = module(h, emb, context)
810
+ h = h.type(x.dtype)
811
+ if self.predict_codebook_ids:
812
+ return self.id_predictor(h)
813
+ else:
814
+ return self.out(h)
815
+
816
+
817
+
818
+
819
+ class UNetModel_Temporal(nn.Module):
820
+ """
821
+ The full UNet model with attention and timestep embedding.
822
+ :param in_channels: channels in the input Tensor.
823
+ :param model_channels: base channel count for the model.
824
+ :param out_channels: channels in the output Tensor.
825
+ :param num_res_blocks: number of residual blocks per downsample.
826
+ :param attention_resolutions: a collection of downsample rates at which
827
+ attention will take place. May be a set, list, or tuple.
828
+ For example, if this contains 4, then at 4x downsampling, attention
829
+ will be used.
830
+ :param dropout: the dropout probability.
831
+ :param channel_mult: channel multiplier for each level of the UNet.
832
+ :param conv_resample: if True, use learned convolutions for upsampling and
833
+ downsampling.
834
+ :param dims: determines if the signal is 1D, 2D, or 3D.
835
+ :param num_classes: if specified (as an int), then this model will be
836
+ class-conditional with `num_classes` classes.
837
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
838
+ :param num_heads: the number of attention heads in each attention layer.
839
+ :param num_heads_channels: if specified, ignore num_heads and instead use
840
+ a fixed channel width per attention head.
841
+ :param num_heads_upsample: works with num_heads to set a different number
842
+ of heads for upsampling. Deprecated.
843
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
844
+ :param resblock_updown: use residual blocks for up/downsampling.
845
+ :param use_new_attention_order: use a different attention pattern for potentially
846
+ increased efficiency.
847
+ """
848
+
849
+ def __init__(
850
+ self,
851
+ image_size,
852
+ in_channels,
853
+ model_channels,
854
+ out_channels,
855
+ num_res_blocks,
856
+ attention_resolutions,
857
+ dropout=0,
858
+ channel_mult=(1, 2, 4, 8),
859
+ conv_resample=True,
860
+ dims=2,
861
+ num_classes=None,
862
+ use_checkpoint=False,
863
+ use_fp16=False,
864
+ num_heads=-1,
865
+ num_head_channels=-1,
866
+ num_heads_upsample=-1,
867
+ use_scale_shift_norm=False,
868
+ resblock_updown=False,
869
+ use_new_attention_order=False,
870
+ use_spatial_transformer=False, # custom transformer support
871
+ transformer_depth=1, # custom transformer support
872
+ context_dim=None, # custom transformer support
873
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
874
+ legacy=True,
875
+ disable_self_attentions=None,
876
+ num_attention_blocks=None,
877
+ disable_middle_self_attn=False,
878
+ use_linear_in_transformer=False,
879
+ unet_additional_kwargs=None,
880
+ ):
881
+ super().__init__()
882
+
883
+ ## Motion Module Kwagrs
884
+ self.unet_additional_kwargs = unet_additional_kwargs
885
+ self.use_motion_module = self.unet_additional_kwargs['use_motion_module']
886
+ self.motion_module_resolutions = self.unet_additional_kwargs['motion_module_resolutions']
887
+ self.unet_use_cross_frame_attention = self.unet_additional_kwargs['unet_use_cross_frame_attention']
888
+ self.unet_use_temporal_attention = self.unet_additional_kwargs['unet_use_temporal_attention']
889
+ self.motion_module_type = self.unet_additional_kwargs['motion_module_type']
890
+
891
+ self.motion_module_kwargs = self.unet_additional_kwargs['motion_module_kwargs']
892
+ self.num_attention_heads = self.motion_module_kwargs['num_attention_heads']
893
+ self.num_transformer_block = self.motion_module_kwargs['num_transformer_block']
894
+ self.attention_block_types = self.motion_module_kwargs['attention_block_types']
895
+ self.temporal_position_encoding = self.motion_module_kwargs['temporal_position_encoding']
896
+ self.temporal_position_encoding_max_len = self.motion_module_kwargs['temporal_position_encoding_max_len']
897
+ self.temporal_attention_dim_div = self.motion_module_kwargs['temporal_attention_dim_div']
898
+ self.zero_initialize = self.motion_module_kwargs['zero_initialize']
899
+
900
+
901
+ if use_spatial_transformer:
902
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
903
+
904
+ if context_dim is not None:
905
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
906
+ from omegaconf.listconfig import ListConfig
907
+ if type(context_dim) == ListConfig:
908
+ context_dim = list(context_dim)
909
+
910
+ if num_heads_upsample == -1:
911
+ num_heads_upsample = num_heads
912
+
913
+ if num_heads == -1:
914
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
915
+
916
+ if num_head_channels == -1:
917
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
918
+
919
+ self.image_size = image_size
920
+ self.in_channels = in_channels
921
+ self.model_channels = model_channels
922
+ self.out_channels = out_channels
923
+ if isinstance(num_res_blocks, int):
924
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks] # 4 * [2]
925
+ else:
926
+ if len(num_res_blocks) != len(channel_mult):
927
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
928
+ "as a list/tuple (per-level) with the same length as channel_mult")
929
+ self.num_res_blocks = num_res_blocks
930
+ if disable_self_attentions is not None:
931
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
932
+ assert len(disable_self_attentions) == len(channel_mult)
933
+ if num_attention_blocks is not None:
934
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
935
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
936
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
937
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
938
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
939
+ f"attention will still not be set.")
940
+
941
+ self.attention_resolutions = attention_resolutions
942
+ self.dropout = dropout
943
+ self.channel_mult = channel_mult
944
+ self.conv_resample = conv_resample
945
+ self.num_classes = num_classes
946
+ self.use_checkpoint = use_checkpoint
947
+ self.dtype = th.float16 if use_fp16 else th.float32
948
+ self.num_heads = num_heads
949
+ self.num_head_channels = num_head_channels
950
+ self.num_heads_upsample = num_heads_upsample
951
+ self.predict_codebook_ids = n_embed is not None
952
+
953
+ time_embed_dim = model_channels * 4
954
+ self.time_embed = nn.Sequential(
955
+ linear(model_channels, time_embed_dim),
956
+ nn.SiLU(),
957
+ linear(time_embed_dim, time_embed_dim),
958
+ )
959
+
960
+ if self.num_classes is not None:
961
+ if isinstance(self.num_classes, int):
962
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
963
+ elif self.num_classes == "continuous":
964
+ print("setting up linear c_adm embedding layer")
965
+ self.label_emb = nn.Linear(1, time_embed_dim)
966
+ else:
967
+ raise ValueError()
968
+
969
+ self.input_blocks = nn.ModuleList(
970
+ [
971
+ TimestepEmbedSequential(
972
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
973
+ )
974
+ ]
975
+ )
976
+ self.input_blocks_motion_module = nn.ModuleList([])
977
+ self._feature_size = model_channels
978
+ input_block_chans = [model_channels]
979
+ ch = model_channels
980
+ ds = 1
981
+ for level, mult in enumerate(channel_mult):
982
+ for nr in range(self.num_res_blocks[level]):
983
+ layers = [
984
+ ResBlock(
985
+ ch,
986
+ time_embed_dim,
987
+ dropout,
988
+ out_channels=mult * model_channels,
989
+ dims=dims,
990
+ use_checkpoint=use_checkpoint,
991
+ use_scale_shift_norm=use_scale_shift_norm,
992
+ )
993
+ ]
994
+ ch = mult * model_channels
995
+ if ds in attention_resolutions: # [1,2,4]
996
+ if num_head_channels == -1:
997
+ dim_head = ch // num_heads
998
+ else:
999
+ num_heads = ch // num_head_channels
1000
+ dim_head = num_head_channels
1001
+ if legacy:
1002
+ #num_heads = 1
1003
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
1004
+ if exists(disable_self_attentions):
1005
+ disabled_sa = disable_self_attentions[level]
1006
+ else:
1007
+ disabled_sa = False
1008
+
1009
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
1010
+ layers.append(
1011
+ AttentionBlock(
1012
+ ch,
1013
+ use_checkpoint=use_checkpoint,
1014
+ num_heads=num_heads,
1015
+ num_head_channels=dim_head,
1016
+ use_new_attention_order=use_new_attention_order,
1017
+ ) if not use_spatial_transformer else SpatialTransformer(
1018
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
1019
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
1020
+ use_checkpoint=use_checkpoint
1021
+ )
1022
+ )
1023
+ if self.use_motion_module:
1024
+ layers_motion_module=[
1025
+ get_motion_module(
1026
+ in_channels=ch,
1027
+ motion_module_type=self.motion_module_type,
1028
+ motion_module_kwargs=self.motion_module_kwargs,
1029
+ )]
1030
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
1031
+ if self.use_motion_module:
1032
+ self.input_blocks_motion_module.append(TimestepEmbedSequential(*layers_motion_module))
1033
+ self._feature_size += ch
1034
+ input_block_chans.append(ch)
1035
+ if level != len(channel_mult) - 1:
1036
+ out_ch = ch
1037
+ self.input_blocks.append(
1038
+ TimestepEmbedSequential(
1039
+ ResBlock(
1040
+ ch,
1041
+ time_embed_dim,
1042
+ dropout,
1043
+ out_channels=out_ch,
1044
+ dims=dims,
1045
+ use_checkpoint=use_checkpoint,
1046
+ use_scale_shift_norm=use_scale_shift_norm,
1047
+ down=True,
1048
+ )
1049
+ if resblock_updown
1050
+ else Downsample(
1051
+ ch, conv_resample, dims=dims, out_channels=out_ch
1052
+ )
1053
+ )
1054
+ )
1055
+ ch = out_ch
1056
+ input_block_chans.append(ch)
1057
+ ds *= 2
1058
+ self._feature_size += ch
1059
+ # motion module [1,2,4,5,7,8,10,11] !!!! Conv RST RST Down RST RST Down RST RST Down RT RT
1060
+ if num_head_channels == -1:
1061
+ dim_head = ch // num_heads
1062
+ else:
1063
+ num_heads = ch // num_head_channels
1064
+ dim_head = num_head_channels
1065
+ if legacy:
1066
+ #num_heads = 1
1067
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
1068
+ self.middle_block = TimestepEmbedSequential(
1069
+ ResBlock(
1070
+ ch,
1071
+ time_embed_dim,
1072
+ dropout,
1073
+ dims=dims,
1074
+ use_checkpoint=use_checkpoint,
1075
+ use_scale_shift_norm=use_scale_shift_norm,
1076
+ ),
1077
+ AttentionBlock(
1078
+ ch,
1079
+ use_checkpoint=use_checkpoint,
1080
+ num_heads=num_heads,
1081
+ num_head_channels=dim_head,
1082
+ use_new_attention_order=use_new_attention_order,
1083
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
1084
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
1085
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
1086
+ use_checkpoint=use_checkpoint
1087
+ ), # Follow by motion module
1088
+ ResBlock(
1089
+ ch,
1090
+ time_embed_dim,
1091
+ dropout,
1092
+ dims=dims,
1093
+ use_checkpoint=use_checkpoint,
1094
+ use_scale_shift_norm=use_scale_shift_norm,
1095
+ ),
1096
+ )
1097
+
1098
+ self._feature_size += ch
1099
+
1100
+ self.output_blocks = nn.ModuleList([])
1101
+ self.output_blocks_motion_module = nn.ModuleList([])
1102
+ for level, mult in list(enumerate(channel_mult))[::-1]:
1103
+ for i in range(self.num_res_blocks[level] + 1):
1104
+ ich = input_block_chans.pop()
1105
+ layers = [
1106
+ ResBlock(
1107
+ ch + ich,
1108
+ time_embed_dim,
1109
+ dropout,
1110
+ out_channels=model_channels * mult,
1111
+ dims=dims,
1112
+ use_checkpoint=use_checkpoint,
1113
+ use_scale_shift_norm=use_scale_shift_norm,
1114
+ )
1115
+ ]
1116
+ ch = model_channels * mult
1117
+ if ds in attention_resolutions:
1118
+ if num_head_channels == -1:
1119
+ dim_head = ch // num_heads
1120
+ else:
1121
+ num_heads = ch // num_head_channels
1122
+ dim_head = num_head_channels
1123
+ if legacy:
1124
+ #num_heads = 1
1125
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
1126
+ if exists(disable_self_attentions):
1127
+ disabled_sa = disable_self_attentions[level]
1128
+ else:
1129
+ disabled_sa = False
1130
+
1131
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
1132
+ layers.append(
1133
+ AttentionBlock(
1134
+ ch,
1135
+ use_checkpoint=use_checkpoint,
1136
+ num_heads=num_heads_upsample,
1137
+ num_head_channels=dim_head,
1138
+ use_new_attention_order=use_new_attention_order,
1139
+ ) if not use_spatial_transformer else SpatialTransformer(
1140
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
1141
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
1142
+ use_checkpoint=use_checkpoint
1143
+ )
1144
+ )
1145
+ if self.use_motion_module:
1146
+ layers_motion_module=[
1147
+ get_motion_module(
1148
+ in_channels=ch,
1149
+ motion_module_type=self.motion_module_type,
1150
+ motion_module_kwargs=self.motion_module_kwargs,
1151
+ )]
1152
+ if level and i == self.num_res_blocks[level]:
1153
+ out_ch = ch
1154
+ layers.append(
1155
+ ResBlock(
1156
+ ch,
1157
+ time_embed_dim,
1158
+ dropout,
1159
+ out_channels=out_ch,
1160
+ dims=dims,
1161
+ use_checkpoint=use_checkpoint,
1162
+ use_scale_shift_norm=use_scale_shift_norm,
1163
+ up=True,
1164
+ )
1165
+ if resblock_updown
1166
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
1167
+ )
1168
+
1169
+ # if self.use_motion_module:
1170
+ # in_channel_mm_up = out_ch or ch
1171
+ # layers_motion_module.append(
1172
+ # get_motion_module(
1173
+ # in_channels=in_channel_mm_up,
1174
+ # motion_module_type=self.motion_module_type,
1175
+ # motion_module_kwargs=self.motion_module_kwargs,
1176
+ # )
1177
+ # )
1178
+ ds //= 2
1179
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
1180
+ self.output_blocks_motion_module.append(TimestepEmbedSequential(*layers_motion_module))
1181
+ self._feature_size += ch
1182
+ # motion module [0,1,2,4,5,6,8,9,10,12,13,14] RT RT RT Up RST RST RST Up RST RST RST Up RST RST RST
1183
+ self.out = nn.Sequential(
1184
+ normalization(ch),
1185
+ nn.SiLU(),
1186
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
1187
+ )
1188
+ if self.predict_codebook_ids:
1189
+ self.id_predictor = nn.Sequential(
1190
+ normalization(ch),
1191
+ conv_nd(dims, model_channels, n_embed, 1),
1192
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
1193
+ )
1194
+
1195
+ def convert_to_fp16(self):
1196
+ """
1197
+ Convert the torso of the model to float16.
1198
+ """
1199
+ self.input_blocks.apply(convert_module_to_f16)
1200
+ self.middle_block.apply(convert_module_to_f16)
1201
+ self.output_blocks.apply(convert_module_to_f16)
1202
+
1203
+ def convert_to_fp32(self):
1204
+ """
1205
+ Convert the torso of the model to float32.
1206
+ """
1207
+ self.input_blocks.apply(convert_module_to_f32)
1208
+ self.middle_block.apply(convert_module_to_f32)
1209
+ self.output_blocks.apply(convert_module_to_f32)
1210
+
1211
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
1212
+ pass
model_lib/ControlNet/ldm/modules/diffusionmodules/upscaling.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from functools import partial
5
+
6
+ from model_lib.ControlNet.ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
7
+ from model_lib.ControlNet.ldm.util import default
8
+
9
+
10
+ class AbstractLowScaleModel(nn.Module):
11
+ # for concatenating a downsampled image to the latent representation
12
+ def __init__(self, noise_schedule_config=None):
13
+ super(AbstractLowScaleModel, self).__init__()
14
+ if noise_schedule_config is not None:
15
+ self.register_schedule(**noise_schedule_config)
16
+
17
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
18
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
19
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
20
+ cosine_s=cosine_s)
21
+ alphas = 1. - betas
22
+ alphas_cumprod = np.cumprod(alphas, axis=0)
23
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
24
+
25
+ timesteps, = betas.shape
26
+ self.num_timesteps = int(timesteps)
27
+ self.linear_start = linear_start
28
+ self.linear_end = linear_end
29
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
30
+
31
+ to_torch = partial(torch.tensor, dtype=torch.float32)
32
+
33
+ self.register_buffer('betas', to_torch(betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
43
+
44
+ def q_sample(self, x_start, t, noise=None):
45
+ noise = default(noise, lambda: torch.randn_like(x_start))
46
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
47
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
48
+
49
+ def forward(self, x):
50
+ return x, None
51
+
52
+ def decode(self, x):
53
+ return x
54
+
55
+
56
+ class SimpleImageConcat(AbstractLowScaleModel):
57
+ # no noise level conditioning
58
+ def __init__(self):
59
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
60
+ self.max_noise_level = 0
61
+
62
+ def forward(self, x):
63
+ # fix to constant noise level
64
+ return x, torch.zeros(x.shape[0], device=x.device).long()
65
+
66
+
67
+ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
68
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
69
+ super().__init__(noise_schedule_config=noise_schedule_config)
70
+ self.max_noise_level = max_noise_level
71
+
72
+ def forward(self, x, noise_level=None):
73
+ if noise_level is None:
74
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
75
+ else:
76
+ assert isinstance(noise_level, torch.Tensor)
77
+ z = self.q_sample(x, noise_level)
78
+ return z, noise_level
79
+
80
+
81
+
model_lib/ControlNet/ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+ from model_lib.ControlNet.ldm.util import instantiate_from_config
18
+ import pdb
19
+
20
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
21
+ if schedule == "linear":
22
+ betas = (
23
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
24
+ )
25
+
26
+ elif schedule == "cosine":
27
+ timesteps = (
28
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
29
+ )
30
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
31
+ alphas = torch.cos(alphas).pow(2)
32
+ alphas = alphas / alphas[0]
33
+ betas = 1 - alphas[1:] / alphas[:-1]
34
+ betas = np.clip(betas, a_min=0, a_max=0.999)
35
+
36
+ elif schedule == "sqrt_linear":
37
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
38
+ elif schedule == "sqrt":
39
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
40
+ else:
41
+ raise ValueError(f"schedule '{schedule}' unknown.")
42
+ return betas.numpy()
43
+
44
+
45
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
46
+ if ddim_discr_method == 'uniform':
47
+ c = num_ddpm_timesteps // num_ddim_timesteps
48
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
49
+ elif ddim_discr_method == 'quad':
50
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
51
+ else:
52
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
53
+
54
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
55
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
56
+ steps_out = ddim_timesteps + 1
57
+ if verbose:
58
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
59
+ return steps_out
60
+
61
+
62
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
63
+ # select alphas for computing the variance schedule
64
+ alphas = alphacums[ddim_timesteps]
65
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
66
+
67
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
68
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
69
+ if verbose:
70
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
71
+ print(f'For the chosen value of eta, which is {eta}, '
72
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
73
+ return sigmas, alphas, alphas_prev
74
+
75
+
76
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
77
+ """
78
+ Create a beta schedule that discretizes the given alpha_t_bar function,
79
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
80
+ :param num_diffusion_timesteps: the number of betas to produce.
81
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
82
+ produces the cumulative product of (1-beta) up to that
83
+ part of the diffusion process.
84
+ :param max_beta: the maximum beta to use; use values lower than 1 to
85
+ prevent singularities.
86
+ """
87
+ betas = []
88
+ for i in range(num_diffusion_timesteps):
89
+ t1 = i / num_diffusion_timesteps
90
+ t2 = (i + 1) / num_diffusion_timesteps
91
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
92
+ return np.array(betas)
93
+
94
+
95
+ def extract_into_tensor(a, t, x_shape):
96
+ b, *_ = t.shape
97
+ out = a.gather(-1, t)
98
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
99
+
100
+
101
+ def checkpoint(func, inputs, params, flag):
102
+ """
103
+ Evaluate a function without caching intermediate activations, allowing for
104
+ reduced memory at the expense of extra compute in the backward pass.
105
+ :param func: the function to evaluate.
106
+ :param inputs: the argument sequence to pass to `func`.
107
+ :param params: a sequence of parameters `func` depends on but does not
108
+ explicitly take as arguments.
109
+ :param flag: if False, disable gradient checkpointing.
110
+ """
111
+ if flag:
112
+ args = tuple(inputs) + tuple(params)
113
+ return CheckpointFunction.apply(func, len(inputs), *args)
114
+ else:
115
+ return func(*inputs)
116
+
117
+
118
+ class CheckpointFunction(torch.autograd.Function):
119
+ @staticmethod
120
+ def forward(ctx, run_function, length, *args):
121
+ ctx.run_function = run_function
122
+ ctx.input_tensors = list(args[:length])
123
+ ctx.input_params = list(args[length:])
124
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
125
+ "dtype": torch.get_autocast_gpu_dtype(),
126
+ "cache_enabled": torch.is_autocast_cache_enabled()}
127
+ with torch.no_grad():
128
+ output_tensors = ctx.run_function(*ctx.input_tensors)
129
+ return output_tensors
130
+
131
+ @staticmethod
132
+ def backward(ctx, *output_grads):
133
+ input_tensors = []
134
+ input_tensor_index = []
135
+ for i, input_tensor in enumerate(ctx.input_tensors):
136
+ if isinstance(input_tensor, torch.Tensor):
137
+ input_tensors.append(input_tensor.detach().requires_grad_(True))
138
+ else:
139
+ input_tensors.append(input_tensor)
140
+ input_tensor_index.append(i)
141
+ ctx.input_tensors = input_tensors
142
+
143
+ length_input_tensors = len(input_tensors)
144
+ with torch.enable_grad(), \
145
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
146
+ # Fixes a bug where the first op in run_function modifies the
147
+ # Tensor storage in place, which is not allowed for detach()'d
148
+ # Tensors.
149
+ shallow_copies = []
150
+ for input_tensor in ctx.input_tensors:
151
+ try:
152
+ shallow_copies.append(input_tensor.view_as(input_tensor))
153
+ except:
154
+ shallow_copies.append(input_tensor)
155
+ # shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
156
+ output_tensors = ctx.run_function(*shallow_copies)
157
+ # print(len(input_tensors))
158
+ # pdb.set_trace()
159
+ num_non_tensor = len(input_tensor_index)
160
+ for num in range(num_non_tensor):
161
+ index = input_tensor_index[num_non_tensor-1-num]
162
+ ctx.input_tensors.pop(index)
163
+
164
+ input_params = []
165
+ input_params_index = []
166
+ for i, input_param in enumerate(ctx.input_params):
167
+ if input_param.requires_grad == True:
168
+ input_params.append(input_param)
169
+ else:
170
+ input_params_index.append(i)
171
+ # pdb.set_trace()
172
+ input_grads = torch.autograd.grad(output_tensors,ctx.input_tensors + input_params,output_grads,allow_unused=True,)
173
+ # print(len(input_grads))
174
+ # pdb.set_trace()
175
+ input_grads = list(input_grads)
176
+ for index in input_tensor_index:
177
+ input_grads.insert(index, None)
178
+ if input_params_index == []:
179
+ pass
180
+ else:
181
+ for param_index in input_params_index:
182
+ input_grads.insert(length_input_tensors+param_index, None)
183
+ input_grads = tuple(input_grads)
184
+ del ctx.input_tensors
185
+ del ctx.input_params
186
+ del output_tensors
187
+ return (None, None) + input_grads
188
+
189
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
190
+ """
191
+ Create sinusoidal timestep embeddings.
192
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
193
+ These may be fractional.
194
+ :param dim: the dimension of the output.
195
+ :param max_period: controls the minimum frequency of the embeddings.
196
+ :return: an [N x dim] Tensor of positional embeddings.
197
+ """
198
+ if not repeat_only:
199
+ half = dim // 2
200
+ freqs = torch.exp(
201
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
202
+ ).to(device=timesteps.device)
203
+ args = timesteps[:, None].float() * freqs[None]
204
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
205
+ if dim % 2:
206
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
207
+ else:
208
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
209
+ return embedding
210
+
211
+
212
+ def zero_module(module):
213
+ """
214
+ Zero out the parameters of a module and return it.
215
+ """
216
+ for p in module.parameters():
217
+ p.detach().zero_()
218
+ return module
219
+
220
+
221
+ def scale_module(module, scale):
222
+ """
223
+ Scale the parameters of a module and return it.
224
+ """
225
+ for p in module.parameters():
226
+ p.detach().mul_(scale)
227
+ return module
228
+
229
+
230
+ def mean_flat(tensor):
231
+ """
232
+ Take the mean over all non-batch dimensions.
233
+ """
234
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
235
+
236
+
237
+ def normalization(channels):
238
+ """
239
+ Make a standard normalization layer.
240
+ :param channels: number of input channels.
241
+ :return: an nn.Module for normalization.
242
+ """
243
+ return GroupNorm32(32, channels)
244
+
245
+
246
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
247
+ class SiLU(nn.Module):
248
+ def forward(self, x):
249
+ return x * torch.sigmoid(x)
250
+
251
+
252
+ class GroupNorm32(nn.GroupNorm):
253
+ def forward(self, x):
254
+ return super().forward(x.float()).type(x.dtype)
255
+
256
+ def conv_nd(dims, *args, **kwargs):
257
+ """
258
+ Create a 1D, 2D, or 3D convolution module.
259
+ """
260
+ if dims == 1:
261
+ return nn.Conv1d(*args, **kwargs)
262
+ elif dims == 2:
263
+ return nn.Conv2d(*args, **kwargs)
264
+ elif dims == 3:
265
+ return nn.Conv3d(*args, **kwargs)
266
+ raise ValueError(f"unsupported dimensions: {dims}")
267
+
268
+
269
+ def linear(*args, **kwargs):
270
+ """
271
+ Create a linear module.
272
+ """
273
+ return nn.Linear(*args, **kwargs)
274
+
275
+
276
+ def avg_pool_nd(dims, *args, **kwargs):
277
+ """
278
+ Create a 1D, 2D, or 3D average pooling module.
279
+ """
280
+ if dims == 1:
281
+ return nn.AvgPool1d(*args, **kwargs)
282
+ elif dims == 2:
283
+ return nn.AvgPool2d(*args, **kwargs)
284
+ elif dims == 3:
285
+ return nn.AvgPool3d(*args, **kwargs)
286
+ raise ValueError(f"unsupported dimensions: {dims}")
287
+
288
+
289
+ class HybridConditioner(nn.Module):
290
+
291
+ def __init__(self, c_concat_config, c_crossattn_config):
292
+ super().__init__()
293
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
294
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
295
+
296
+ def forward(self, c_concat, c_crossattn):
297
+ c_concat = self.concat_conditioner(c_concat)
298
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
299
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
300
+
301
+
302
+ def noise_like(shape, device, repeat=False):
303
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
304
+ noise = lambda: torch.randn(shape, device=device)
305
+ return repeat_noise() if repeat else noise()
model_lib/ControlNet/ldm/modules/distributions/__init__.py ADDED
File without changes
model_lib/ControlNet/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (233 Bytes). View file
 
model_lib/ControlNet/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc ADDED
Binary file (3.85 kB). View file
 
model_lib/ControlNet/ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )
model_lib/ControlNet/ldm/modules/ema.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1, dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ # remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.', '')
20
+ self.m_name2s_name.update({name: s_name})
21
+ self.register_buffer(s_name, p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def reset_num_updates(self):
26
+ del self.num_updates
27
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28
+
29
+ def forward(self, model):
30
+ decay = self.decay
31
+
32
+ if self.num_updates >= 0:
33
+ self.num_updates += 1
34
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35
+
36
+ one_minus_decay = 1.0 - decay
37
+
38
+ with torch.no_grad():
39
+ m_param = dict(model.named_parameters())
40
+ shadow_params = dict(self.named_buffers())
41
+
42
+ for key in m_param:
43
+ if m_param[key].requires_grad:
44
+ sname = self.m_name2s_name[key]
45
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
47
+ else:
48
+ assert not key in self.m_name2s_name
49
+
50
+ def copy_to(self, model):
51
+ m_param = dict(model.named_parameters())
52
+ shadow_params = dict(self.named_buffers())
53
+ for key in m_param:
54
+ if m_param[key].requires_grad:
55
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
56
+ else:
57
+ assert not key in self.m_name2s_name
58
+
59
+ def store(self, parameters):
60
+ """
61
+ Save the current parameters for restoring later.
62
+ Args:
63
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
64
+ temporarily stored.
65
+ """
66
+ self.collected_params = [param.clone() for param in parameters]
67
+
68
+ def restore(self, parameters):
69
+ """
70
+ Restore the parameters stored with the `store` method.
71
+ Useful to validate the model with EMA parameters without affecting the
72
+ original optimization process. Store the parameters before the
73
+ `copy_to` method. After validation (or model saving), use this to
74
+ restore the former parameters.
75
+ Args:
76
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
77
+ updated with the stored parameters.
78
+ """
79
+ for c_param, param in zip(self.collected_params, parameters):
80
+ param.data.copy_(c_param.data)
model_lib/ControlNet/ldm/modules/encoders/__init__.py ADDED
File without changes
model_lib/ControlNet/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (228 Bytes). View file