Commit
·
074c857
1
Parent(s):
1d7be4e
Upload 198 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- configuration.py +183 -0
- deforum-stable-diffusion/Deforum_Stable_Diffusion.ipynb +580 -0
- deforum-stable-diffusion/Deforum_Stable_Diffusion.py +536 -0
- deforum-stable-diffusion/LICENSE +0 -0
- deforum-stable-diffusion/configs/v1-inference.yaml +70 -0
- deforum-stable-diffusion/configs/v2-inference-v.yaml +68 -0
- deforum-stable-diffusion/configs/v2-inference.yaml +67 -0
- deforum-stable-diffusion/configs/v2-inpainting-inference.yaml +158 -0
- deforum-stable-diffusion/configs/v2-midas-inference.yaml +74 -0
- deforum-stable-diffusion/configs/x4-upscaling.yaml +76 -0
- deforum-stable-diffusion/helpers/__init__.py +9 -0
- deforum-stable-diffusion/helpers/__pycache__/__init__.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/__init__.cpython-39.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/aesthetics.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/animation.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/callback.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/colors.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/conditioning.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/depth.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/generate.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/generate.cpython-39.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/k_samplers.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/load_images.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/model_load.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/model_wrap.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/prompt.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/render.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/render.cpython-39.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/save_images.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/save_images.cpython-39.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/settings.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/settings.cpython-39.pyc +0 -0
- deforum-stable-diffusion/helpers/__pycache__/simulacra_fit_linear_model.cpython-38.pyc +0 -0
- deforum-stable-diffusion/helpers/aesthetics.py +48 -0
- deforum-stable-diffusion/helpers/animation.py +338 -0
- deforum-stable-diffusion/helpers/callback.py +124 -0
- deforum-stable-diffusion/helpers/colors.py +16 -0
- deforum-stable-diffusion/helpers/conditioning.py +262 -0
- deforum-stable-diffusion/helpers/depth.py +175 -0
- deforum-stable-diffusion/helpers/generate.py +282 -0
- deforum-stable-diffusion/helpers/k_samplers.py +124 -0
- deforum-stable-diffusion/helpers/load_images.py +99 -0
- deforum-stable-diffusion/helpers/model_load.py +257 -0
- deforum-stable-diffusion/helpers/model_wrap.py +226 -0
- deforum-stable-diffusion/helpers/prompt.py +130 -0
- deforum-stable-diffusion/helpers/rank_images.py +69 -0
- deforum-stable-diffusion/helpers/render.py +472 -0
- deforum-stable-diffusion/helpers/save_images.py +60 -0
- deforum-stable-diffusion/helpers/settings.py +34 -0
- deforum-stable-diffusion/helpers/simulacra_compute_embeddings.py +96 -0
configuration.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
def Root():
|
4 |
+
models_path = "models" #@param {type:"string"}
|
5 |
+
configs_path = "configs" #@param {type:"string"}
|
6 |
+
output_path = "output" #@param {type:"string"}
|
7 |
+
mount_google_drive = False #@param {type:"boolean"}
|
8 |
+
models_path_gdrive = "/content/drive/MyDrive/AI/models" #@param {type:"string"}
|
9 |
+
output_path_gdrive = "/content/drive/MyDrive/AI/StableDiffusion" #@param {type:"string"}
|
10 |
+
|
11 |
+
#@markdown **Model Setup**
|
12 |
+
model_config = "v1-inference.yaml" #@param ["custom","v1-inference.yaml"]
|
13 |
+
model_checkpoint = "v1-5-pruned-emaonly.ckpt" #@param ["custom","v1-5-pruned.ckpt","v1-5-pruned-emaonly.ckpt","sd-v1-4-full-ema.ckpt","sd-v1-4.ckpt","sd-v1-3-full-ema.ckpt","sd-v1-3.ckpt","sd-v1-2-full-ema.ckpt","sd-v1-2.ckpt","sd-v1-1-full-ema.ckpt","sd-v1-1.ckpt", "robo-diffusion-v1.ckpt","wd-v1-3-float16.ckpt"]
|
14 |
+
custom_config_path = "" #@param {type:"string"}
|
15 |
+
custom_checkpoint_path = "" #@param {type:"string"}
|
16 |
+
half_precision = True
|
17 |
+
return locals()
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
def DeforumAnimArgs():
|
22 |
+
animation_mode = "3D" #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'}
|
23 |
+
max_frames = 200 #@param {type:"number"}
|
24 |
+
border = 'wrap' #@param ['wrap', 'replicate'] {type:'string'}
|
25 |
+
|
26 |
+
#@markdown ####**Motion Parameters:**
|
27 |
+
angle = "0:(0)" #@param {type:"string"}
|
28 |
+
zoom = "0:(1.04)" #@param {type:"string"}
|
29 |
+
translation_x = "0:(0)" #@param {type:"string"}
|
30 |
+
translation_y = "0:(0)" #@param {type:"string"}
|
31 |
+
translation_z = "0:(0)" #@param {type:"string"}
|
32 |
+
rotation_3d_x = "0:(0)" #@param {type:"string"}
|
33 |
+
rotation_3d_y = "0:(0)" #@param {type:"string"}
|
34 |
+
rotation_3d_z = "0:(0)" #@param {type:"string"}
|
35 |
+
flip_2d_perspective = False #@param {type:"boolean"}
|
36 |
+
perspective_flip_theta = "0:(0)" #@param {type:"string"}
|
37 |
+
perspective_flip_phi = "0:(t%15)" #@param {type:"string"}
|
38 |
+
perspective_flip_gamma = "0:(0)" #@param {type:"string"}
|
39 |
+
perspective_flip_fv = "0:(0)" #@param {type:"string"}
|
40 |
+
noise_schedule = "0:(0.02)" #@param {type:"string"}
|
41 |
+
strength_schedule = "0:(0.65)" #@param {type:"string"}
|
42 |
+
contrast_schedule = "0:(1.0)" #@param {type:"string"}
|
43 |
+
|
44 |
+
#@markdown ####**Coherence:**
|
45 |
+
color_coherence = "Match Frame 0 LAB" #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}
|
46 |
+
diffusion_cadence = "3" #@param ['1','2','3','4','5','6','7','8'] {type:'string'}
|
47 |
+
|
48 |
+
#@markdown #### 3D Depth Warping
|
49 |
+
use_depth_warping = True #@param {type:"boolean"}
|
50 |
+
midas_weight = 0.3 #@param {type:"number"}
|
51 |
+
near_plane = 200
|
52 |
+
far_plane = 10000
|
53 |
+
fov = 40 #@param {type:"number"}
|
54 |
+
padding_mode = "border" #@param ['border', 'reflection', 'zeros'] {type:'string'}
|
55 |
+
sampling_mode = "bicubic" #@param ['bicubic', 'bilinear', 'nearest'] {type:'string'}
|
56 |
+
save_depth_maps = False #@param {type:"boolean"}
|
57 |
+
|
58 |
+
#@markdown ####**Video Input:**
|
59 |
+
video_init_path = "./input/video_in.mp4" #@param {type:"string"}
|
60 |
+
extract_nth_frame = 1 #@param {type:"number"}
|
61 |
+
overwrite_extracted_frames = True #@param {type:"boolean"}
|
62 |
+
use_mask_video = False #@param {type:"boolean"}
|
63 |
+
video_mask_path = "" #@param {type:"string"}
|
64 |
+
|
65 |
+
#@markdown ####**Interpolation:**
|
66 |
+
interpolate_key_frames = False #@param {type:"boolean"}
|
67 |
+
interpolate_x_frames = 4 #@param {type:"number"}
|
68 |
+
|
69 |
+
#@markdown ####**Resume Animation:**
|
70 |
+
resume_from_timestring = False #@param {type:"boolean"}
|
71 |
+
resume_timestring = "20220829210106" #@param {type:"string"}
|
72 |
+
return locals()
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
def DeforumArgs():
|
77 |
+
#@markdown **Image Settings**
|
78 |
+
W = 512 #@param
|
79 |
+
H = 512 #@param
|
80 |
+
W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64
|
81 |
+
|
82 |
+
#@markdown **Sampling Settings**
|
83 |
+
seed = 2022 #@param
|
84 |
+
sampler = "klms" #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim", "dpm_fast", "dpm_adaptive", "dpmpp_2s_a", "dpmpp_2m"]
|
85 |
+
steps = 50 #@param
|
86 |
+
scale = 7 #@param
|
87 |
+
ddim_eta = 0.0 #@param
|
88 |
+
dynamic_threshold = None
|
89 |
+
static_threshold = None
|
90 |
+
|
91 |
+
#@markdown **Save & Display Settings**
|
92 |
+
save_samples = True #@param {type:"boolean"}
|
93 |
+
save_settings = True #@param {type:"boolean"}
|
94 |
+
display_samples = True #@param {type:"boolean"}
|
95 |
+
save_sample_per_step = False #@param {type:"boolean"}
|
96 |
+
show_sample_per_step = False #@param {type:"boolean"}
|
97 |
+
|
98 |
+
#@markdown **Prompt Settings**
|
99 |
+
prompt_weighting = True #@param {type:"boolean"}
|
100 |
+
normalize_prompt_weights = True #@param {type:"boolean"}
|
101 |
+
log_weighted_subprompts = False #@param {type:"boolean"}
|
102 |
+
|
103 |
+
#@markdown **Batch Settings**
|
104 |
+
n_batch = 1 #@param
|
105 |
+
batch_name = "data" #@param {type:"string"}
|
106 |
+
filename_format = "{timestring}_{index}_{prompt}.png" #@param ["{timestring}_{index}_{seed}.png","{timestring}_{index}_{prompt}.png"]
|
107 |
+
seed_behavior = "iter" #@param ["iter","fixed","random"]
|
108 |
+
make_grid = False #@param {type:"boolean"}
|
109 |
+
grid_rows = 2 #@param
|
110 |
+
outdir = "./outputs"
|
111 |
+
|
112 |
+
#@markdown **Init Settings**
|
113 |
+
use_init = False #@param {type:"boolean"}
|
114 |
+
strength = 0.0 #@param {type:"number"}
|
115 |
+
strength_0_no_init = True # Set the strength to 0 automatically when no init image is used
|
116 |
+
init_image = "" #@param {type:"string"}
|
117 |
+
# Whiter areas of the mask are areas that change more
|
118 |
+
use_mask = False #@param {type:"boolean"}
|
119 |
+
use_alpha_as_mask = False # use the alpha channel of the init image as the mask
|
120 |
+
mask_file = "" #@param {type:"string"}
|
121 |
+
invert_mask = False #@param {type:"boolean"}
|
122 |
+
# Adjust mask image, 1.0 is no adjustment. Should be positive numbers.
|
123 |
+
mask_brightness_adjust = 1.0 #@param {type:"number"}
|
124 |
+
mask_contrast_adjust = 1.0 #@param {type:"number"}
|
125 |
+
|
126 |
+
# Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding
|
127 |
+
overlay_mask = True # {type:"boolean"}
|
128 |
+
# Blur edges of final overlay mask, if used. Minimum = 0 (no blur)
|
129 |
+
mask_overlay_blur = 5 # {type:"number"}
|
130 |
+
|
131 |
+
#@markdown **Exposure/Contrast Conditional Settings**
|
132 |
+
mean_scale = 0 #@param {type:"number"}
|
133 |
+
var_scale = 0 #@param {type:"number"}
|
134 |
+
exposure_scale = 0 #@param {type:"number"}
|
135 |
+
exposure_target = 0.5 #@param {type:"number"}
|
136 |
+
|
137 |
+
#@markdown **Color Match Conditional Settings**
|
138 |
+
colormatch_scale = 0 #@param {type:"number"}
|
139 |
+
colormatch_image = "" #@param {type:"string"}
|
140 |
+
colormatch_n_colors = 4 #@param {type:"number"}
|
141 |
+
ignore_sat_weight = 0 #@param {type:"number"}
|
142 |
+
|
143 |
+
#@markdown **CLIP\Aesthetics Conditional Settings**
|
144 |
+
clip_name = "ViT-L/14" #@param ['ViT-L/14', 'ViT-L/14@336px', 'ViT-B/16', 'ViT-B/32']
|
145 |
+
clip_scale = 0 #@param {type:"number"}
|
146 |
+
aesthetics_scale = 0 #@param {type:"number"}
|
147 |
+
cutn = 1 #@param {type:"number"}
|
148 |
+
cut_pow = 0.0001 #@param {type:"number"}
|
149 |
+
|
150 |
+
#@markdown **Other Conditional Settings**
|
151 |
+
init_mse_scale = 0 #@param {type:"number"}
|
152 |
+
init_mse_image = "" #@param {type:"string"}
|
153 |
+
|
154 |
+
blue_scale = 1 #@param {type:"number"}
|
155 |
+
|
156 |
+
#@markdown **Conditional Gradient Settings**
|
157 |
+
gradient_wrt = "x0_pred" #@param ["x", "x0_pred"]
|
158 |
+
gradient_add_to = "both" #@param ["cond", "uncond", "both"]
|
159 |
+
decode_method = "linear" #@param ["autoencoder","linear"]
|
160 |
+
grad_threshold_type = "dynamic" #@param ["dynamic", "static", "mean", "schedule"]
|
161 |
+
clamp_grad_threshold = 0.2 #@param {type:"number"}
|
162 |
+
clamp_start = 0.2 #@param
|
163 |
+
clamp_stop = 0.01 #@param
|
164 |
+
grad_inject_timing = list(range(1,10)) #@param
|
165 |
+
|
166 |
+
#@markdown **Speed vs VRAM Settings**
|
167 |
+
cond_uncond_sync = True #@param {type:"boolean"}
|
168 |
+
|
169 |
+
n_samples = 1 # doesnt do anything
|
170 |
+
precision = 'autocast'
|
171 |
+
C = 4
|
172 |
+
f = 8
|
173 |
+
|
174 |
+
prompt = ""
|
175 |
+
timestring = ""
|
176 |
+
init_latent = None
|
177 |
+
init_sample = None
|
178 |
+
init_sample_raw = None
|
179 |
+
mask_sample = None
|
180 |
+
init_c = None
|
181 |
+
|
182 |
+
return locals()
|
183 |
+
|
deforum-stable-diffusion/Deforum_Stable_Diffusion.ipynb
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "ByGXyiHZWM_q"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"# **Deforum Stable Diffusion v0.6**\n",
|
10 |
+
"[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Bj\u00f6rn Ommer and the [Stability.ai](https://stability.ai/) Team. [K Diffusion](https://github.com/crowsonkb/k-diffusion) by [Katherine Crowson](https://twitter.com/RiversHaveWings).\n",
|
11 |
+
"\n",
|
12 |
+
"[Quick Guide](https://docs.google.com/document/d/1RrQv7FntzOuLg4ohjRZPVL7iptIyBhwwbcEYEW2OfcI/edit?usp=sharing) to Deforum v0.6\n",
|
13 |
+
"\n",
|
14 |
+
"Notebook by [deforum](https://discord.gg/upmXXsrwZc)"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"metadata": {
|
20 |
+
"cellView": "form",
|
21 |
+
"id": "IJjzzkKlWM_s"
|
22 |
+
},
|
23 |
+
"source": [
|
24 |
+
"#@markdown **NVIDIA GPU**\n",
|
25 |
+
"import subprocess, os, sys\n",
|
26 |
+
"sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
27 |
+
"print(f\"{sub_p_res[:-1]}\")"
|
28 |
+
],
|
29 |
+
"outputs": [],
|
30 |
+
"execution_count": null
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "markdown",
|
34 |
+
"metadata": {
|
35 |
+
"id": "UA8-efH-WM_t"
|
36 |
+
},
|
37 |
+
"source": [
|
38 |
+
"# Setup"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"metadata": {
|
44 |
+
"cellView": "form",
|
45 |
+
"id": "0D2HQO-PWM_t"
|
46 |
+
},
|
47 |
+
"source": [
|
48 |
+
"\n",
|
49 |
+
"import subprocess, time, gc, os, sys\n",
|
50 |
+
"\n",
|
51 |
+
"def setup_environment():\n",
|
52 |
+
" print_subprocess = False\n",
|
53 |
+
" use_xformers_for_colab = True\n",
|
54 |
+
" try:\n",
|
55 |
+
" ipy = get_ipython()\n",
|
56 |
+
" except:\n",
|
57 |
+
" ipy = 'could not get_ipython'\n",
|
58 |
+
" if 'google.colab' in str(ipy):\n",
|
59 |
+
" print(\"..setting up environment\")\n",
|
60 |
+
" start_time = time.time()\n",
|
61 |
+
" all_process = [\n",
|
62 |
+
" ['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],\n",
|
63 |
+
" ['pip', 'install', 'omegaconf==2.2.3', 'einops==0.4.1', 'pytorch-lightning==1.7.4', 'torchmetrics==0.9.3', 'torchtext==0.13.1', 'transformers==4.21.2', 'kornia==0.6.7'],\n",
|
64 |
+
" ['git', 'clone', 'https://github.com/deforum-art/deforum-stable-diffusion'],\n",
|
65 |
+
" ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'matplotlib', 'resize-right', 'timm', 'torchdiffeq','scikit-learn'],\n",
|
66 |
+
" ]\n",
|
67 |
+
" for process in all_process:\n",
|
68 |
+
" running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
69 |
+
" if print_subprocess:\n",
|
70 |
+
" print(running)\n",
|
71 |
+
" with open('deforum-stable-diffusion/src/k_diffusion/__init__.py', 'w') as f:\n",
|
72 |
+
" f.write('')\n",
|
73 |
+
" sys.path.extend([\n",
|
74 |
+
" 'deforum-stable-diffusion/',\n",
|
75 |
+
" 'deforum-stable-diffusion/src',\n",
|
76 |
+
" ])\n",
|
77 |
+
" end_time = time.time()\n",
|
78 |
+
"\n",
|
79 |
+
" if use_xformers_for_colab:\n",
|
80 |
+
"\n",
|
81 |
+
" print(\"..installing xformers\")\n",
|
82 |
+
"\n",
|
83 |
+
" all_process = [['pip', 'install', 'triton==2.0.0.dev20220701']]\n",
|
84 |
+
" for process in all_process:\n",
|
85 |
+
" running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
86 |
+
" if print_subprocess:\n",
|
87 |
+
" print(running)\n",
|
88 |
+
" \n",
|
89 |
+
" v_card_name = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
90 |
+
" if 't4' in v_card_name.lower():\n",
|
91 |
+
" name_to_download = 'T4'\n",
|
92 |
+
" elif 'v100' in v_card_name.lower():\n",
|
93 |
+
" name_to_download = 'V100'\n",
|
94 |
+
" elif 'a100' in v_card_name.lower():\n",
|
95 |
+
" name_to_download = 'A100'\n",
|
96 |
+
" elif 'p100' in v_card_name.lower():\n",
|
97 |
+
" name_to_download = 'P100'\n",
|
98 |
+
" else:\n",
|
99 |
+
" print(v_card_name + ' is currently not supported with xformers flash attention in deforum!')\n",
|
100 |
+
"\n",
|
101 |
+
" x_ver = 'xformers-0.0.13.dev0-py3-none-any.whl'\n",
|
102 |
+
" x_link = 'https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/' + name_to_download + '/' + x_ver\n",
|
103 |
+
" \n",
|
104 |
+
" all_process = [\n",
|
105 |
+
" ['wget', x_link],\n",
|
106 |
+
" ['pip', 'install', x_ver],\n",
|
107 |
+
" ['mv', 'deforum-stable-diffusion/src/ldm/modules/attention.py', 'deforum-stable-diffusion/src/ldm/modules/attention_backup.py'],\n",
|
108 |
+
" ['mv', 'deforum-stable-diffusion/src/ldm/modules/attention_xformers.py', 'deforum-stable-diffusion/src/ldm/modules/attention.py']\n",
|
109 |
+
" ]\n",
|
110 |
+
"\n",
|
111 |
+
" for process in all_process:\n",
|
112 |
+
" running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
113 |
+
" if print_subprocess:\n",
|
114 |
+
" print(running)\n",
|
115 |
+
"\n",
|
116 |
+
" print(f\"Environment set up in {end_time-start_time:.0f} seconds\")\n",
|
117 |
+
" else:\n",
|
118 |
+
" sys.path.extend([\n",
|
119 |
+
" 'src'\n",
|
120 |
+
" ])\n",
|
121 |
+
" return\n",
|
122 |
+
"\n",
|
123 |
+
"setup_environment()\n",
|
124 |
+
"\n",
|
125 |
+
"import torch\n",
|
126 |
+
"import random\n",
|
127 |
+
"import clip\n",
|
128 |
+
"from IPython import display\n",
|
129 |
+
"from types import SimpleNamespace\n",
|
130 |
+
"from helpers.save_images import get_output_folder\n",
|
131 |
+
"from helpers.settings import load_args\n",
|
132 |
+
"from helpers.render import render_animation, render_input_video, render_image_batch, render_interpolation\n",
|
133 |
+
"from helpers.model_load import make_linear_decode, load_model, get_model_output_paths\n",
|
134 |
+
"from helpers.aesthetics import load_aesthetics_model\n",
|
135 |
+
"\n",
|
136 |
+
"#@markdown **Path Setup**\n",
|
137 |
+
"\n",
|
138 |
+
"def Root():\n",
|
139 |
+
" models_path = \"models\" #@param {type:\"string\"}\n",
|
140 |
+
" configs_path = \"configs\" #@param {type:\"string\"}\n",
|
141 |
+
" output_path = \"output\" #@param {type:\"string\"}\n",
|
142 |
+
" mount_google_drive = True #@param {type:\"boolean\"}\n",
|
143 |
+
" models_path_gdrive = \"/content/drive/MyDrive/AI/models\" #@param {type:\"string\"}\n",
|
144 |
+
" output_path_gdrive = \"/content/drive/MyDrive/AI/StableDiffusion\" #@param {type:\"string\"}\n",
|
145 |
+
"\n",
|
146 |
+
" #@markdown **Model Setup**\n",
|
147 |
+
" model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n",
|
148 |
+
" model_checkpoint = \"v1-5-pruned-emaonly.ckpt\" #@param [\"custom\",\"v1-5-pruned.ckpt\",\"v1-5-pruned-emaonly.ckpt\",\"sd-v1-4-full-ema.ckpt\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\", \"robo-diffusion-v1.ckpt\",\"wd-v1-3-float16.ckpt\"]\n",
|
149 |
+
" custom_config_path = \"\" #@param {type:\"string\"}\n",
|
150 |
+
" custom_checkpoint_path = \"\" #@param {type:\"string\"}\n",
|
151 |
+
" half_precision = True\n",
|
152 |
+
" return locals()\n",
|
153 |
+
"\n",
|
154 |
+
"root = Root()\n",
|
155 |
+
"root = SimpleNamespace(**root)\n",
|
156 |
+
"\n",
|
157 |
+
"root.models_path, root.output_path = get_model_output_paths(root)\n",
|
158 |
+
"root.model, root.device = load_model(root, \n",
|
159 |
+
" load_on_run_all=True\n",
|
160 |
+
" , \n",
|
161 |
+
" check_sha256=True\n",
|
162 |
+
" )"
|
163 |
+
],
|
164 |
+
"outputs": [],
|
165 |
+
"execution_count": null
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"cell_type": "markdown",
|
169 |
+
"metadata": {
|
170 |
+
"id": "6JxwhBwtWM_t"
|
171 |
+
},
|
172 |
+
"source": [
|
173 |
+
"# Settings"
|
174 |
+
]
|
175 |
+
},
|
176 |
+
{
|
177 |
+
"cell_type": "code",
|
178 |
+
"metadata": {
|
179 |
+
"cellView": "form",
|
180 |
+
"id": "E0tJVYA4WM_u"
|
181 |
+
},
|
182 |
+
"source": [
|
183 |
+
"def DeforumAnimArgs():\n",
|
184 |
+
"\n",
|
185 |
+
" #@markdown ####**Animation:**\n",
|
186 |
+
" animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'}\n",
|
187 |
+
" max_frames = 1000 #@param {type:\"number\"}\n",
|
188 |
+
" border = 'replicate' #@param ['wrap', 'replicate'] {type:'string'}\n",
|
189 |
+
"\n",
|
190 |
+
" #@markdown ####**Motion Parameters:**\n",
|
191 |
+
" angle = \"0:(0)\"#@param {type:\"string\"}\n",
|
192 |
+
" zoom = \"0:(1.04)\"#@param {type:\"string\"}\n",
|
193 |
+
" translation_x = \"0:(10*sin(2*3.14*t/10))\"#@param {type:\"string\"}\n",
|
194 |
+
" translation_y = \"0:(0)\"#@param {type:\"string\"}\n",
|
195 |
+
" translation_z = \"0:(10)\"#@param {type:\"string\"}\n",
|
196 |
+
" rotation_3d_x = \"0:(0)\"#@param {type:\"string\"}\n",
|
197 |
+
" rotation_3d_y = \"0:(0)\"#@param {type:\"string\"}\n",
|
198 |
+
" rotation_3d_z = \"0:(0)\"#@param {type:\"string\"}\n",
|
199 |
+
" flip_2d_perspective = False #@param {type:\"boolean\"}\n",
|
200 |
+
" perspective_flip_theta = \"0:(0)\"#@param {type:\"string\"}\n",
|
201 |
+
" perspective_flip_phi = \"0:(t%15)\"#@param {type:\"string\"}\n",
|
202 |
+
" perspective_flip_gamma = \"0:(0)\"#@param {type:\"string\"}\n",
|
203 |
+
" perspective_flip_fv = \"0:(53)\"#@param {type:\"string\"}\n",
|
204 |
+
" noise_schedule = \"0: (0.02)\"#@param {type:\"string\"}\n",
|
205 |
+
" strength_schedule = \"0: (0.65)\"#@param {type:\"string\"}\n",
|
206 |
+
" contrast_schedule = \"0: (1.0)\"#@param {type:\"string\"}\n",
|
207 |
+
"\n",
|
208 |
+
" #@markdown ####**Coherence:**\n",
|
209 |
+
" color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}\n",
|
210 |
+
" diffusion_cadence = '1' #@param ['1','2','3','4','5','6','7','8'] {type:'string'}\n",
|
211 |
+
"\n",
|
212 |
+
" #@markdown ####**3D Depth Warping:**\n",
|
213 |
+
" use_depth_warping = True #@param {type:\"boolean\"}\n",
|
214 |
+
" midas_weight = 0.3#@param {type:\"number\"}\n",
|
215 |
+
" near_plane = 200\n",
|
216 |
+
" far_plane = 10000\n",
|
217 |
+
" fov = 40#@param {type:\"number\"}\n",
|
218 |
+
" padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'}\n",
|
219 |
+
" sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'}\n",
|
220 |
+
" save_depth_maps = False #@param {type:\"boolean\"}\n",
|
221 |
+
"\n",
|
222 |
+
" #@markdown ####**Video Input:**\n",
|
223 |
+
" video_init_path ='/content/video_in.mp4'#@param {type:\"string\"}\n",
|
224 |
+
" extract_nth_frame = 1#@param {type:\"number\"}\n",
|
225 |
+
" overwrite_extracted_frames = True #@param {type:\"boolean\"}\n",
|
226 |
+
" use_mask_video = False #@param {type:\"boolean\"}\n",
|
227 |
+
" video_mask_path ='/content/video_in.mp4'#@param {type:\"string\"}\n",
|
228 |
+
"\n",
|
229 |
+
" #@markdown ####**Interpolation:**\n",
|
230 |
+
" interpolate_key_frames = False #@param {type:\"boolean\"}\n",
|
231 |
+
" interpolate_x_frames = 4 #@param {type:\"number\"}\n",
|
232 |
+
" \n",
|
233 |
+
" #@markdown ####**Resume Animation:**\n",
|
234 |
+
" resume_from_timestring = False #@param {type:\"boolean\"}\n",
|
235 |
+
" resume_timestring = \"20220829210106\" #@param {type:\"string\"}\n",
|
236 |
+
"\n",
|
237 |
+
" return locals()"
|
238 |
+
],
|
239 |
+
"outputs": [],
|
240 |
+
"execution_count": null
|
241 |
+
},
|
242 |
+
{
|
243 |
+
"cell_type": "code",
|
244 |
+
"metadata": {
|
245 |
+
"id": "i9fly1RIWM_u"
|
246 |
+
},
|
247 |
+
"source": [
|
248 |
+
"prompts = [\n",
|
249 |
+
" \"a beautiful lake by Asher Brown Durand, trending on Artstation\", # the first prompt I want\n",
|
250 |
+
" \"a beautiful portrait of a woman by Artgerm, trending on Artstation\", # the second prompt I want\n",
|
251 |
+
" #\"this prompt I don't want it I commented it out\",\n",
|
252 |
+
" #\"a nousr robot, trending on Artstation\", # use \"nousr robot\" with the robot diffusion model (see model_checkpoint setting)\n",
|
253 |
+
" #\"touhou 1girl komeiji_koishi portrait, green hair\", # waifu diffusion prompts can use danbooru tag groups (see model_checkpoint)\n",
|
254 |
+
" #\"this prompt has weights if prompt weighting enabled:2 can also do negative:-2\", # (see prompt_weighting)\n",
|
255 |
+
"]\n",
|
256 |
+
"\n",
|
257 |
+
"animation_prompts = {\n",
|
258 |
+
" 0: \"a beautiful apple, trending on Artstation\",\n",
|
259 |
+
" 20: \"a beautiful banana, trending on Artstation\",\n",
|
260 |
+
" 30: \"a beautiful coconut, trending on Artstation\",\n",
|
261 |
+
" 40: \"a beautiful durian, trending on Artstation\",\n",
|
262 |
+
"}"
|
263 |
+
],
|
264 |
+
"outputs": [],
|
265 |
+
"execution_count": null
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"cell_type": "code",
|
269 |
+
"metadata": {
|
270 |
+
"cellView": "form",
|
271 |
+
"id": "XVzhbmizWM_u"
|
272 |
+
},
|
273 |
+
"source": [
|
274 |
+
"#@markdown **Load Settings**\n",
|
275 |
+
"override_settings_with_file = False #@param {type:\"boolean\"}\n",
|
276 |
+
"settings_file = \"custom\" #@param [\"custom\", \"512x512_aesthetic_0.json\",\"512x512_aesthetic_1.json\",\"512x512_colormatch_0.json\",\"512x512_colormatch_1.json\",\"512x512_colormatch_2.json\",\"512x512_colormatch_3.json\"]\n",
|
277 |
+
"custom_settings_file = \"/content/drive/MyDrive/Settings.txt\"#@param {type:\"string\"}\n",
|
278 |
+
"\n",
|
279 |
+
"def DeforumArgs():\n",
|
280 |
+
" #@markdown **Image Settings**\n",
|
281 |
+
" W = 512 #@param\n",
|
282 |
+
" H = 512 #@param\n",
|
283 |
+
" W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64\n",
|
284 |
+
"\n",
|
285 |
+
" #@markdown **Sampling Settings**\n",
|
286 |
+
" seed = -1 #@param\n",
|
287 |
+
" sampler = 'dpmpp_2s_a' #@param [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\",\"plms\", \"ddim\", \"dpm_fast\", \"dpm_adaptive\", \"dpmpp_2s_a\", \"dpmpp_2m\"]\n",
|
288 |
+
" steps = 80 #@param\n",
|
289 |
+
" scale = 7 #@param\n",
|
290 |
+
" ddim_eta = 0.0 #@param\n",
|
291 |
+
" dynamic_threshold = None\n",
|
292 |
+
" static_threshold = None \n",
|
293 |
+
"\n",
|
294 |
+
" #@markdown **Save & Display Settings**\n",
|
295 |
+
" save_samples = True #@param {type:\"boolean\"}\n",
|
296 |
+
" save_settings = True #@param {type:\"boolean\"}\n",
|
297 |
+
" display_samples = True #@param {type:\"boolean\"}\n",
|
298 |
+
" save_sample_per_step = False #@param {type:\"boolean\"}\n",
|
299 |
+
" show_sample_per_step = False #@param {type:\"boolean\"}\n",
|
300 |
+
"\n",
|
301 |
+
" #@markdown **Prompt Settings**\n",
|
302 |
+
" prompt_weighting = True #@param {type:\"boolean\"}\n",
|
303 |
+
" normalize_prompt_weights = True #@param {type:\"boolean\"}\n",
|
304 |
+
" log_weighted_subprompts = False #@param {type:\"boolean\"}\n",
|
305 |
+
"\n",
|
306 |
+
" #@markdown **Batch Settings**\n",
|
307 |
+
" n_batch = 1 #@param\n",
|
308 |
+
" batch_name = \"StableFun\" #@param {type:\"string\"}\n",
|
309 |
+
" filename_format = \"{timestring}_{index}_{prompt}.png\" #@param [\"{timestring}_{index}_{seed}.png\",\"{timestring}_{index}_{prompt}.png\"]\n",
|
310 |
+
" seed_behavior = \"iter\" #@param [\"iter\",\"fixed\",\"random\"]\n",
|
311 |
+
" make_grid = False #@param {type:\"boolean\"}\n",
|
312 |
+
" grid_rows = 2 #@param \n",
|
313 |
+
" outdir = get_output_folder(root.output_path, batch_name)\n",
|
314 |
+
"\n",
|
315 |
+
" #@markdown **Init Settings**\n",
|
316 |
+
" use_init = False #@param {type:\"boolean\"}\n",
|
317 |
+
" strength = 0.0 #@param {type:\"number\"}\n",
|
318 |
+
" strength_0_no_init = True # Set the strength to 0 automatically when no init image is used\n",
|
319 |
+
" init_image = \"https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg\" #@param {type:\"string\"}\n",
|
320 |
+
" # Whiter areas of the mask are areas that change more\n",
|
321 |
+
" use_mask = False #@param {type:\"boolean\"}\n",
|
322 |
+
" use_alpha_as_mask = False # use the alpha channel of the init image as the mask\n",
|
323 |
+
" mask_file = \"https://www.filterforge.com/wiki/images/archive/b/b7/20080927223728%21Polygonal_gradient_thumb.jpg\" #@param {type:\"string\"}\n",
|
324 |
+
" invert_mask = False #@param {type:\"boolean\"}\n",
|
325 |
+
" # Adjust mask image, 1.0 is no adjustment. Should be positive numbers.\n",
|
326 |
+
" mask_brightness_adjust = 1.0 #@param {type:\"number\"}\n",
|
327 |
+
" mask_contrast_adjust = 1.0 #@param {type:\"number\"}\n",
|
328 |
+
" # Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding\n",
|
329 |
+
" overlay_mask = True # {type:\"boolean\"}\n",
|
330 |
+
" # Blur edges of final overlay mask, if used. Minimum = 0 (no blur)\n",
|
331 |
+
" mask_overlay_blur = 5 # {type:\"number\"}\n",
|
332 |
+
"\n",
|
333 |
+
" #@markdown **Exposure/Contrast Conditional Settings**\n",
|
334 |
+
" mean_scale = 0 #@param {type:\"number\"}\n",
|
335 |
+
" var_scale = 0 #@param {type:\"number\"}\n",
|
336 |
+
" exposure_scale = 0 #@param {type:\"number\"}\n",
|
337 |
+
" exposure_target = 0.5 #@param {type:\"number\"}\n",
|
338 |
+
"\n",
|
339 |
+
" #@markdown **Color Match Conditional Settings**\n",
|
340 |
+
" colormatch_scale = 0 #@param {type:\"number\"}\n",
|
341 |
+
" colormatch_image = \"https://www.saasdesign.io/wp-content/uploads/2021/02/palette-3-min-980x588.png\" #@param {type:\"string\"}\n",
|
342 |
+
" colormatch_n_colors = 4 #@param {type:\"number\"}\n",
|
343 |
+
" ignore_sat_weight = 0 #@param {type:\"number\"}\n",
|
344 |
+
"\n",
|
345 |
+
" #@markdown **CLIP\\Aesthetics Conditional Settings**\n",
|
346 |
+
" clip_name = 'ViT-L/14' #@param ['ViT-L/14', 'ViT-L/14@336px', 'ViT-B/16', 'ViT-B/32']\n",
|
347 |
+
" clip_scale = 0 #@param {type:\"number\"}\n",
|
348 |
+
" aesthetics_scale = 0 #@param {type:\"number\"}\n",
|
349 |
+
" cutn = 1 #@param {type:\"number\"}\n",
|
350 |
+
" cut_pow = 0.0001 #@param {type:\"number\"}\n",
|
351 |
+
"\n",
|
352 |
+
" #@markdown **Other Conditional Settings**\n",
|
353 |
+
" init_mse_scale = 0 #@param {type:\"number\"}\n",
|
354 |
+
" init_mse_image = \"https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg\" #@param {type:\"string\"}\n",
|
355 |
+
"\n",
|
356 |
+
" blue_scale = 0 #@param {type:\"number\"}\n",
|
357 |
+
" \n",
|
358 |
+
" #@markdown **Conditional Gradient Settings**\n",
|
359 |
+
" gradient_wrt = 'x0_pred' #@param [\"x\", \"x0_pred\"]\n",
|
360 |
+
" gradient_add_to = 'both' #@param [\"cond\", \"uncond\", \"both\"]\n",
|
361 |
+
" decode_method = 'linear' #@param [\"autoencoder\",\"linear\"]\n",
|
362 |
+
" grad_threshold_type = 'dynamic' #@param [\"dynamic\", \"static\", \"mean\", \"schedule\"]\n",
|
363 |
+
" clamp_grad_threshold = 0.2 #@param {type:\"number\"}\n",
|
364 |
+
" clamp_start = 0.2 #@param\n",
|
365 |
+
" clamp_stop = 0.01 #@param\n",
|
366 |
+
" grad_inject_timing = list(range(1,10)) #@param\n",
|
367 |
+
"\n",
|
368 |
+
" #@markdown **Speed vs VRAM Settings**\n",
|
369 |
+
" cond_uncond_sync = True #@param {type:\"boolean\"}\n",
|
370 |
+
"\n",
|
371 |
+
" n_samples = 1 # doesnt do anything\n",
|
372 |
+
" precision = 'autocast' \n",
|
373 |
+
" C = 4\n",
|
374 |
+
" f = 8\n",
|
375 |
+
"\n",
|
376 |
+
" prompt = \"\"\n",
|
377 |
+
" timestring = \"\"\n",
|
378 |
+
" init_latent = None\n",
|
379 |
+
" init_sample = None\n",
|
380 |
+
" init_sample_raw = None\n",
|
381 |
+
" mask_sample = None\n",
|
382 |
+
" init_c = None\n",
|
383 |
+
"\n",
|
384 |
+
" return locals()\n",
|
385 |
+
"\n",
|
386 |
+
"args_dict = DeforumArgs()\n",
|
387 |
+
"anim_args_dict = DeforumAnimArgs()\n",
|
388 |
+
"\n",
|
389 |
+
"if override_settings_with_file:\n",
|
390 |
+
" load_args(args_dict, anim_args_dict, settings_file, custom_settings_file, verbose=False)\n",
|
391 |
+
"\n",
|
392 |
+
"args = SimpleNamespace(**args_dict)\n",
|
393 |
+
"anim_args = SimpleNamespace(**anim_args_dict)\n",
|
394 |
+
"\n",
|
395 |
+
"args.timestring = time.strftime('%Y%m%d%H%M%S')\n",
|
396 |
+
"args.strength = max(0.0, min(1.0, args.strength))\n",
|
397 |
+
"\n",
|
398 |
+
"# Load clip model if using clip guidance\n",
|
399 |
+
"if (args.clip_scale > 0) or (args.aesthetics_scale > 0):\n",
|
400 |
+
" root.clip_model = clip.load(args.clip_name, jit=False)[0].eval().requires_grad_(False).to(root.device)\n",
|
401 |
+
" if (args.aesthetics_scale > 0):\n",
|
402 |
+
" root.aesthetics_model = load_aesthetics_model(args, root)\n",
|
403 |
+
"\n",
|
404 |
+
"if args.seed == -1:\n",
|
405 |
+
" args.seed = random.randint(0, 2**32 - 1)\n",
|
406 |
+
"if not args.use_init:\n",
|
407 |
+
" args.init_image = None\n",
|
408 |
+
"if args.sampler == 'plms' and (args.use_init or anim_args.animation_mode != 'None'):\n",
|
409 |
+
" print(f\"Init images aren't supported with PLMS yet, switching to KLMS\")\n",
|
410 |
+
" args.sampler = 'klms'\n",
|
411 |
+
"if args.sampler != 'ddim':\n",
|
412 |
+
" args.ddim_eta = 0\n",
|
413 |
+
"\n",
|
414 |
+
"if anim_args.animation_mode == 'None':\n",
|
415 |
+
" anim_args.max_frames = 1\n",
|
416 |
+
"elif anim_args.animation_mode == 'Video Input':\n",
|
417 |
+
" args.use_init = True\n",
|
418 |
+
"\n",
|
419 |
+
"# clean up unused memory\n",
|
420 |
+
"gc.collect()\n",
|
421 |
+
"torch.cuda.empty_cache()\n",
|
422 |
+
"\n",
|
423 |
+
"# dispatch to appropriate renderer\n",
|
424 |
+
"if anim_args.animation_mode == '2D' or anim_args.animation_mode == '3D':\n",
|
425 |
+
" render_animation(args, anim_args, animation_prompts, root)\n",
|
426 |
+
"elif anim_args.animation_mode == 'Video Input':\n",
|
427 |
+
" render_input_video(args, anim_args, animation_prompts, root)\n",
|
428 |
+
"elif anim_args.animation_mode == 'Interpolation':\n",
|
429 |
+
" render_interpolation(args, anim_args, animation_prompts, root)\n",
|
430 |
+
"else:\n",
|
431 |
+
" render_image_batch(args, prompts, root)"
|
432 |
+
],
|
433 |
+
"outputs": [],
|
434 |
+
"execution_count": null
|
435 |
+
},
|
436 |
+
{
|
437 |
+
"cell_type": "markdown",
|
438 |
+
"metadata": {
|
439 |
+
"id": "gJ88kZ2-WM_v"
|
440 |
+
},
|
441 |
+
"source": [
|
442 |
+
"# Create Video From Frames"
|
443 |
+
]
|
444 |
+
},
|
445 |
+
{
|
446 |
+
"cell_type": "code",
|
447 |
+
"metadata": {
|
448 |
+
"cellView": "form",
|
449 |
+
"id": "XQGeqaGAWM_v"
|
450 |
+
},
|
451 |
+
"source": [
|
452 |
+
"skip_video_for_run_all = True #@param {type: 'boolean'}\n",
|
453 |
+
"fps = 12 #@param {type:\"number\"}\n",
|
454 |
+
"#@markdown **Manual Settings**\n",
|
455 |
+
"use_manual_settings = False #@param {type:\"boolean\"}\n",
|
456 |
+
"image_path = \"/content/drive/MyDrive/AI/StableDiffusion/2022-09/20220903000939_%05d.png\" #@param {type:\"string\"}\n",
|
457 |
+
"mp4_path = \"/content/drive/MyDrive/AI/StableDiffusion/2022-09/20220903000939.mp4\" #@param {type:\"string\"}\n",
|
458 |
+
"render_steps = False #@param {type: 'boolean'}\n",
|
459 |
+
"path_name_modifier = \"x0_pred\" #@param [\"x0_pred\",\"x\"]\n",
|
460 |
+
"make_gif = False\n",
|
461 |
+
"\n",
|
462 |
+
"if skip_video_for_run_all == True:\n",
|
463 |
+
" print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')\n",
|
464 |
+
"else:\n",
|
465 |
+
" import os\n",
|
466 |
+
" import subprocess\n",
|
467 |
+
" from base64 import b64encode\n",
|
468 |
+
"\n",
|
469 |
+
" print(f\"{image_path} -> {mp4_path}\")\n",
|
470 |
+
"\n",
|
471 |
+
" if use_manual_settings:\n",
|
472 |
+
" max_frames = \"200\" #@param {type:\"string\"}\n",
|
473 |
+
" else:\n",
|
474 |
+
" if render_steps: # render steps from a single image\n",
|
475 |
+
" fname = f\"{path_name_modifier}_%05d.png\"\n",
|
476 |
+
" all_step_dirs = [os.path.join(args.outdir, d) for d in os.listdir(args.outdir) if os.path.isdir(os.path.join(args.outdir,d))]\n",
|
477 |
+
" newest_dir = max(all_step_dirs, key=os.path.getmtime)\n",
|
478 |
+
" image_path = os.path.join(newest_dir, fname)\n",
|
479 |
+
" print(f\"Reading images from {image_path}\")\n",
|
480 |
+
" mp4_path = os.path.join(newest_dir, f\"{args.timestring}_{path_name_modifier}.mp4\")\n",
|
481 |
+
" max_frames = str(args.steps)\n",
|
482 |
+
" else: # render images for a video\n",
|
483 |
+
" image_path = os.path.join(args.outdir, f\"{args.timestring}_%05d.png\")\n",
|
484 |
+
" mp4_path = os.path.join(args.outdir, f\"{args.timestring}.mp4\")\n",
|
485 |
+
" max_frames = str(anim_args.max_frames)\n",
|
486 |
+
"\n",
|
487 |
+
" # make video\n",
|
488 |
+
" cmd = [\n",
|
489 |
+
" 'ffmpeg',\n",
|
490 |
+
" '-y',\n",
|
491 |
+
" '-vcodec', 'png',\n",
|
492 |
+
" '-r', str(fps),\n",
|
493 |
+
" '-start_number', str(0),\n",
|
494 |
+
" '-i', image_path,\n",
|
495 |
+
" '-frames:v', max_frames,\n",
|
496 |
+
" '-c:v', 'libx264',\n",
|
497 |
+
" '-vf',\n",
|
498 |
+
" f'fps={fps}',\n",
|
499 |
+
" '-pix_fmt', 'yuv420p',\n",
|
500 |
+
" '-crf', '17',\n",
|
501 |
+
" '-preset', 'veryfast',\n",
|
502 |
+
" '-pattern_type', 'sequence',\n",
|
503 |
+
" mp4_path\n",
|
504 |
+
" ]\n",
|
505 |
+
" process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
|
506 |
+
" stdout, stderr = process.communicate()\n",
|
507 |
+
" if process.returncode != 0:\n",
|
508 |
+
" print(stderr)\n",
|
509 |
+
" raise RuntimeError(stderr)\n",
|
510 |
+
"\n",
|
511 |
+
" mp4 = open(mp4_path,'rb').read()\n",
|
512 |
+
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
513 |
+
" display.display(display.HTML(f'<video controls loop><source src=\"{data_url}\" type=\"video/mp4\"></video>') )\n",
|
514 |
+
" \n",
|
515 |
+
" if make_gif:\n",
|
516 |
+
" gif_path = os.path.splitext(mp4_path)[0]+'.gif'\n",
|
517 |
+
" cmd_gif = [\n",
|
518 |
+
" 'ffmpeg',\n",
|
519 |
+
" '-y',\n",
|
520 |
+
" '-i', mp4_path,\n",
|
521 |
+
" '-r', str(fps),\n",
|
522 |
+
" gif_path\n",
|
523 |
+
" ]\n",
|
524 |
+
" process_gif = subprocess.Popen(cmd_gif, stdout=subprocess.PIPE, stderr=subprocess.PIPE)"
|
525 |
+
],
|
526 |
+
"outputs": [],
|
527 |
+
"execution_count": null
|
528 |
+
},
|
529 |
+
{
|
530 |
+
"cell_type": "code",
|
531 |
+
"metadata": {
|
532 |
+
"cellView": "form",
|
533 |
+
"id": "MMpAcyrYWM_v"
|
534 |
+
},
|
535 |
+
"source": [
|
536 |
+
"skip_disconnect_for_run_all = True #@param {type: 'boolean'}\n",
|
537 |
+
"\n",
|
538 |
+
"if skip_disconnect_for_run_all == True:\n",
|
539 |
+
" print('Skipping disconnect, uncheck skip_disconnect_for_run_all if you want to run it')\n",
|
540 |
+
"else:\n",
|
541 |
+
" from google.colab import runtime\n",
|
542 |
+
" runtime.unassign()"
|
543 |
+
],
|
544 |
+
"outputs": [],
|
545 |
+
"execution_count": null
|
546 |
+
}
|
547 |
+
],
|
548 |
+
"metadata": {
|
549 |
+
"kernelspec": {
|
550 |
+
"display_name": "Python 3.10.6 ('dsd')",
|
551 |
+
"language": "python",
|
552 |
+
"name": "python3"
|
553 |
+
},
|
554 |
+
"language_info": {
|
555 |
+
"codemirror_mode": {
|
556 |
+
"name": "ipython",
|
557 |
+
"version": 3
|
558 |
+
},
|
559 |
+
"file_extension": ".py",
|
560 |
+
"mimetype": "text/x-python",
|
561 |
+
"name": "python",
|
562 |
+
"nbconvert_exporter": "python",
|
563 |
+
"pygments_lexer": "ipython3",
|
564 |
+
"version": "3.10.6"
|
565 |
+
},
|
566 |
+
"orig_nbformat": 4,
|
567 |
+
"vscode": {
|
568 |
+
"interpreter": {
|
569 |
+
"hash": "b7e04c8a9537645cbc77fa0cbde8069bc94e341b0d5ced104651213865b24e58"
|
570 |
+
}
|
571 |
+
},
|
572 |
+
"colab": {
|
573 |
+
"provenance": []
|
574 |
+
},
|
575 |
+
"accelerator": "GPU",
|
576 |
+
"gpuClass": "standard"
|
577 |
+
},
|
578 |
+
"nbformat": 4,
|
579 |
+
"nbformat_minor": 4
|
580 |
+
}
|
deforum-stable-diffusion/Deforum_Stable_Diffusion.py
ADDED
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
# !! {"metadata":{
|
3 |
+
# !! "id": "ByGXyiHZWM_q"
|
4 |
+
# !! }}
|
5 |
+
"""
|
6 |
+
# **Deforum Stable Diffusion v0.6**
|
7 |
+
[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer and the [Stability.ai](https://stability.ai/) Team. [K Diffusion](https://github.com/crowsonkb/k-diffusion) by [Katherine Crowson](https://twitter.com/RiversHaveWings).
|
8 |
+
|
9 |
+
[Quick Guide](https://docs.google.com/document/d/1RrQv7FntzOuLg4ohjRZPVL7iptIyBhwwbcEYEW2OfcI/edit?usp=sharing) to Deforum v0.6
|
10 |
+
|
11 |
+
Notebook by [deforum](https://discord.gg/upmXXsrwZc)
|
12 |
+
"""
|
13 |
+
|
14 |
+
# %%
|
15 |
+
# !! {"metadata":{
|
16 |
+
# !! "cellView": "form",
|
17 |
+
# !! "id": "IJjzzkKlWM_s"
|
18 |
+
# !! }}
|
19 |
+
#@markdown **NVIDIA GPU**
|
20 |
+
import subprocess, os, sys
|
21 |
+
sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')
|
22 |
+
print(f"{sub_p_res[:-1]}")
|
23 |
+
|
24 |
+
# %%
|
25 |
+
# !! {"metadata":{
|
26 |
+
# !! "id": "UA8-efH-WM_t"
|
27 |
+
# !! }}
|
28 |
+
"""
|
29 |
+
# Setup
|
30 |
+
"""
|
31 |
+
|
32 |
+
# %%
|
33 |
+
# !! {"metadata":{
|
34 |
+
# !! "cellView": "form",
|
35 |
+
# !! "id": "0D2HQO-PWM_t"
|
36 |
+
# !! }}
|
37 |
+
|
38 |
+
import subprocess, time, gc, os, sys
|
39 |
+
|
40 |
+
def setup_environment():
|
41 |
+
print_subprocess = False
|
42 |
+
use_xformers_for_colab = True
|
43 |
+
try:
|
44 |
+
ipy = get_ipython()
|
45 |
+
except:
|
46 |
+
ipy = 'could not get_ipython'
|
47 |
+
if 'google.colab' in str(ipy):
|
48 |
+
print("..setting up environment")
|
49 |
+
start_time = time.time()
|
50 |
+
all_process = [
|
51 |
+
['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],
|
52 |
+
['pip', 'install', 'omegaconf==2.2.3', 'einops==0.4.1', 'pytorch-lightning==1.7.4', 'torchmetrics==0.9.3', 'torchtext==0.13.1', 'transformers==4.21.2', 'kornia==0.6.7'],
|
53 |
+
['git', 'clone', 'https://github.com/deforum-art/deforum-stable-diffusion'],
|
54 |
+
['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'matplotlib', 'resize-right', 'timm', 'torchdiffeq','scikit-learn'],
|
55 |
+
]
|
56 |
+
for process in all_process:
|
57 |
+
running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')
|
58 |
+
if print_subprocess:
|
59 |
+
print(running)
|
60 |
+
with open('deforum-stable-diffusion/src/k_diffusion/__init__.py', 'w') as f:
|
61 |
+
f.write('')
|
62 |
+
sys.path.extend([
|
63 |
+
'deforum-stable-diffusion/',
|
64 |
+
'deforum-stable-diffusion/src',
|
65 |
+
])
|
66 |
+
end_time = time.time()
|
67 |
+
|
68 |
+
if use_xformers_for_colab:
|
69 |
+
|
70 |
+
print("..installing xformers")
|
71 |
+
|
72 |
+
all_process = [['pip', 'install', 'triton==2.0.0.dev20220701']]
|
73 |
+
for process in all_process:
|
74 |
+
running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')
|
75 |
+
if print_subprocess:
|
76 |
+
print(running)
|
77 |
+
|
78 |
+
v_card_name = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')
|
79 |
+
if 't4' in v_card_name.lower():
|
80 |
+
name_to_download = 'T4'
|
81 |
+
elif 'v100' in v_card_name.lower():
|
82 |
+
name_to_download = 'V100'
|
83 |
+
elif 'a100' in v_card_name.lower():
|
84 |
+
name_to_download = 'A100'
|
85 |
+
elif 'p100' in v_card_name.lower():
|
86 |
+
name_to_download = 'P100'
|
87 |
+
else:
|
88 |
+
print(v_card_name + ' is currently not supported with xformers flash attention in deforum!')
|
89 |
+
|
90 |
+
x_ver = 'xformers-0.0.13.dev0-py3-none-any.whl'
|
91 |
+
x_link = 'https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/' + name_to_download + '/' + x_ver
|
92 |
+
|
93 |
+
all_process = [
|
94 |
+
['wget', x_link],
|
95 |
+
['pip', 'install', x_ver],
|
96 |
+
['mv', 'deforum-stable-diffusion/src/ldm/modules/attention.py', 'deforum-stable-diffusion/src/ldm/modules/attention_backup.py'],
|
97 |
+
['mv', 'deforum-stable-diffusion/src/ldm/modules/attention_xformers.py', 'deforum-stable-diffusion/src/ldm/modules/attention.py']
|
98 |
+
]
|
99 |
+
|
100 |
+
for process in all_process:
|
101 |
+
running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')
|
102 |
+
if print_subprocess:
|
103 |
+
print(running)
|
104 |
+
|
105 |
+
print(f"Environment set up in {end_time-start_time:.0f} seconds")
|
106 |
+
else:
|
107 |
+
sys.path.extend([
|
108 |
+
'src'
|
109 |
+
])
|
110 |
+
return
|
111 |
+
|
112 |
+
setup_environment()
|
113 |
+
|
114 |
+
import torch
|
115 |
+
import random
|
116 |
+
import clip
|
117 |
+
from IPython import display
|
118 |
+
from types import SimpleNamespace
|
119 |
+
from helpers.save_images import get_output_folder
|
120 |
+
from helpers.settings import load_args
|
121 |
+
from helpers.render import render_animation, render_input_video, render_image_batch, render_interpolation
|
122 |
+
from helpers.model_load import make_linear_decode, load_model, get_model_output_paths
|
123 |
+
from helpers.aesthetics import load_aesthetics_model
|
124 |
+
|
125 |
+
#@markdown **Path Setup**
|
126 |
+
|
127 |
+
def Root():
|
128 |
+
models_path = "models" #@param {type:"string"}
|
129 |
+
configs_path = "configs" #@param {type:"string"}
|
130 |
+
output_path = "output" #@param {type:"string"}
|
131 |
+
mount_google_drive = True #@param {type:"boolean"}
|
132 |
+
models_path_gdrive = "/content/drive/MyDrive/AI/models" #@param {type:"string"}
|
133 |
+
output_path_gdrive = "/content/drive/MyDrive/AI/StableDiffusion" #@param {type:"string"}
|
134 |
+
|
135 |
+
#@markdown **Model Setup**
|
136 |
+
model_config = "v1-inference.yaml" #@param ["custom","v1-inference.yaml"]
|
137 |
+
model_checkpoint = "v1-5-pruned-emaonly.ckpt" #@param ["custom","v1-5-pruned.ckpt","v1-5-pruned-emaonly.ckpt","sd-v1-4-full-ema.ckpt","sd-v1-4.ckpt","sd-v1-3-full-ema.ckpt","sd-v1-3.ckpt","sd-v1-2-full-ema.ckpt","sd-v1-2.ckpt","sd-v1-1-full-ema.ckpt","sd-v1-1.ckpt", "robo-diffusion-v1.ckpt","wd-v1-3-float16.ckpt"]
|
138 |
+
custom_config_path = "" #@param {type:"string"}
|
139 |
+
custom_checkpoint_path = "" #@param {type:"string"}
|
140 |
+
half_precision = True
|
141 |
+
return locals()
|
142 |
+
|
143 |
+
root = Root()
|
144 |
+
root = SimpleNamespace(**root)
|
145 |
+
|
146 |
+
root.models_path, root.output_path = get_model_output_paths(root)
|
147 |
+
root.model, root.device = load_model(root,
|
148 |
+
load_on_run_all=True
|
149 |
+
,
|
150 |
+
check_sha256=True
|
151 |
+
)
|
152 |
+
|
153 |
+
# %%
|
154 |
+
# !! {"metadata":{
|
155 |
+
# !! "id": "6JxwhBwtWM_t"
|
156 |
+
# !! }}
|
157 |
+
"""
|
158 |
+
# Settings
|
159 |
+
"""
|
160 |
+
|
161 |
+
# %%
|
162 |
+
# !! {"metadata":{
|
163 |
+
# !! "cellView": "form",
|
164 |
+
# !! "id": "E0tJVYA4WM_u"
|
165 |
+
# !! }}
|
166 |
+
def DeforumAnimArgs():
|
167 |
+
|
168 |
+
#@markdown ####**Animation:**
|
169 |
+
animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'}
|
170 |
+
max_frames = 1000 #@param {type:"number"}
|
171 |
+
border = 'replicate' #@param ['wrap', 'replicate'] {type:'string'}
|
172 |
+
|
173 |
+
#@markdown ####**Motion Parameters:**
|
174 |
+
angle = "0:(0)"#@param {type:"string"}
|
175 |
+
zoom = "0:(1.04)"#@param {type:"string"}
|
176 |
+
translation_x = "0:(10*sin(2*3.14*t/10))"#@param {type:"string"}
|
177 |
+
translation_y = "0:(0)"#@param {type:"string"}
|
178 |
+
translation_z = "0:(10)"#@param {type:"string"}
|
179 |
+
rotation_3d_x = "0:(0)"#@param {type:"string"}
|
180 |
+
rotation_3d_y = "0:(0)"#@param {type:"string"}
|
181 |
+
rotation_3d_z = "0:(0)"#@param {type:"string"}
|
182 |
+
flip_2d_perspective = False #@param {type:"boolean"}
|
183 |
+
perspective_flip_theta = "0:(0)"#@param {type:"string"}
|
184 |
+
perspective_flip_phi = "0:(t%15)"#@param {type:"string"}
|
185 |
+
perspective_flip_gamma = "0:(0)"#@param {type:"string"}
|
186 |
+
perspective_flip_fv = "0:(53)"#@param {type:"string"}
|
187 |
+
noise_schedule = "0: (0.02)"#@param {type:"string"}
|
188 |
+
strength_schedule = "0: (0.65)"#@param {type:"string"}
|
189 |
+
contrast_schedule = "0: (1.0)"#@param {type:"string"}
|
190 |
+
|
191 |
+
#@markdown ####**Coherence:**
|
192 |
+
color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}
|
193 |
+
diffusion_cadence = '1' #@param ['1','2','3','4','5','6','7','8'] {type:'string'}
|
194 |
+
|
195 |
+
#@markdown ####**3D Depth Warping:**
|
196 |
+
use_depth_warping = True #@param {type:"boolean"}
|
197 |
+
midas_weight = 0.3#@param {type:"number"}
|
198 |
+
near_plane = 200
|
199 |
+
far_plane = 10000
|
200 |
+
fov = 40#@param {type:"number"}
|
201 |
+
padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'}
|
202 |
+
sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'}
|
203 |
+
save_depth_maps = False #@param {type:"boolean"}
|
204 |
+
|
205 |
+
#@markdown ####**Video Input:**
|
206 |
+
video_init_path ='/content/video_in.mp4'#@param {type:"string"}
|
207 |
+
extract_nth_frame = 1#@param {type:"number"}
|
208 |
+
overwrite_extracted_frames = True #@param {type:"boolean"}
|
209 |
+
use_mask_video = False #@param {type:"boolean"}
|
210 |
+
video_mask_path ='/content/video_in.mp4'#@param {type:"string"}
|
211 |
+
|
212 |
+
#@markdown ####**Interpolation:**
|
213 |
+
interpolate_key_frames = False #@param {type:"boolean"}
|
214 |
+
interpolate_x_frames = 4 #@param {type:"number"}
|
215 |
+
|
216 |
+
#@markdown ####**Resume Animation:**
|
217 |
+
resume_from_timestring = False #@param {type:"boolean"}
|
218 |
+
resume_timestring = "20220829210106" #@param {type:"string"}
|
219 |
+
|
220 |
+
return locals()
|
221 |
+
|
222 |
+
# %%
|
223 |
+
# !! {"metadata":{
|
224 |
+
# !! "id": "i9fly1RIWM_u"
|
225 |
+
# !! }}
|
226 |
+
prompts = [
|
227 |
+
"a beautiful lake by Asher Brown Durand, trending on Artstation", # the first prompt I want
|
228 |
+
"a beautiful portrait of a woman by Artgerm, trending on Artstation", # the second prompt I want
|
229 |
+
#"this prompt I don't want it I commented it out",
|
230 |
+
#"a nousr robot, trending on Artstation", # use "nousr robot" with the robot diffusion model (see model_checkpoint setting)
|
231 |
+
#"touhou 1girl komeiji_koishi portrait, green hair", # waifu diffusion prompts can use danbooru tag groups (see model_checkpoint)
|
232 |
+
#"this prompt has weights if prompt weighting enabled:2 can also do negative:-2", # (see prompt_weighting)
|
233 |
+
]
|
234 |
+
|
235 |
+
animation_prompts = {
|
236 |
+
0: "a beautiful apple, trending on Artstation",
|
237 |
+
20: "a beautiful banana, trending on Artstation",
|
238 |
+
30: "a beautiful coconut, trending on Artstation",
|
239 |
+
40: "a beautiful durian, trending on Artstation",
|
240 |
+
}
|
241 |
+
|
242 |
+
# %%
|
243 |
+
# !! {"metadata":{
|
244 |
+
# !! "cellView": "form",
|
245 |
+
# !! "id": "XVzhbmizWM_u"
|
246 |
+
# !! }}
|
247 |
+
#@markdown **Load Settings**
|
248 |
+
override_settings_with_file = False #@param {type:"boolean"}
|
249 |
+
settings_file = "custom" #@param ["custom", "512x512_aesthetic_0.json","512x512_aesthetic_1.json","512x512_colormatch_0.json","512x512_colormatch_1.json","512x512_colormatch_2.json","512x512_colormatch_3.json"]
|
250 |
+
custom_settings_file = "/content/drive/MyDrive/Settings.txt"#@param {type:"string"}
|
251 |
+
|
252 |
+
def DeforumArgs():
|
253 |
+
#@markdown **Image Settings**
|
254 |
+
W = 512 #@param
|
255 |
+
H = 512 #@param
|
256 |
+
W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64
|
257 |
+
|
258 |
+
#@markdown **Sampling Settings**
|
259 |
+
seed = -1 #@param
|
260 |
+
sampler = 'dpmpp_2s_a' #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim", "dpm_fast", "dpm_adaptive", "dpmpp_2s_a", "dpmpp_2m"]
|
261 |
+
steps = 80 #@param
|
262 |
+
scale = 7 #@param
|
263 |
+
ddim_eta = 0.0 #@param
|
264 |
+
dynamic_threshold = None
|
265 |
+
static_threshold = None
|
266 |
+
|
267 |
+
#@markdown **Save & Display Settings**
|
268 |
+
save_samples = True #@param {type:"boolean"}
|
269 |
+
save_settings = True #@param {type:"boolean"}
|
270 |
+
display_samples = True #@param {type:"boolean"}
|
271 |
+
save_sample_per_step = False #@param {type:"boolean"}
|
272 |
+
show_sample_per_step = False #@param {type:"boolean"}
|
273 |
+
|
274 |
+
#@markdown **Prompt Settings**
|
275 |
+
prompt_weighting = True #@param {type:"boolean"}
|
276 |
+
normalize_prompt_weights = True #@param {type:"boolean"}
|
277 |
+
log_weighted_subprompts = False #@param {type:"boolean"}
|
278 |
+
|
279 |
+
#@markdown **Batch Settings**
|
280 |
+
n_batch = 1 #@param
|
281 |
+
batch_name = "StableFun" #@param {type:"string"}
|
282 |
+
filename_format = "{timestring}_{index}_{prompt}.png" #@param ["{timestring}_{index}_{seed}.png","{timestring}_{index}_{prompt}.png"]
|
283 |
+
seed_behavior = "iter" #@param ["iter","fixed","random"]
|
284 |
+
make_grid = False #@param {type:"boolean"}
|
285 |
+
grid_rows = 2 #@param
|
286 |
+
outdir = get_output_folder(root.output_path, batch_name)
|
287 |
+
|
288 |
+
#@markdown **Init Settings**
|
289 |
+
use_init = False #@param {type:"boolean"}
|
290 |
+
strength = 0.0 #@param {type:"number"}
|
291 |
+
strength_0_no_init = True # Set the strength to 0 automatically when no init image is used
|
292 |
+
init_image = "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg" #@param {type:"string"}
|
293 |
+
# Whiter areas of the mask are areas that change more
|
294 |
+
use_mask = False #@param {type:"boolean"}
|
295 |
+
use_alpha_as_mask = False # use the alpha channel of the init image as the mask
|
296 |
+
mask_file = "https://www.filterforge.com/wiki/images/archive/b/b7/20080927223728%21Polygonal_gradient_thumb.jpg" #@param {type:"string"}
|
297 |
+
invert_mask = False #@param {type:"boolean"}
|
298 |
+
# Adjust mask image, 1.0 is no adjustment. Should be positive numbers.
|
299 |
+
mask_brightness_adjust = 1.0 #@param {type:"number"}
|
300 |
+
mask_contrast_adjust = 1.0 #@param {type:"number"}
|
301 |
+
# Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding
|
302 |
+
overlay_mask = True # {type:"boolean"}
|
303 |
+
# Blur edges of final overlay mask, if used. Minimum = 0 (no blur)
|
304 |
+
mask_overlay_blur = 5 # {type:"number"}
|
305 |
+
|
306 |
+
#@markdown **Exposure/Contrast Conditional Settings**
|
307 |
+
mean_scale = 0 #@param {type:"number"}
|
308 |
+
var_scale = 0 #@param {type:"number"}
|
309 |
+
exposure_scale = 0 #@param {type:"number"}
|
310 |
+
exposure_target = 0.5 #@param {type:"number"}
|
311 |
+
|
312 |
+
#@markdown **Color Match Conditional Settings**
|
313 |
+
colormatch_scale = 0 #@param {type:"number"}
|
314 |
+
colormatch_image = "https://www.saasdesign.io/wp-content/uploads/2021/02/palette-3-min-980x588.png" #@param {type:"string"}
|
315 |
+
colormatch_n_colors = 4 #@param {type:"number"}
|
316 |
+
ignore_sat_weight = 0 #@param {type:"number"}
|
317 |
+
|
318 |
+
#@markdown **CLIP\Aesthetics Conditional Settings**
|
319 |
+
clip_name = 'ViT-L/14' #@param ['ViT-L/14', 'ViT-L/14@336px', 'ViT-B/16', 'ViT-B/32']
|
320 |
+
clip_scale = 0 #@param {type:"number"}
|
321 |
+
aesthetics_scale = 0 #@param {type:"number"}
|
322 |
+
cutn = 1 #@param {type:"number"}
|
323 |
+
cut_pow = 0.0001 #@param {type:"number"}
|
324 |
+
|
325 |
+
#@markdown **Other Conditional Settings**
|
326 |
+
init_mse_scale = 0 #@param {type:"number"}
|
327 |
+
init_mse_image = "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg" #@param {type:"string"}
|
328 |
+
|
329 |
+
blue_scale = 0 #@param {type:"number"}
|
330 |
+
|
331 |
+
#@markdown **Conditional Gradient Settings**
|
332 |
+
gradient_wrt = 'x0_pred' #@param ["x", "x0_pred"]
|
333 |
+
gradient_add_to = 'both' #@param ["cond", "uncond", "both"]
|
334 |
+
decode_method = 'linear' #@param ["autoencoder","linear"]
|
335 |
+
grad_threshold_type = 'dynamic' #@param ["dynamic", "static", "mean", "schedule"]
|
336 |
+
clamp_grad_threshold = 0.2 #@param {type:"number"}
|
337 |
+
clamp_start = 0.2 #@param
|
338 |
+
clamp_stop = 0.01 #@param
|
339 |
+
grad_inject_timing = list(range(1,10)) #@param
|
340 |
+
|
341 |
+
#@markdown **Speed vs VRAM Settings**
|
342 |
+
cond_uncond_sync = True #@param {type:"boolean"}
|
343 |
+
|
344 |
+
n_samples = 1 # doesnt do anything
|
345 |
+
precision = 'autocast'
|
346 |
+
C = 4
|
347 |
+
f = 8
|
348 |
+
|
349 |
+
prompt = ""
|
350 |
+
timestring = ""
|
351 |
+
init_latent = None
|
352 |
+
init_sample = None
|
353 |
+
init_sample_raw = None
|
354 |
+
mask_sample = None
|
355 |
+
init_c = None
|
356 |
+
|
357 |
+
return locals()
|
358 |
+
|
359 |
+
args_dict = DeforumArgs()
|
360 |
+
anim_args_dict = DeforumAnimArgs()
|
361 |
+
|
362 |
+
if override_settings_with_file:
|
363 |
+
load_args(args_dict, anim_args_dict, settings_file, custom_settings_file, verbose=False)
|
364 |
+
|
365 |
+
args = SimpleNamespace(**args_dict)
|
366 |
+
anim_args = SimpleNamespace(**anim_args_dict)
|
367 |
+
|
368 |
+
args.timestring = time.strftime('%Y%m%d%H%M%S')
|
369 |
+
args.strength = max(0.0, min(1.0, args.strength))
|
370 |
+
|
371 |
+
# Load clip model if using clip guidance
|
372 |
+
if (args.clip_scale > 0) or (args.aesthetics_scale > 0):
|
373 |
+
root.clip_model = clip.load(args.clip_name, jit=False)[0].eval().requires_grad_(False).to(root.device)
|
374 |
+
if (args.aesthetics_scale > 0):
|
375 |
+
root.aesthetics_model = load_aesthetics_model(args, root)
|
376 |
+
|
377 |
+
if args.seed == -1:
|
378 |
+
args.seed = random.randint(0, 2**32 - 1)
|
379 |
+
if not args.use_init:
|
380 |
+
args.init_image = None
|
381 |
+
if args.sampler == 'plms' and (args.use_init or anim_args.animation_mode != 'None'):
|
382 |
+
print(f"Init images aren't supported with PLMS yet, switching to KLMS")
|
383 |
+
args.sampler = 'klms'
|
384 |
+
if args.sampler != 'ddim':
|
385 |
+
args.ddim_eta = 0
|
386 |
+
|
387 |
+
if anim_args.animation_mode == 'None':
|
388 |
+
anim_args.max_frames = 1
|
389 |
+
elif anim_args.animation_mode == 'Video Input':
|
390 |
+
args.use_init = True
|
391 |
+
|
392 |
+
# clean up unused memory
|
393 |
+
gc.collect()
|
394 |
+
torch.cuda.empty_cache()
|
395 |
+
|
396 |
+
# dispatch to appropriate renderer
|
397 |
+
if anim_args.animation_mode == '2D' or anim_args.animation_mode == '3D':
|
398 |
+
render_animation(args, anim_args, animation_prompts, root)
|
399 |
+
elif anim_args.animation_mode == 'Video Input':
|
400 |
+
render_input_video(args, anim_args, animation_prompts, root)
|
401 |
+
elif anim_args.animation_mode == 'Interpolation':
|
402 |
+
render_interpolation(args, anim_args, animation_prompts, root)
|
403 |
+
else:
|
404 |
+
render_image_batch(args, prompts, root)
|
405 |
+
|
406 |
+
# %%
|
407 |
+
# !! {"metadata":{
|
408 |
+
# !! "id": "gJ88kZ2-WM_v"
|
409 |
+
# !! }}
|
410 |
+
"""
|
411 |
+
# Create Video From Frames
|
412 |
+
"""
|
413 |
+
|
414 |
+
# %%
|
415 |
+
# !! {"metadata":{
|
416 |
+
# !! "cellView": "form",
|
417 |
+
# !! "id": "XQGeqaGAWM_v"
|
418 |
+
# !! }}
|
419 |
+
skip_video_for_run_all = True #@param {type: 'boolean'}
|
420 |
+
fps = 12 #@param {type:"number"}
|
421 |
+
#@markdown **Manual Settings**
|
422 |
+
use_manual_settings = False #@param {type:"boolean"}
|
423 |
+
image_path = "/content/drive/MyDrive/AI/StableDiffusion/2022-09/20220903000939_%05d.png" #@param {type:"string"}
|
424 |
+
mp4_path = "/content/drive/MyDrive/AI/StableDiffusion/2022-09/20220903000939.mp4" #@param {type:"string"}
|
425 |
+
render_steps = False #@param {type: 'boolean'}
|
426 |
+
path_name_modifier = "x0_pred" #@param ["x0_pred","x"]
|
427 |
+
make_gif = False
|
428 |
+
|
429 |
+
if skip_video_for_run_all == True:
|
430 |
+
print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')
|
431 |
+
else:
|
432 |
+
import os
|
433 |
+
import subprocess
|
434 |
+
from base64 import b64encode
|
435 |
+
|
436 |
+
print(f"{image_path} -> {mp4_path}")
|
437 |
+
|
438 |
+
if use_manual_settings:
|
439 |
+
max_frames = "200" #@param {type:"string"}
|
440 |
+
else:
|
441 |
+
if render_steps: # render steps from a single image
|
442 |
+
fname = f"{path_name_modifier}_%05d.png"
|
443 |
+
all_step_dirs = [os.path.join(args.outdir, d) for d in os.listdir(args.outdir) if os.path.isdir(os.path.join(args.outdir,d))]
|
444 |
+
newest_dir = max(all_step_dirs, key=os.path.getmtime)
|
445 |
+
image_path = os.path.join(newest_dir, fname)
|
446 |
+
print(f"Reading images from {image_path}")
|
447 |
+
mp4_path = os.path.join(newest_dir, f"{args.timestring}_{path_name_modifier}.mp4")
|
448 |
+
max_frames = str(args.steps)
|
449 |
+
else: # render images for a video
|
450 |
+
image_path = os.path.join(args.outdir, f"{args.timestring}_%05d.png")
|
451 |
+
mp4_path = os.path.join(args.outdir, f"{args.timestring}.mp4")
|
452 |
+
max_frames = str(anim_args.max_frames)
|
453 |
+
|
454 |
+
# make video
|
455 |
+
cmd = [
|
456 |
+
'ffmpeg',
|
457 |
+
'-y',
|
458 |
+
'-vcodec', 'png',
|
459 |
+
'-r', str(fps),
|
460 |
+
'-start_number', str(0),
|
461 |
+
'-i', image_path,
|
462 |
+
'-frames:v', max_frames,
|
463 |
+
'-c:v', 'libx264',
|
464 |
+
'-vf',
|
465 |
+
f'fps={fps}',
|
466 |
+
'-pix_fmt', 'yuv420p',
|
467 |
+
'-crf', '17',
|
468 |
+
'-preset', 'veryfast',
|
469 |
+
'-pattern_type', 'sequence',
|
470 |
+
mp4_path
|
471 |
+
]
|
472 |
+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
473 |
+
stdout, stderr = process.communicate()
|
474 |
+
if process.returncode != 0:
|
475 |
+
print(stderr)
|
476 |
+
raise RuntimeError(stderr)
|
477 |
+
|
478 |
+
mp4 = open(mp4_path,'rb').read()
|
479 |
+
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
|
480 |
+
display.display(display.HTML(f'<video controls loop><source src="{data_url}" type="video/mp4"></video>') )
|
481 |
+
|
482 |
+
if make_gif:
|
483 |
+
gif_path = os.path.splitext(mp4_path)[0]+'.gif'
|
484 |
+
cmd_gif = [
|
485 |
+
'ffmpeg',
|
486 |
+
'-y',
|
487 |
+
'-i', mp4_path,
|
488 |
+
'-r', str(fps),
|
489 |
+
gif_path
|
490 |
+
]
|
491 |
+
process_gif = subprocess.Popen(cmd_gif, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
492 |
+
|
493 |
+
# %%
|
494 |
+
# !! {"metadata":{
|
495 |
+
# !! "cellView": "form",
|
496 |
+
# !! "id": "MMpAcyrYWM_v"
|
497 |
+
# !! }}
|
498 |
+
skip_disconnect_for_run_all = True #@param {type: 'boolean'}
|
499 |
+
|
500 |
+
if skip_disconnect_for_run_all == True:
|
501 |
+
print('Skipping disconnect, uncheck skip_disconnect_for_run_all if you want to run it')
|
502 |
+
else:
|
503 |
+
from google.colab import runtime
|
504 |
+
runtime.unassign()
|
505 |
+
|
506 |
+
# %%
|
507 |
+
# !! {"main_metadata":{
|
508 |
+
# !! "kernelspec": {
|
509 |
+
# !! "display_name": "Python 3.10.6 ('dsd')",
|
510 |
+
# !! "language": "python",
|
511 |
+
# !! "name": "python3"
|
512 |
+
# !! },
|
513 |
+
# !! "language_info": {
|
514 |
+
# !! "codemirror_mode": {
|
515 |
+
# !! "name": "ipython",
|
516 |
+
# !! "version": 3
|
517 |
+
# !! },
|
518 |
+
# !! "file_extension": ".py",
|
519 |
+
# !! "mimetype": "text/x-python",
|
520 |
+
# !! "name": "python",
|
521 |
+
# !! "nbconvert_exporter": "python",
|
522 |
+
# !! "pygments_lexer": "ipython3",
|
523 |
+
# !! "version": "3.10.6"
|
524 |
+
# !! },
|
525 |
+
# !! "orig_nbformat": 4,
|
526 |
+
# !! "vscode": {
|
527 |
+
# !! "interpreter": {
|
528 |
+
# !! "hash": "b7e04c8a9537645cbc77fa0cbde8069bc94e341b0d5ced104651213865b24e58"
|
529 |
+
# !! }
|
530 |
+
# !! },
|
531 |
+
# !! "colab": {
|
532 |
+
# !! "provenance": []
|
533 |
+
# !! },
|
534 |
+
# !! "accelerator": "GPU",
|
535 |
+
# !! "gpuClass": "standard"
|
536 |
+
# !! }}
|
deforum-stable-diffusion/LICENSE
ADDED
The diff for this file is too large to render.
See raw diff
|
|
deforum-stable-diffusion/configs/v1-inference.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 10000 ]
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
deforum-stable-diffusion/configs/v2-inference-v.yaml
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-4
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
parameterization: "v"
|
6 |
+
linear_start: 0.00085
|
7 |
+
linear_end: 0.0120
|
8 |
+
num_timesteps_cond: 1
|
9 |
+
log_every_t: 200
|
10 |
+
timesteps: 1000
|
11 |
+
first_stage_key: "jpg"
|
12 |
+
cond_stage_key: "txt"
|
13 |
+
image_size: 64
|
14 |
+
channels: 4
|
15 |
+
cond_stage_trainable: false
|
16 |
+
conditioning_key: crossattn
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
scale_factor: 0.18215
|
19 |
+
use_ema: False # we set this to false because this is an inference only config
|
20 |
+
|
21 |
+
unet_config:
|
22 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
23 |
+
params:
|
24 |
+
use_checkpoint: True
|
25 |
+
use_fp16: True
|
26 |
+
image_size: 32 # unused
|
27 |
+
in_channels: 4
|
28 |
+
out_channels: 4
|
29 |
+
model_channels: 320
|
30 |
+
attention_resolutions: [ 4, 2, 1 ]
|
31 |
+
num_res_blocks: 2
|
32 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
33 |
+
num_head_channels: 64 # need to fix for flash-attn
|
34 |
+
use_spatial_transformer: True
|
35 |
+
use_linear_in_transformer: True
|
36 |
+
transformer_depth: 1
|
37 |
+
context_dim: 1024
|
38 |
+
legacy: False
|
39 |
+
|
40 |
+
first_stage_config:
|
41 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
42 |
+
params:
|
43 |
+
embed_dim: 4
|
44 |
+
monitor: val/rec_loss
|
45 |
+
ddconfig:
|
46 |
+
#attn_type: "vanilla-xformers"
|
47 |
+
double_z: true
|
48 |
+
z_channels: 4
|
49 |
+
resolution: 256
|
50 |
+
in_channels: 3
|
51 |
+
out_ch: 3
|
52 |
+
ch: 128
|
53 |
+
ch_mult:
|
54 |
+
- 1
|
55 |
+
- 2
|
56 |
+
- 4
|
57 |
+
- 4
|
58 |
+
num_res_blocks: 2
|
59 |
+
attn_resolutions: []
|
60 |
+
dropout: 0.0
|
61 |
+
lossconfig:
|
62 |
+
target: torch.nn.Identity
|
63 |
+
|
64 |
+
cond_stage_config:
|
65 |
+
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
66 |
+
params:
|
67 |
+
freeze: True
|
68 |
+
layer: "penultimate"
|
deforum-stable-diffusion/configs/v2-inference.yaml
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-4
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False # we set this to false because this is an inference only config
|
19 |
+
|
20 |
+
unet_config:
|
21 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
22 |
+
params:
|
23 |
+
use_checkpoint: True
|
24 |
+
use_fp16: True
|
25 |
+
image_size: 32 # unused
|
26 |
+
in_channels: 4
|
27 |
+
out_channels: 4
|
28 |
+
model_channels: 320
|
29 |
+
attention_resolutions: [ 4, 2, 1 ]
|
30 |
+
num_res_blocks: 2
|
31 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
32 |
+
num_head_channels: 64 # need to fix for flash-attn
|
33 |
+
use_spatial_transformer: True
|
34 |
+
use_linear_in_transformer: True
|
35 |
+
transformer_depth: 1
|
36 |
+
context_dim: 1024
|
37 |
+
legacy: False
|
38 |
+
|
39 |
+
first_stage_config:
|
40 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
41 |
+
params:
|
42 |
+
embed_dim: 4
|
43 |
+
monitor: val/rec_loss
|
44 |
+
ddconfig:
|
45 |
+
#attn_type: "vanilla-xformers"
|
46 |
+
double_z: true
|
47 |
+
z_channels: 4
|
48 |
+
resolution: 256
|
49 |
+
in_channels: 3
|
50 |
+
out_ch: 3
|
51 |
+
ch: 128
|
52 |
+
ch_mult:
|
53 |
+
- 1
|
54 |
+
- 2
|
55 |
+
- 4
|
56 |
+
- 4
|
57 |
+
num_res_blocks: 2
|
58 |
+
attn_resolutions: []
|
59 |
+
dropout: 0.0
|
60 |
+
lossconfig:
|
61 |
+
target: torch.nn.Identity
|
62 |
+
|
63 |
+
cond_stage_config:
|
64 |
+
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
65 |
+
params:
|
66 |
+
freeze: True
|
67 |
+
layer: "penultimate"
|
deforum-stable-diffusion/configs/v2-inpainting-inference.yaml
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 5.0e-05
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false
|
15 |
+
conditioning_key: hybrid
|
16 |
+
scale_factor: 0.18215
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
finetune_keys: null
|
19 |
+
use_ema: False
|
20 |
+
|
21 |
+
unet_config:
|
22 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
23 |
+
params:
|
24 |
+
use_checkpoint: True
|
25 |
+
image_size: 32 # unused
|
26 |
+
in_channels: 9
|
27 |
+
out_channels: 4
|
28 |
+
model_channels: 320
|
29 |
+
attention_resolutions: [ 4, 2, 1 ]
|
30 |
+
num_res_blocks: 2
|
31 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
32 |
+
num_head_channels: 64 # need to fix for flash-attn
|
33 |
+
use_spatial_transformer: True
|
34 |
+
use_linear_in_transformer: True
|
35 |
+
transformer_depth: 1
|
36 |
+
context_dim: 1024
|
37 |
+
legacy: False
|
38 |
+
|
39 |
+
first_stage_config:
|
40 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
41 |
+
params:
|
42 |
+
embed_dim: 4
|
43 |
+
monitor: val/rec_loss
|
44 |
+
ddconfig:
|
45 |
+
#attn_type: "vanilla-xformers"
|
46 |
+
double_z: true
|
47 |
+
z_channels: 4
|
48 |
+
resolution: 256
|
49 |
+
in_channels: 3
|
50 |
+
out_ch: 3
|
51 |
+
ch: 128
|
52 |
+
ch_mult:
|
53 |
+
- 1
|
54 |
+
- 2
|
55 |
+
- 4
|
56 |
+
- 4
|
57 |
+
num_res_blocks: 2
|
58 |
+
attn_resolutions: [ ]
|
59 |
+
dropout: 0.0
|
60 |
+
lossconfig:
|
61 |
+
target: torch.nn.Identity
|
62 |
+
|
63 |
+
cond_stage_config:
|
64 |
+
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
65 |
+
params:
|
66 |
+
freeze: True
|
67 |
+
layer: "penultimate"
|
68 |
+
|
69 |
+
|
70 |
+
data:
|
71 |
+
target: ldm.data.laion.WebDataModuleFromConfig
|
72 |
+
params:
|
73 |
+
tar_base: null # for concat as in LAION-A
|
74 |
+
p_unsafe_threshold: 0.1
|
75 |
+
filter_word_list: "data/filters.yaml"
|
76 |
+
max_pwatermark: 0.45
|
77 |
+
batch_size: 8
|
78 |
+
num_workers: 6
|
79 |
+
multinode: True
|
80 |
+
min_size: 512
|
81 |
+
train:
|
82 |
+
shards:
|
83 |
+
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
|
84 |
+
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
|
85 |
+
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
|
86 |
+
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
|
87 |
+
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
|
88 |
+
shuffle: 10000
|
89 |
+
image_key: jpg
|
90 |
+
image_transforms:
|
91 |
+
- target: torchvision.transforms.Resize
|
92 |
+
params:
|
93 |
+
size: 512
|
94 |
+
interpolation: 3
|
95 |
+
- target: torchvision.transforms.RandomCrop
|
96 |
+
params:
|
97 |
+
size: 512
|
98 |
+
postprocess:
|
99 |
+
target: ldm.data.laion.AddMask
|
100 |
+
params:
|
101 |
+
mode: "512train-large"
|
102 |
+
p_drop: 0.25
|
103 |
+
# NOTE use enough shards to avoid empty validation loops in workers
|
104 |
+
validation:
|
105 |
+
shards:
|
106 |
+
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
|
107 |
+
shuffle: 0
|
108 |
+
image_key: jpg
|
109 |
+
image_transforms:
|
110 |
+
- target: torchvision.transforms.Resize
|
111 |
+
params:
|
112 |
+
size: 512
|
113 |
+
interpolation: 3
|
114 |
+
- target: torchvision.transforms.CenterCrop
|
115 |
+
params:
|
116 |
+
size: 512
|
117 |
+
postprocess:
|
118 |
+
target: ldm.data.laion.AddMask
|
119 |
+
params:
|
120 |
+
mode: "512train-large"
|
121 |
+
p_drop: 0.25
|
122 |
+
|
123 |
+
lightning:
|
124 |
+
find_unused_parameters: True
|
125 |
+
modelcheckpoint:
|
126 |
+
params:
|
127 |
+
every_n_train_steps: 5000
|
128 |
+
|
129 |
+
callbacks:
|
130 |
+
metrics_over_trainsteps_checkpoint:
|
131 |
+
params:
|
132 |
+
every_n_train_steps: 10000
|
133 |
+
|
134 |
+
image_logger:
|
135 |
+
target: main.ImageLogger
|
136 |
+
params:
|
137 |
+
enable_autocast: False
|
138 |
+
disabled: False
|
139 |
+
batch_frequency: 1000
|
140 |
+
max_images: 4
|
141 |
+
increase_log_steps: False
|
142 |
+
log_first_step: False
|
143 |
+
log_images_kwargs:
|
144 |
+
use_ema_scope: False
|
145 |
+
inpaint: False
|
146 |
+
plot_progressive_rows: False
|
147 |
+
plot_diffusion_rows: False
|
148 |
+
N: 4
|
149 |
+
unconditional_guidance_scale: 5.0
|
150 |
+
unconditional_guidance_label: [""]
|
151 |
+
ddim_steps: 50 # todo check these out for depth2img,
|
152 |
+
ddim_eta: 0.0 # todo check these out for depth2img,
|
153 |
+
|
154 |
+
trainer:
|
155 |
+
benchmark: True
|
156 |
+
val_check_interval: 5000000
|
157 |
+
num_sanity_val_steps: 0
|
158 |
+
accumulate_grad_batches: 1
|
deforum-stable-diffusion/configs/v2-midas-inference.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 5.0e-07
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false
|
15 |
+
conditioning_key: hybrid
|
16 |
+
scale_factor: 0.18215
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
finetune_keys: null
|
19 |
+
use_ema: False
|
20 |
+
|
21 |
+
depth_stage_config:
|
22 |
+
target: ldm.modules.midas.api.MiDaSInference
|
23 |
+
params:
|
24 |
+
model_type: "dpt_hybrid"
|
25 |
+
|
26 |
+
unet_config:
|
27 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
28 |
+
params:
|
29 |
+
use_checkpoint: True
|
30 |
+
image_size: 32 # unused
|
31 |
+
in_channels: 5
|
32 |
+
out_channels: 4
|
33 |
+
model_channels: 320
|
34 |
+
attention_resolutions: [ 4, 2, 1 ]
|
35 |
+
num_res_blocks: 2
|
36 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
37 |
+
num_head_channels: 64 # need to fix for flash-attn
|
38 |
+
use_spatial_transformer: True
|
39 |
+
use_linear_in_transformer: True
|
40 |
+
transformer_depth: 1
|
41 |
+
context_dim: 1024
|
42 |
+
legacy: False
|
43 |
+
|
44 |
+
first_stage_config:
|
45 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
46 |
+
params:
|
47 |
+
embed_dim: 4
|
48 |
+
monitor: val/rec_loss
|
49 |
+
ddconfig:
|
50 |
+
#attn_type: "vanilla-xformers"
|
51 |
+
double_z: true
|
52 |
+
z_channels: 4
|
53 |
+
resolution: 256
|
54 |
+
in_channels: 3
|
55 |
+
out_ch: 3
|
56 |
+
ch: 128
|
57 |
+
ch_mult:
|
58 |
+
- 1
|
59 |
+
- 2
|
60 |
+
- 4
|
61 |
+
- 4
|
62 |
+
num_res_blocks: 2
|
63 |
+
attn_resolutions: [ ]
|
64 |
+
dropout: 0.0
|
65 |
+
lossconfig:
|
66 |
+
target: torch.nn.Identity
|
67 |
+
|
68 |
+
cond_stage_config:
|
69 |
+
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
70 |
+
params:
|
71 |
+
freeze: True
|
72 |
+
layer: "penultimate"
|
73 |
+
|
74 |
+
|
deforum-stable-diffusion/configs/x4-upscaling.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
|
4 |
+
params:
|
5 |
+
parameterization: "v"
|
6 |
+
low_scale_key: "lr"
|
7 |
+
linear_start: 0.0001
|
8 |
+
linear_end: 0.02
|
9 |
+
num_timesteps_cond: 1
|
10 |
+
log_every_t: 200
|
11 |
+
timesteps: 1000
|
12 |
+
first_stage_key: "jpg"
|
13 |
+
cond_stage_key: "txt"
|
14 |
+
image_size: 128
|
15 |
+
channels: 4
|
16 |
+
cond_stage_trainable: false
|
17 |
+
conditioning_key: "hybrid-adm"
|
18 |
+
monitor: val/loss_simple_ema
|
19 |
+
scale_factor: 0.08333
|
20 |
+
use_ema: False
|
21 |
+
|
22 |
+
low_scale_config:
|
23 |
+
target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
|
24 |
+
params:
|
25 |
+
noise_schedule_config: # image space
|
26 |
+
linear_start: 0.0001
|
27 |
+
linear_end: 0.02
|
28 |
+
max_noise_level: 350
|
29 |
+
|
30 |
+
unet_config:
|
31 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
32 |
+
params:
|
33 |
+
use_checkpoint: True
|
34 |
+
num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
|
35 |
+
image_size: 128
|
36 |
+
in_channels: 7
|
37 |
+
out_channels: 4
|
38 |
+
model_channels: 256
|
39 |
+
attention_resolutions: [ 2,4,8]
|
40 |
+
num_res_blocks: 2
|
41 |
+
channel_mult: [ 1, 2, 2, 4]
|
42 |
+
disable_self_attentions: [True, True, True, False]
|
43 |
+
disable_middle_self_attn: False
|
44 |
+
num_heads: 8
|
45 |
+
use_spatial_transformer: True
|
46 |
+
transformer_depth: 1
|
47 |
+
context_dim: 1024
|
48 |
+
legacy: False
|
49 |
+
use_linear_in_transformer: True
|
50 |
+
|
51 |
+
first_stage_config:
|
52 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
53 |
+
params:
|
54 |
+
embed_dim: 4
|
55 |
+
ddconfig:
|
56 |
+
# attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though)
|
57 |
+
double_z: True
|
58 |
+
z_channels: 4
|
59 |
+
resolution: 256
|
60 |
+
in_channels: 3
|
61 |
+
out_ch: 3
|
62 |
+
ch: 128
|
63 |
+
ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
|
64 |
+
num_res_blocks: 2
|
65 |
+
attn_resolutions: [ ]
|
66 |
+
dropout: 0.0
|
67 |
+
|
68 |
+
lossconfig:
|
69 |
+
target: torch.nn.Identity
|
70 |
+
|
71 |
+
cond_stage_config:
|
72 |
+
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
73 |
+
params:
|
74 |
+
freeze: True
|
75 |
+
layer: "penultimate"
|
76 |
+
|
deforum-stable-diffusion/helpers/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
from .save_images import save_samples, get_output_folder
|
3 |
+
from .k_samplers import sampler_fn, make_inject_timing_fn
|
4 |
+
from .depth import DepthModel
|
5 |
+
from .prompt import sanitize
|
6 |
+
from .animation import construct_RotationMatrixHomogenous, getRotationMatrixManual, getPoints_for_PerspectiveTranformEstimation, warpMatrix, anim_frame_warp
|
7 |
+
from .generate import add_noise, load_img, load_mask_latent, prepare_mask
|
8 |
+
from .load_images import load_img, load_mask_latent, prepare_mask, prepare_overlay_mask
|
9 |
+
"""
|
deforum-stable-diffusion/helpers/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (644 Bytes). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (685 Bytes). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/aesthetics.cpython-38.pyc
ADDED
Binary file (1.65 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/animation.cpython-38.pyc
ADDED
Binary file (10.6 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/callback.cpython-38.pyc
ADDED
Binary file (4.47 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/colors.cpython-38.pyc
ADDED
Binary file (730 Bytes). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/conditioning.cpython-38.pyc
ADDED
Binary file (9.93 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/depth.cpython-38.pyc
ADDED
Binary file (5.39 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/generate.cpython-38.pyc
ADDED
Binary file (7.78 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/generate.cpython-39.pyc
ADDED
Binary file (7.91 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/k_samplers.cpython-38.pyc
ADDED
Binary file (4.45 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/load_images.cpython-38.pyc
ADDED
Binary file (2.48 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/model_load.cpython-38.pyc
ADDED
Binary file (7.53 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/model_wrap.cpython-38.pyc
ADDED
Binary file (6.46 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/prompt.cpython-38.pyc
ADDED
Binary file (4.76 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/render.cpython-38.pyc
ADDED
Binary file (10.4 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/render.cpython-39.pyc
ADDED
Binary file (10.7 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/save_images.cpython-38.pyc
ADDED
Binary file (1.85 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/save_images.cpython-39.pyc
ADDED
Binary file (1.88 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/settings.cpython-38.pyc
ADDED
Binary file (1.23 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/settings.cpython-39.pyc
ADDED
Binary file (1.3 kB). View file
|
|
deforum-stable-diffusion/helpers/__pycache__/simulacra_fit_linear_model.cpython-38.pyc
ADDED
Binary file (2.32 kB). View file
|
|
deforum-stable-diffusion/helpers/aesthetics.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from .simulacra_fit_linear_model import AestheticMeanPredictionLinearModel
|
4 |
+
import requests
|
5 |
+
|
6 |
+
def wget(url, outputdir):
|
7 |
+
filename = url.split("/")[-1]
|
8 |
+
|
9 |
+
ckpt_request = requests.get(url)
|
10 |
+
request_status = ckpt_request.status_code
|
11 |
+
|
12 |
+
# inform user of errors
|
13 |
+
if request_status == 403:
|
14 |
+
raise ConnectionRefusedError("You have not accepted the license for this model.")
|
15 |
+
elif request_status == 404:
|
16 |
+
raise ConnectionError("Could not make contact with server")
|
17 |
+
elif request_status != 200:
|
18 |
+
raise ConnectionError(f"Some other error has ocurred - response code: {request_status}")
|
19 |
+
|
20 |
+
# write to model path
|
21 |
+
with open(os.path.join(outputdir, filename), 'wb') as model_file:
|
22 |
+
model_file.write(ckpt_request.content)
|
23 |
+
|
24 |
+
|
25 |
+
def load_aesthetics_model(args,root):
|
26 |
+
|
27 |
+
clip_size = {
|
28 |
+
"ViT-B/32": 512,
|
29 |
+
"ViT-B/16": 512,
|
30 |
+
"ViT-L/14": 768,
|
31 |
+
"ViT-L/14@336px": 768,
|
32 |
+
}
|
33 |
+
|
34 |
+
model_name = {
|
35 |
+
"ViT-B/32": "sac_public_2022_06_29_vit_b_32_linear.pth",
|
36 |
+
"ViT-B/16": "sac_public_2022_06_29_vit_b_16_linear.pth",
|
37 |
+
"ViT-L/14": "sac_public_2022_06_29_vit_l_14_linear.pth",
|
38 |
+
}
|
39 |
+
|
40 |
+
if not os.path.exists(os.path.join(root.models_path,model_name[args.clip_name])):
|
41 |
+
print("Downloading aesthetics model...")
|
42 |
+
os.makedirs(root.models_path, exist_ok=True)
|
43 |
+
wget("https://github.com/crowsonkb/simulacra-aesthetic-models/raw/master/models/"+model_name[args.clip_name], root.models_path)
|
44 |
+
|
45 |
+
aesthetics_model = AestheticMeanPredictionLinearModel(clip_size[args.clip_name])
|
46 |
+
aesthetics_model.load_state_dict(torch.load(os.path.join(root.models_path,model_name[args.clip_name])))
|
47 |
+
|
48 |
+
return aesthetics_model.to(root.device)
|
deforum-stable-diffusion/helpers/animation.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
from functools import reduce
|
4 |
+
import math
|
5 |
+
import py3d_tools as p3d
|
6 |
+
import torch
|
7 |
+
from einops import rearrange
|
8 |
+
import re
|
9 |
+
import pathlib
|
10 |
+
import os
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
def check_is_number(value):
|
14 |
+
float_pattern = r'^(?=.)([+-]?([0-9]*)(\.([0-9]+))?)$'
|
15 |
+
return re.match(float_pattern, value)
|
16 |
+
|
17 |
+
def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:
|
18 |
+
sample = ((sample.astype(float) / 255.0) * 2) - 1
|
19 |
+
sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16)
|
20 |
+
sample = torch.from_numpy(sample)
|
21 |
+
return sample
|
22 |
+
|
23 |
+
def sample_to_cv2(sample: torch.Tensor, type=np.uint8) -> np.ndarray:
|
24 |
+
sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype(np.float32)
|
25 |
+
sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1)
|
26 |
+
sample_int8 = (sample_f32 * 255)
|
27 |
+
return sample_int8.astype(type)
|
28 |
+
|
29 |
+
def construct_RotationMatrixHomogenous(rotation_angles):
|
30 |
+
assert(type(rotation_angles)==list and len(rotation_angles)==3)
|
31 |
+
RH = np.eye(4,4)
|
32 |
+
cv2.Rodrigues(np.array(rotation_angles), RH[0:3, 0:3])
|
33 |
+
return RH
|
34 |
+
|
35 |
+
def vid2frames(video_path, frames_path, n=1, overwrite=True):
|
36 |
+
if not os.path.exists(frames_path) or overwrite:
|
37 |
+
try:
|
38 |
+
for f in pathlib.Path(frames_path).glob('*.jpg'):
|
39 |
+
f.unlink()
|
40 |
+
except:
|
41 |
+
pass
|
42 |
+
assert os.path.exists(video_path), f"Video input {video_path} does not exist"
|
43 |
+
|
44 |
+
vidcap = cv2.VideoCapture(video_path)
|
45 |
+
success,image = vidcap.read()
|
46 |
+
count = 0
|
47 |
+
t=1
|
48 |
+
success = True
|
49 |
+
while success:
|
50 |
+
if count % n == 0:
|
51 |
+
cv2.imwrite(frames_path + os.path.sep + f"{t:05}.jpg" , image) # save frame as JPEG file
|
52 |
+
t += 1
|
53 |
+
success,image = vidcap.read()
|
54 |
+
count += 1
|
55 |
+
print("Converted %d frames" % count)
|
56 |
+
else: print("Frames already unpacked")
|
57 |
+
|
58 |
+
# https://en.wikipedia.org/wiki/Rotation_matrix
|
59 |
+
def getRotationMatrixManual(rotation_angles):
|
60 |
+
|
61 |
+
rotation_angles = [np.deg2rad(x) for x in rotation_angles]
|
62 |
+
|
63 |
+
phi = rotation_angles[0] # around x
|
64 |
+
gamma = rotation_angles[1] # around y
|
65 |
+
theta = rotation_angles[2] # around z
|
66 |
+
|
67 |
+
# X rotation
|
68 |
+
Rphi = np.eye(4,4)
|
69 |
+
sp = np.sin(phi)
|
70 |
+
cp = np.cos(phi)
|
71 |
+
Rphi[1,1] = cp
|
72 |
+
Rphi[2,2] = Rphi[1,1]
|
73 |
+
Rphi[1,2] = -sp
|
74 |
+
Rphi[2,1] = sp
|
75 |
+
|
76 |
+
# Y rotation
|
77 |
+
Rgamma = np.eye(4,4)
|
78 |
+
sg = np.sin(gamma)
|
79 |
+
cg = np.cos(gamma)
|
80 |
+
Rgamma[0,0] = cg
|
81 |
+
Rgamma[2,2] = Rgamma[0,0]
|
82 |
+
Rgamma[0,2] = sg
|
83 |
+
Rgamma[2,0] = -sg
|
84 |
+
|
85 |
+
# Z rotation (in-image-plane)
|
86 |
+
Rtheta = np.eye(4,4)
|
87 |
+
st = np.sin(theta)
|
88 |
+
ct = np.cos(theta)
|
89 |
+
Rtheta[0,0] = ct
|
90 |
+
Rtheta[1,1] = Rtheta[0,0]
|
91 |
+
Rtheta[0,1] = -st
|
92 |
+
Rtheta[1,0] = st
|
93 |
+
|
94 |
+
R = reduce(lambda x,y : np.matmul(x,y), [Rphi, Rgamma, Rtheta])
|
95 |
+
|
96 |
+
return R
|
97 |
+
|
98 |
+
def getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sidelength):
|
99 |
+
|
100 |
+
ptsIn2D = ptsIn[0,:]
|
101 |
+
ptsOut2D = ptsOut[0,:]
|
102 |
+
ptsOut2Dlist = []
|
103 |
+
ptsIn2Dlist = []
|
104 |
+
|
105 |
+
for i in range(0,4):
|
106 |
+
ptsOut2Dlist.append([ptsOut2D[i,0], ptsOut2D[i,1]])
|
107 |
+
ptsIn2Dlist.append([ptsIn2D[i,0], ptsIn2D[i,1]])
|
108 |
+
|
109 |
+
pin = np.array(ptsIn2Dlist) + [W/2.,H/2.]
|
110 |
+
pout = (np.array(ptsOut2Dlist) + [1.,1.]) * (0.5*sidelength)
|
111 |
+
pin = pin.astype(np.float32)
|
112 |
+
pout = pout.astype(np.float32)
|
113 |
+
|
114 |
+
return pin, pout
|
115 |
+
|
116 |
+
|
117 |
+
def warpMatrix(W, H, theta, phi, gamma, scale, fV):
|
118 |
+
|
119 |
+
# M is to be estimated
|
120 |
+
M = np.eye(4, 4)
|
121 |
+
|
122 |
+
fVhalf = np.deg2rad(fV/2.)
|
123 |
+
d = np.sqrt(W*W+H*H)
|
124 |
+
sideLength = scale*d/np.cos(fVhalf)
|
125 |
+
h = d/(2.0*np.sin(fVhalf))
|
126 |
+
n = h-(d/2.0)
|
127 |
+
f = h+(d/2.0)
|
128 |
+
|
129 |
+
# Translation along Z-axis by -h
|
130 |
+
T = np.eye(4,4)
|
131 |
+
T[2,3] = -h
|
132 |
+
|
133 |
+
# Rotation matrices around x,y,z
|
134 |
+
R = getRotationMatrixManual([phi, gamma, theta])
|
135 |
+
|
136 |
+
|
137 |
+
# Projection Matrix
|
138 |
+
P = np.eye(4,4)
|
139 |
+
P[0,0] = 1.0/np.tan(fVhalf)
|
140 |
+
P[1,1] = P[0,0]
|
141 |
+
P[2,2] = -(f+n)/(f-n)
|
142 |
+
P[2,3] = -(2.0*f*n)/(f-n)
|
143 |
+
P[3,2] = -1.0
|
144 |
+
|
145 |
+
# pythonic matrix multiplication
|
146 |
+
F = reduce(lambda x,y : np.matmul(x,y), [P, T, R])
|
147 |
+
|
148 |
+
# shape should be 1,4,3 for ptsIn and ptsOut since perspectiveTransform() expects data in this way.
|
149 |
+
# In C++, this can be achieved by Mat ptsIn(1,4,CV_64FC3);
|
150 |
+
ptsIn = np.array([[
|
151 |
+
[-W/2., H/2., 0.],[ W/2., H/2., 0.],[ W/2.,-H/2., 0.],[-W/2.,-H/2., 0.]
|
152 |
+
]])
|
153 |
+
ptsOut = np.array(np.zeros((ptsIn.shape), dtype=ptsIn.dtype))
|
154 |
+
ptsOut = cv2.perspectiveTransform(ptsIn, F)
|
155 |
+
|
156 |
+
ptsInPt2f, ptsOutPt2f = getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sideLength)
|
157 |
+
|
158 |
+
# check float32 otherwise OpenCV throws an error
|
159 |
+
assert(ptsInPt2f.dtype == np.float32)
|
160 |
+
assert(ptsOutPt2f.dtype == np.float32)
|
161 |
+
M33 = cv2.getPerspectiveTransform(ptsInPt2f,ptsOutPt2f)
|
162 |
+
|
163 |
+
return M33, sideLength
|
164 |
+
|
165 |
+
def anim_frame_warp(prev, args, anim_args, keys, frame_idx, depth_model=None, depth=None, device='cuda'):
|
166 |
+
if isinstance(prev, np.ndarray):
|
167 |
+
prev_img_cv2 = prev
|
168 |
+
else:
|
169 |
+
prev_img_cv2 = sample_to_cv2(prev)
|
170 |
+
|
171 |
+
if anim_args.use_depth_warping:
|
172 |
+
if depth is None and depth_model is not None:
|
173 |
+
depth = depth_model.predict(prev_img_cv2, anim_args)
|
174 |
+
else:
|
175 |
+
depth = None
|
176 |
+
|
177 |
+
if anim_args.animation_mode == '2D':
|
178 |
+
prev_img = anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx)
|
179 |
+
else: # '3D'
|
180 |
+
prev_img = anim_frame_warp_3d(device, prev_img_cv2, depth, anim_args, keys, frame_idx)
|
181 |
+
|
182 |
+
return prev_img, depth
|
183 |
+
|
184 |
+
def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx):
|
185 |
+
angle = keys.angle_series[frame_idx]
|
186 |
+
zoom = keys.zoom_series[frame_idx]
|
187 |
+
translation_x = keys.translation_x_series[frame_idx]
|
188 |
+
translation_y = keys.translation_y_series[frame_idx]
|
189 |
+
|
190 |
+
center = (args.W // 2, args.H // 2)
|
191 |
+
trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]])
|
192 |
+
rot_mat = cv2.getRotationMatrix2D(center, angle, zoom)
|
193 |
+
trans_mat = np.vstack([trans_mat, [0,0,1]])
|
194 |
+
rot_mat = np.vstack([rot_mat, [0,0,1]])
|
195 |
+
if anim_args.flip_2d_perspective:
|
196 |
+
perspective_flip_theta = keys.perspective_flip_theta_series[frame_idx]
|
197 |
+
perspective_flip_phi = keys.perspective_flip_phi_series[frame_idx]
|
198 |
+
perspective_flip_gamma = keys.perspective_flip_gamma_series[frame_idx]
|
199 |
+
perspective_flip_fv = keys.perspective_flip_fv_series[frame_idx]
|
200 |
+
M,sl = warpMatrix(args.W, args.H, perspective_flip_theta, perspective_flip_phi, perspective_flip_gamma, 1., perspective_flip_fv);
|
201 |
+
post_trans_mat = np.float32([[1, 0, (args.W-sl)/2], [0, 1, (args.H-sl)/2]])
|
202 |
+
post_trans_mat = np.vstack([post_trans_mat, [0,0,1]])
|
203 |
+
bM = np.matmul(M, post_trans_mat)
|
204 |
+
xform = np.matmul(bM, rot_mat, trans_mat)
|
205 |
+
else:
|
206 |
+
xform = np.matmul(rot_mat, trans_mat)
|
207 |
+
|
208 |
+
return cv2.warpPerspective(
|
209 |
+
prev_img_cv2,
|
210 |
+
xform,
|
211 |
+
(prev_img_cv2.shape[1], prev_img_cv2.shape[0]),
|
212 |
+
borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE
|
213 |
+
)
|
214 |
+
|
215 |
+
def anim_frame_warp_3d(device, prev_img_cv2, depth, anim_args, keys, frame_idx):
|
216 |
+
TRANSLATION_SCALE = 1.0/200.0 # matches Disco
|
217 |
+
translate_xyz = [
|
218 |
+
-keys.translation_x_series[frame_idx] * TRANSLATION_SCALE,
|
219 |
+
keys.translation_y_series[frame_idx] * TRANSLATION_SCALE,
|
220 |
+
-keys.translation_z_series[frame_idx] * TRANSLATION_SCALE
|
221 |
+
]
|
222 |
+
rotate_xyz = [
|
223 |
+
math.radians(keys.rotation_3d_x_series[frame_idx]),
|
224 |
+
math.radians(keys.rotation_3d_y_series[frame_idx]),
|
225 |
+
math.radians(keys.rotation_3d_z_series[frame_idx])
|
226 |
+
]
|
227 |
+
rot_mat = p3d.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), "XYZ").unsqueeze(0)
|
228 |
+
result = transform_image_3d(device, prev_img_cv2, depth, rot_mat, translate_xyz, anim_args)
|
229 |
+
torch.cuda.empty_cache()
|
230 |
+
return result
|
231 |
+
|
232 |
+
def transform_image_3d(device, prev_img_cv2, depth_tensor, rot_mat, translate, anim_args):
|
233 |
+
# adapted and optimized version of transform_image_3d from Disco Diffusion https://github.com/alembics/disco-diffusion
|
234 |
+
w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0]
|
235 |
+
|
236 |
+
aspect_ratio = float(w)/float(h)
|
237 |
+
near, far, fov_deg = anim_args.near_plane, anim_args.far_plane, anim_args.fov
|
238 |
+
persp_cam_old = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, device=device)
|
239 |
+
persp_cam_new = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, R=rot_mat, T=torch.tensor([translate]), device=device)
|
240 |
+
|
241 |
+
# range of [-1,1] is important to torch grid_sample's padding handling
|
242 |
+
y,x = torch.meshgrid(torch.linspace(-1.,1.,h,dtype=torch.float32,device=device),torch.linspace(-1.,1.,w,dtype=torch.float32,device=device))
|
243 |
+
if depth_tensor is None:
|
244 |
+
z = torch.ones_like(x)
|
245 |
+
else:
|
246 |
+
z = torch.as_tensor(depth_tensor, dtype=torch.float32, device=device)
|
247 |
+
xyz_old_world = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=1)
|
248 |
+
|
249 |
+
xyz_old_cam_xy = persp_cam_old.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2]
|
250 |
+
xyz_new_cam_xy = persp_cam_new.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2]
|
251 |
+
|
252 |
+
offset_xy = xyz_new_cam_xy - xyz_old_cam_xy
|
253 |
+
# affine_grid theta param expects a batch of 2D mats. Each is 2x3 to do rotation+translation.
|
254 |
+
identity_2d_batch = torch.tensor([[1.,0.,0.],[0.,1.,0.]], device=device).unsqueeze(0)
|
255 |
+
# coords_2d will have shape (N,H,W,2).. which is also what grid_sample needs.
|
256 |
+
coords_2d = torch.nn.functional.affine_grid(identity_2d_batch, [1,1,h,w], align_corners=False)
|
257 |
+
offset_coords_2d = coords_2d - torch.reshape(offset_xy, (h,w,2)).unsqueeze(0)
|
258 |
+
|
259 |
+
image_tensor = rearrange(torch.from_numpy(prev_img_cv2.astype(np.float32)), 'h w c -> c h w').to(device)
|
260 |
+
new_image = torch.nn.functional.grid_sample(
|
261 |
+
image_tensor.add(1/512 - 0.0001).unsqueeze(0),
|
262 |
+
offset_coords_2d,
|
263 |
+
mode=anim_args.sampling_mode,
|
264 |
+
padding_mode=anim_args.padding_mode,
|
265 |
+
align_corners=False
|
266 |
+
)
|
267 |
+
|
268 |
+
# convert back to cv2 style numpy array
|
269 |
+
result = rearrange(
|
270 |
+
new_image.squeeze().clamp(0,255),
|
271 |
+
'c h w -> h w c'
|
272 |
+
).cpu().numpy().astype(prev_img_cv2.dtype)
|
273 |
+
return result
|
274 |
+
|
275 |
+
class DeformAnimKeys():
|
276 |
+
def __init__(self, anim_args):
|
277 |
+
self.angle_series = get_inbetweens(parse_key_frames(anim_args.angle), anim_args.max_frames)
|
278 |
+
self.zoom_series = get_inbetweens(parse_key_frames(anim_args.zoom), anim_args.max_frames)
|
279 |
+
self.translation_x_series = get_inbetweens(parse_key_frames(anim_args.translation_x), anim_args.max_frames)
|
280 |
+
self.translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y), anim_args.max_frames)
|
281 |
+
self.translation_z_series = get_inbetweens(parse_key_frames(anim_args.translation_z), anim_args.max_frames)
|
282 |
+
self.rotation_3d_x_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_x), anim_args.max_frames)
|
283 |
+
self.rotation_3d_y_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_y), anim_args.max_frames)
|
284 |
+
self.rotation_3d_z_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_z), anim_args.max_frames)
|
285 |
+
self.perspective_flip_theta_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_theta), anim_args.max_frames)
|
286 |
+
self.perspective_flip_phi_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_phi), anim_args.max_frames)
|
287 |
+
self.perspective_flip_gamma_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_gamma), anim_args.max_frames)
|
288 |
+
self.perspective_flip_fv_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_fv), anim_args.max_frames)
|
289 |
+
self.noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule), anim_args.max_frames)
|
290 |
+
self.strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule), anim_args.max_frames)
|
291 |
+
self.contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule), anim_args.max_frames)
|
292 |
+
|
293 |
+
def get_inbetweens(key_frames, max_frames, integer=False, interp_method='Linear'):
|
294 |
+
import numexpr
|
295 |
+
key_frame_series = pd.Series([np.nan for a in range(max_frames)])
|
296 |
+
|
297 |
+
for i in range(0, max_frames):
|
298 |
+
if i in key_frames:
|
299 |
+
value = key_frames[i]
|
300 |
+
value_is_number = check_is_number(value)
|
301 |
+
# if it's only a number, leave the rest for the default interpolation
|
302 |
+
if value_is_number:
|
303 |
+
t = i
|
304 |
+
key_frame_series[i] = value
|
305 |
+
if not value_is_number:
|
306 |
+
t = i
|
307 |
+
key_frame_series[i] = numexpr.evaluate(value)
|
308 |
+
key_frame_series = key_frame_series.astype(float)
|
309 |
+
|
310 |
+
if interp_method == 'Cubic' and len(key_frames.items()) <= 3:
|
311 |
+
interp_method = 'Quadratic'
|
312 |
+
if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:
|
313 |
+
interp_method = 'Linear'
|
314 |
+
|
315 |
+
key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]
|
316 |
+
key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]
|
317 |
+
key_frame_series = key_frame_series.interpolate(method=interp_method.lower(), limit_direction='both')
|
318 |
+
if integer:
|
319 |
+
return key_frame_series.astype(int)
|
320 |
+
return key_frame_series
|
321 |
+
|
322 |
+
def parse_key_frames(string, prompt_parser=None):
|
323 |
+
# because math functions (i.e. sin(t)) can utilize brackets
|
324 |
+
# it extracts the value in form of some stuff
|
325 |
+
# which has previously been enclosed with brackets and
|
326 |
+
# with a comma or end of line existing after the closing one
|
327 |
+
pattern = r'((?P<frame>[0-9]+):[\s]*\((?P<param>[\S\s]*?)\)([,][\s]?|[\s]?$))'
|
328 |
+
frames = dict()
|
329 |
+
for match_object in re.finditer(pattern, string):
|
330 |
+
frame = int(match_object.groupdict()['frame'])
|
331 |
+
param = match_object.groupdict()['param']
|
332 |
+
if prompt_parser:
|
333 |
+
frames[frame] = prompt_parser(param)
|
334 |
+
else:
|
335 |
+
frames[frame] = param
|
336 |
+
if frames == {} and len(string) != 0:
|
337 |
+
raise RuntimeError('Key Frame string not correctly formatted')
|
338 |
+
return frames
|
deforum-stable-diffusion/helpers/callback.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import torchvision.transforms.functional as TF
|
4 |
+
from torchvision.utils import make_grid
|
5 |
+
import numpy as np
|
6 |
+
from IPython import display
|
7 |
+
|
8 |
+
#
|
9 |
+
# Callback functions
|
10 |
+
#
|
11 |
+
class SamplerCallback(object):
|
12 |
+
# Creates the callback function to be passed into the samplers for each step
|
13 |
+
def __init__(self, args, root, mask=None, init_latent=None, sigmas=None, sampler=None,
|
14 |
+
verbose=False):
|
15 |
+
self.model = root.model
|
16 |
+
self.device = root.device
|
17 |
+
self.sampler_name = args.sampler
|
18 |
+
self.dynamic_threshold = args.dynamic_threshold
|
19 |
+
self.static_threshold = args.static_threshold
|
20 |
+
self.mask = mask
|
21 |
+
self.init_latent = init_latent
|
22 |
+
self.sigmas = sigmas
|
23 |
+
self.sampler = sampler
|
24 |
+
self.verbose = verbose
|
25 |
+
self.batch_size = args.n_samples
|
26 |
+
self.save_sample_per_step = args.save_sample_per_step
|
27 |
+
self.show_sample_per_step = args.show_sample_per_step
|
28 |
+
self.paths_to_image_steps = [os.path.join( args.outdir, f"{args.timestring}_{index:02}_{args.seed}") for index in range(args.n_samples) ]
|
29 |
+
|
30 |
+
if self.save_sample_per_step:
|
31 |
+
for path in self.paths_to_image_steps:
|
32 |
+
os.makedirs(path, exist_ok=True)
|
33 |
+
|
34 |
+
self.step_index = 0
|
35 |
+
|
36 |
+
self.noise = None
|
37 |
+
if init_latent is not None:
|
38 |
+
self.noise = torch.randn_like(init_latent, device=self.device)
|
39 |
+
|
40 |
+
self.mask_schedule = None
|
41 |
+
if sigmas is not None and len(sigmas) > 0:
|
42 |
+
self.mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas))
|
43 |
+
elif len(sigmas) == 0:
|
44 |
+
self.mask = None # no mask needed if no steps (usually happens because strength==1.0)
|
45 |
+
|
46 |
+
if self.sampler_name in ["plms","ddim"]:
|
47 |
+
if mask is not None:
|
48 |
+
assert sampler is not None, "Callback function for stable-diffusion samplers requires sampler variable"
|
49 |
+
|
50 |
+
if self.sampler_name in ["plms","ddim"]:
|
51 |
+
# Callback function formated for compvis latent diffusion samplers
|
52 |
+
self.callback = self.img_callback_
|
53 |
+
else:
|
54 |
+
# Default callback function uses k-diffusion sampler variables
|
55 |
+
self.callback = self.k_callback_
|
56 |
+
|
57 |
+
self.verbose_print = print if verbose else lambda *args, **kwargs: None
|
58 |
+
|
59 |
+
def display_images(self, images):
|
60 |
+
images = images.double().cpu().add(1).div(2).clamp(0, 1)
|
61 |
+
images = torch.tensor(np.array(images))
|
62 |
+
grid = make_grid(images, 4).cpu()
|
63 |
+
display.clear_output(wait=True)
|
64 |
+
display.display(TF.to_pil_image(grid))
|
65 |
+
return
|
66 |
+
|
67 |
+
def view_sample_step(self, latents, path_name_modifier=''):
|
68 |
+
if self.save_sample_per_step:
|
69 |
+
samples = self.model.decode_first_stage(latents)
|
70 |
+
fname = f'{path_name_modifier}_{self.step_index:05}.png'
|
71 |
+
for i, sample in enumerate(samples):
|
72 |
+
sample = sample.double().cpu().add(1).div(2).clamp(0, 1)
|
73 |
+
sample = torch.tensor(np.array(sample))
|
74 |
+
grid = make_grid(sample, 4).cpu()
|
75 |
+
TF.to_pil_image(grid).save(os.path.join(self.paths_to_image_steps[i], fname))
|
76 |
+
if self.show_sample_per_step:
|
77 |
+
samples = self.model.linear_decode(latents)
|
78 |
+
print(path_name_modifier)
|
79 |
+
self.display_images(samples)
|
80 |
+
return
|
81 |
+
|
82 |
+
# The callback function is applied to the image at each step
|
83 |
+
def dynamic_thresholding_(self, img, threshold):
|
84 |
+
# Dynamic thresholding from Imagen paper (May 2022)
|
85 |
+
s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim)))
|
86 |
+
s = np.max(np.append(s,1.0))
|
87 |
+
torch.clamp_(img, -1*s, s)
|
88 |
+
torch.FloatTensor.div_(img, s)
|
89 |
+
|
90 |
+
# Callback for samplers in the k-diffusion repo, called thus:
|
91 |
+
# callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
92 |
+
def k_callback_(self, args_dict):
|
93 |
+
self.step_index = args_dict['i']
|
94 |
+
if self.dynamic_threshold is not None:
|
95 |
+
self.dynamic_thresholding_(args_dict['x'], self.dynamic_threshold)
|
96 |
+
if self.static_threshold is not None:
|
97 |
+
torch.clamp_(args_dict['x'], -1*self.static_threshold, self.static_threshold)
|
98 |
+
if self.mask is not None:
|
99 |
+
init_noise = self.init_latent + self.noise * args_dict['sigma']
|
100 |
+
is_masked = torch.logical_and(self.mask >= self.mask_schedule[args_dict['i']], self.mask != 0 )
|
101 |
+
new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1)
|
102 |
+
args_dict['x'].copy_(new_img)
|
103 |
+
|
104 |
+
self.view_sample_step(args_dict['denoised'], "x0_pred")
|
105 |
+
self.view_sample_step(args_dict['x'], "x")
|
106 |
+
|
107 |
+
# Callback for Compvis samplers
|
108 |
+
# Function that is called on the image (img) and step (i) at each step
|
109 |
+
def img_callback_(self, img, pred_x0, i):
|
110 |
+
self.step_index = i
|
111 |
+
# Thresholding functions
|
112 |
+
if self.dynamic_threshold is not None:
|
113 |
+
self.dynamic_thresholding_(img, self.dynamic_threshold)
|
114 |
+
if self.static_threshold is not None:
|
115 |
+
torch.clamp_(img, -1*self.static_threshold, self.static_threshold)
|
116 |
+
if self.mask is not None:
|
117 |
+
i_inv = len(self.sigmas) - i - 1
|
118 |
+
init_noise = self.sampler.stochastic_encode(self.init_latent, torch.tensor([i_inv]*self.batch_size).to(self.device), noise=self.noise)
|
119 |
+
is_masked = torch.logical_and(self.mask >= self.mask_schedule[i], self.mask != 0 )
|
120 |
+
new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1)
|
121 |
+
img.copy_(new_img)
|
122 |
+
|
123 |
+
self.view_sample_step(pred_x0, "x0_pred")
|
124 |
+
self.view_sample_step(img, "x")
|
deforum-stable-diffusion/helpers/colors.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from skimage.exposure import match_histograms
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
def maintain_colors(prev_img, color_match_sample, mode):
|
5 |
+
if mode == 'Match Frame 0 RGB':
|
6 |
+
return match_histograms(prev_img, color_match_sample, multichannel=True)
|
7 |
+
elif mode == 'Match Frame 0 HSV':
|
8 |
+
prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV)
|
9 |
+
color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV)
|
10 |
+
matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True)
|
11 |
+
return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB)
|
12 |
+
else: # Match Frame 0 LAB
|
13 |
+
prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB)
|
14 |
+
color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB)
|
15 |
+
matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True)
|
16 |
+
return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB)
|
deforum-stable-diffusion/helpers/conditioning.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import clip
|
5 |
+
from torchvision.transforms import Normalize as Normalize
|
6 |
+
from torchvision.utils import make_grid
|
7 |
+
import numpy as np
|
8 |
+
from IPython import display
|
9 |
+
from sklearn.cluster import KMeans
|
10 |
+
import torchvision.transforms.functional as TF
|
11 |
+
|
12 |
+
###
|
13 |
+
# Loss functions
|
14 |
+
###
|
15 |
+
|
16 |
+
|
17 |
+
## CLIP -----------------------------------------
|
18 |
+
|
19 |
+
class MakeCutouts(nn.Module):
|
20 |
+
def __init__(self, cut_size, cutn, cut_pow=1.):
|
21 |
+
super().__init__()
|
22 |
+
self.cut_size = cut_size
|
23 |
+
self.cutn = cutn
|
24 |
+
self.cut_pow = cut_pow
|
25 |
+
|
26 |
+
def forward(self, input):
|
27 |
+
sideY, sideX = input.shape[2:4]
|
28 |
+
max_size = min(sideX, sideY)
|
29 |
+
min_size = min(sideX, sideY, self.cut_size)
|
30 |
+
cutouts = []
|
31 |
+
for _ in range(self.cutn):
|
32 |
+
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
|
33 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
34 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
35 |
+
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
|
36 |
+
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
|
37 |
+
return torch.cat(cutouts)
|
38 |
+
|
39 |
+
|
40 |
+
def spherical_dist_loss(x, y):
|
41 |
+
x = F.normalize(x, dim=-1)
|
42 |
+
y = F.normalize(y, dim=-1)
|
43 |
+
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
44 |
+
|
45 |
+
def make_clip_loss_fn(root, args):
|
46 |
+
clip_size = root.clip_model.visual.input_resolution # for openslip: clip_model.visual.image_size
|
47 |
+
|
48 |
+
def parse_prompt(prompt):
|
49 |
+
if prompt.startswith('http://') or prompt.startswith('https://'):
|
50 |
+
vals = prompt.rsplit(':', 2)
|
51 |
+
vals = [vals[0] + ':' + vals[1], *vals[2:]]
|
52 |
+
else:
|
53 |
+
vals = prompt.rsplit(':', 1)
|
54 |
+
vals = vals + ['', '1'][len(vals):]
|
55 |
+
return vals[0], float(vals[1])
|
56 |
+
|
57 |
+
def parse_clip_prompts(clip_prompt):
|
58 |
+
target_embeds, weights = [], []
|
59 |
+
for prompt in clip_prompt:
|
60 |
+
txt, weight = parse_prompt(prompt)
|
61 |
+
target_embeds.append(root.clip_model.encode_text(clip.tokenize(txt).to(root.device)).float())
|
62 |
+
weights.append(weight)
|
63 |
+
target_embeds = torch.cat(target_embeds)
|
64 |
+
weights = torch.tensor(weights, device=root.device)
|
65 |
+
if weights.sum().abs() < 1e-3:
|
66 |
+
raise RuntimeError('Clip prompt weights must not sum to 0.')
|
67 |
+
weights /= weights.sum().abs()
|
68 |
+
return target_embeds, weights
|
69 |
+
|
70 |
+
normalize = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
71 |
+
std=[0.26862954, 0.26130258, 0.27577711])
|
72 |
+
|
73 |
+
make_cutouts = MakeCutouts(clip_size, args.cutn, args.cut_pow)
|
74 |
+
target_embeds, weights = parse_clip_prompts(args.clip_prompt)
|
75 |
+
|
76 |
+
def clip_loss_fn(x, sigma, **kwargs):
|
77 |
+
nonlocal target_embeds, weights, make_cutouts, normalize
|
78 |
+
clip_in = normalize(make_cutouts(x.add(1).div(2)))
|
79 |
+
image_embeds = root.clip_model.encode_image(clip_in).float()
|
80 |
+
dists = spherical_dist_loss(image_embeds[:, None], target_embeds[None])
|
81 |
+
dists = dists.view([args.cutn, 1, -1])
|
82 |
+
losses = dists.mul(weights).sum(2).mean(0)
|
83 |
+
return losses.sum()
|
84 |
+
|
85 |
+
return clip_loss_fn
|
86 |
+
|
87 |
+
def make_aesthetics_loss_fn(root,args):
|
88 |
+
clip_size = root.clip_model.visual.input_resolution # for openslip: clip_model.visual.image_size
|
89 |
+
|
90 |
+
def aesthetics_cond_fn(x, sigma, **kwargs):
|
91 |
+
clip_in = F.interpolate(x, (clip_size, clip_size))
|
92 |
+
image_embeds = root.clip_model.encode_image(clip_in).float()
|
93 |
+
losses = (10 - root.aesthetics_model(image_embeds)[0])
|
94 |
+
return losses.sum()
|
95 |
+
|
96 |
+
return aesthetics_cond_fn
|
97 |
+
|
98 |
+
## end CLIP -----------------------------------------
|
99 |
+
|
100 |
+
# blue loss from @johnowhitaker's tutorial on Grokking Stable Diffusion
|
101 |
+
def blue_loss_fn(x, sigma, **kwargs):
|
102 |
+
# How far are the blue channel values to 0.9:
|
103 |
+
error = torch.abs(x[:,-1, :, :] - 0.9).mean()
|
104 |
+
return error
|
105 |
+
|
106 |
+
# MSE loss from init
|
107 |
+
def make_mse_loss(target):
|
108 |
+
def mse_loss(x, sigma, **kwargs):
|
109 |
+
return (x - target).square().mean()
|
110 |
+
return mse_loss
|
111 |
+
|
112 |
+
# MSE loss from init
|
113 |
+
def exposure_loss(target):
|
114 |
+
def exposure_loss_fn(x, sigma, **kwargs):
|
115 |
+
error = torch.abs(x-target).mean()
|
116 |
+
return error
|
117 |
+
return exposure_loss_fn
|
118 |
+
|
119 |
+
def mean_loss_fn(x, sigma, **kwargs):
|
120 |
+
error = torch.abs(x).mean()
|
121 |
+
return error
|
122 |
+
|
123 |
+
def var_loss_fn(x, sigma, **kwargs):
|
124 |
+
error = x.var()
|
125 |
+
return error
|
126 |
+
|
127 |
+
def get_color_palette(root, n_colors, target, verbose=False):
|
128 |
+
def display_color_palette(color_list):
|
129 |
+
# Expand to 64x64 grid of single color pixels
|
130 |
+
images = color_list.unsqueeze(2).repeat(1,1,64).unsqueeze(3).repeat(1,1,1,64)
|
131 |
+
images = images.double().cpu().add(1).div(2).clamp(0, 1)
|
132 |
+
images = torch.tensor(np.array(images))
|
133 |
+
grid = make_grid(images, 8).cpu()
|
134 |
+
display.display(TF.to_pil_image(grid))
|
135 |
+
return
|
136 |
+
|
137 |
+
# Create color palette
|
138 |
+
kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(torch.flatten(target[0],1,2).T.cpu().numpy())
|
139 |
+
color_list = torch.Tensor(kmeans.cluster_centers_).to(root.device)
|
140 |
+
if verbose:
|
141 |
+
display_color_palette(color_list)
|
142 |
+
# Get ratio of each color class in the target image
|
143 |
+
color_indexes, color_counts = np.unique(kmeans.labels_, return_counts=True)
|
144 |
+
# color_list = color_list[color_indexes]
|
145 |
+
return color_list, color_counts
|
146 |
+
|
147 |
+
def make_rgb_color_match_loss(root, target, n_colors, ignore_sat_weight=None, img_shape=None, device='cuda:0'):
|
148 |
+
"""
|
149 |
+
target (tensor): Image sample (values from -1 to 1) to extract the color palette
|
150 |
+
n_colors (int): Number of colors in the color palette
|
151 |
+
ignore_sat_weight (None or number>0): Scale to ignore color saturation in color comparison
|
152 |
+
img_shape (None or (int, int)): shape (width, height) of sample that the conditioning gradient is applied to,
|
153 |
+
if None then calculate target color distribution during gradient calculation
|
154 |
+
rather than once at the beginning
|
155 |
+
"""
|
156 |
+
assert n_colors > 0, "Must use at least one color with color match loss"
|
157 |
+
|
158 |
+
def adjust_saturation(sample, saturation_factor):
|
159 |
+
# as in torchvision.transforms.functional.adjust_saturation, but for tensors with values from -1,1
|
160 |
+
return blend(sample, TF.rgb_to_grayscale(sample), saturation_factor)
|
161 |
+
|
162 |
+
def blend(img1, img2, ratio):
|
163 |
+
return (ratio * img1 + (1.0 - ratio) * img2).clamp(-1, 1).to(img1.dtype)
|
164 |
+
|
165 |
+
def color_distance_distributions(n_colors, img_shape, color_list, color_counts, n_images=1):
|
166 |
+
# Get the target color distance distributions
|
167 |
+
# Ensure color counts total the amout of pixels in the image
|
168 |
+
n_pixels = img_shape[0]*img_shape[1]
|
169 |
+
color_counts = (color_counts * n_pixels / sum(color_counts)).astype(int)
|
170 |
+
|
171 |
+
# Make color distances for each color, sorted by distance
|
172 |
+
color_distributions = torch.zeros((n_colors, n_images, n_pixels), device=device)
|
173 |
+
for i_image in range(n_images):
|
174 |
+
for ic,color0 in enumerate(color_list):
|
175 |
+
i_dist = 0
|
176 |
+
for jc,color1 in enumerate(color_list):
|
177 |
+
color_dist = torch.linalg.norm(color0 - color1)
|
178 |
+
color_distributions[ic, i_image, i_dist:i_dist+color_counts[jc]] = color_dist
|
179 |
+
i_dist += color_counts[jc]
|
180 |
+
color_distributions, _ = torch.sort(color_distributions,dim=2)
|
181 |
+
return color_distributions
|
182 |
+
|
183 |
+
color_list, color_counts = get_color_palette(root, n_colors, target)
|
184 |
+
color_distributions = None
|
185 |
+
if img_shape is not None:
|
186 |
+
color_distributions = color_distance_distributions(n_colors, img_shape, color_list, color_counts)
|
187 |
+
|
188 |
+
def rgb_color_ratio_loss(x, sigma, **kwargs):
|
189 |
+
nonlocal color_distributions
|
190 |
+
all_color_norm_distances = torch.ones(len(color_list), x.shape[0], x.shape[2], x.shape[3]).to(device) * 6.0 # distance to color won't be more than max norm1 distance between -1 and 1 in 3 color dimensions
|
191 |
+
|
192 |
+
for ic,color in enumerate(color_list):
|
193 |
+
# Make a tensor of entirely one color
|
194 |
+
color = color[None,:,None].repeat(1,1,x.shape[2]).unsqueeze(3).repeat(1,1,1,x.shape[3])
|
195 |
+
# Get the color distances
|
196 |
+
if ignore_sat_weight is None:
|
197 |
+
# Simple color distance
|
198 |
+
color_distances = torch.linalg.norm(x - color, dim=1)
|
199 |
+
else:
|
200 |
+
# Color distance if the colors were saturated
|
201 |
+
# This is to make color comparison ignore shadows and highlights, for example
|
202 |
+
color_distances = torch.linalg.norm(adjust_saturation(x, ignore_sat_weight) - color, dim=1)
|
203 |
+
|
204 |
+
all_color_norm_distances[ic] = color_distances
|
205 |
+
all_color_norm_distances = torch.flatten(all_color_norm_distances,start_dim=2)
|
206 |
+
|
207 |
+
if color_distributions is None:
|
208 |
+
color_distributions = color_distance_distributions(n_colors,
|
209 |
+
(x.shape[2], x.shape[3]),
|
210 |
+
color_list,
|
211 |
+
color_counts,
|
212 |
+
n_images=x.shape[0])
|
213 |
+
|
214 |
+
# Sort the color distances so we can compare them as if they were a cumulative distribution function
|
215 |
+
all_color_norm_distances, _ = torch.sort(all_color_norm_distances,dim=2)
|
216 |
+
|
217 |
+
color_norm_distribution_diff = all_color_norm_distances - color_distributions
|
218 |
+
|
219 |
+
return color_norm_distribution_diff.square().mean()
|
220 |
+
|
221 |
+
return rgb_color_ratio_loss
|
222 |
+
|
223 |
+
|
224 |
+
###
|
225 |
+
# Thresholding functions for grad
|
226 |
+
###
|
227 |
+
def threshold_by(threshold, threshold_type, clamp_schedule):
|
228 |
+
|
229 |
+
def dynamic_thresholding(vals, sigma):
|
230 |
+
# Dynamic thresholding from Imagen paper (May 2022)
|
231 |
+
s = np.percentile(np.abs(vals.cpu()), threshold, axis=tuple(range(1,vals.ndim)))
|
232 |
+
s = np.max(np.append(s,1.0))
|
233 |
+
vals = torch.clamp(vals, -1*s, s)
|
234 |
+
vals = torch.FloatTensor.div(vals, s)
|
235 |
+
return vals
|
236 |
+
|
237 |
+
def static_thresholding(vals, sigma):
|
238 |
+
vals = torch.clamp(vals, -1*threshold, threshold)
|
239 |
+
return vals
|
240 |
+
|
241 |
+
def mean_thresholding(vals, sigma): # Thresholding that appears in Jax and Disco
|
242 |
+
magnitude = vals.square().mean(axis=(1,2,3),keepdims=True).sqrt()
|
243 |
+
vals = vals * torch.where(magnitude > threshold, threshold / magnitude, 1.0)
|
244 |
+
return vals
|
245 |
+
|
246 |
+
def scheduling(vals, sigma):
|
247 |
+
clamp_val = clamp_schedule[sigma.item()]
|
248 |
+
magnitude = vals.square().mean().sqrt()
|
249 |
+
vals = vals * magnitude.clamp(max=clamp_val) / magnitude
|
250 |
+
#print(clamp_val)
|
251 |
+
return vals
|
252 |
+
|
253 |
+
if threshold_type == 'dynamic':
|
254 |
+
return dynamic_thresholding
|
255 |
+
elif threshold_type == 'static':
|
256 |
+
return static_thresholding
|
257 |
+
elif threshold_type == 'mean':
|
258 |
+
return mean_thresholding
|
259 |
+
elif threshold_type == 'schedule':
|
260 |
+
return scheduling
|
261 |
+
else:
|
262 |
+
raise Exception(f"Thresholding type {threshold_type} not supported")
|
deforum-stable-diffusion/helpers/depth.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import requests
|
6 |
+
import torch
|
7 |
+
import torchvision.transforms as T
|
8 |
+
import torchvision.transforms.functional as TF
|
9 |
+
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
from infer import InferenceHelper
|
14 |
+
from midas.dpt_depth import DPTDepthModel
|
15 |
+
from midas.transforms import Resize, NormalizeImage, PrepareForNet
|
16 |
+
|
17 |
+
|
18 |
+
def wget(url, outputdir):
|
19 |
+
filename = url.split("/")[-1]
|
20 |
+
|
21 |
+
ckpt_request = requests.get(url)
|
22 |
+
request_status = ckpt_request.status_code
|
23 |
+
|
24 |
+
# inform user of errors
|
25 |
+
if request_status == 403:
|
26 |
+
raise ConnectionRefusedError("You have not accepted the license for this model.")
|
27 |
+
elif request_status == 404:
|
28 |
+
raise ConnectionError("Could not make contact with server")
|
29 |
+
elif request_status != 200:
|
30 |
+
raise ConnectionError(f"Some other error has ocurred - response code: {request_status}")
|
31 |
+
|
32 |
+
# write to model path
|
33 |
+
with open(os.path.join(outputdir, filename), 'wb') as model_file:
|
34 |
+
model_file.write(ckpt_request.content)
|
35 |
+
|
36 |
+
|
37 |
+
class DepthModel():
|
38 |
+
def __init__(self, device):
|
39 |
+
self.adabins_helper = None
|
40 |
+
self.depth_min = 1000
|
41 |
+
self.depth_max = -1000
|
42 |
+
self.device = device
|
43 |
+
self.midas_model = None
|
44 |
+
self.midas_transform = None
|
45 |
+
|
46 |
+
def load_adabins(self, models_path):
|
47 |
+
if not os.path.exists(os.path.join(models_path,'AdaBins_nyu.pt')):
|
48 |
+
print("Downloading AdaBins_nyu.pt...")
|
49 |
+
os.makedirs(models_path, exist_ok=True)
|
50 |
+
wget("https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt", models_path)
|
51 |
+
self.adabins_helper = InferenceHelper(models_path, dataset='nyu', device=self.device)
|
52 |
+
|
53 |
+
def load_midas(self, models_path, half_precision=True):
|
54 |
+
if not os.path.exists(os.path.join(models_path, 'dpt_large-midas-2f21e586.pt')):
|
55 |
+
print("Downloading dpt_large-midas-2f21e586.pt...")
|
56 |
+
wget("https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt", models_path)
|
57 |
+
|
58 |
+
self.midas_model = DPTDepthModel(
|
59 |
+
path=os.path.join(models_path, "dpt_large-midas-2f21e586.pt"),
|
60 |
+
backbone="vitl16_384",
|
61 |
+
non_negative=True,
|
62 |
+
)
|
63 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
64 |
+
|
65 |
+
self.midas_transform = T.Compose([
|
66 |
+
Resize(
|
67 |
+
384, 384,
|
68 |
+
resize_target=None,
|
69 |
+
keep_aspect_ratio=True,
|
70 |
+
ensure_multiple_of=32,
|
71 |
+
resize_method="minimal",
|
72 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
73 |
+
),
|
74 |
+
normalization,
|
75 |
+
PrepareForNet()
|
76 |
+
])
|
77 |
+
|
78 |
+
self.midas_model.eval()
|
79 |
+
if half_precision and self.device == torch.device("cuda"):
|
80 |
+
self.midas_model = self.midas_model.to(memory_format=torch.channels_last)
|
81 |
+
self.midas_model = self.midas_model.half()
|
82 |
+
self.midas_model.to(self.device)
|
83 |
+
|
84 |
+
def predict(self, prev_img_cv2, anim_args) -> torch.Tensor:
|
85 |
+
w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0]
|
86 |
+
|
87 |
+
# predict depth with AdaBins
|
88 |
+
use_adabins = anim_args.midas_weight < 1.0 and self.adabins_helper is not None
|
89 |
+
if use_adabins:
|
90 |
+
MAX_ADABINS_AREA = 500000
|
91 |
+
MIN_ADABINS_AREA = 448*448
|
92 |
+
|
93 |
+
# resize image if too large or too small
|
94 |
+
img_pil = Image.fromarray(cv2.cvtColor(prev_img_cv2.astype(np.uint8), cv2.COLOR_RGB2BGR))
|
95 |
+
image_pil_area = w*h
|
96 |
+
resized = True
|
97 |
+
if image_pil_area > MAX_ADABINS_AREA:
|
98 |
+
scale = math.sqrt(MAX_ADABINS_AREA) / math.sqrt(image_pil_area)
|
99 |
+
depth_input = img_pil.resize((int(w*scale), int(h*scale)), Image.LANCZOS) # LANCZOS is good for downsampling
|
100 |
+
print(f" resized to {depth_input.width}x{depth_input.height}")
|
101 |
+
elif image_pil_area < MIN_ADABINS_AREA:
|
102 |
+
scale = math.sqrt(MIN_ADABINS_AREA) / math.sqrt(image_pil_area)
|
103 |
+
depth_input = img_pil.resize((int(w*scale), int(h*scale)), Image.BICUBIC)
|
104 |
+
print(f" resized to {depth_input.width}x{depth_input.height}")
|
105 |
+
else:
|
106 |
+
depth_input = img_pil
|
107 |
+
resized = False
|
108 |
+
|
109 |
+
# predict depth and resize back to original dimensions
|
110 |
+
try:
|
111 |
+
with torch.no_grad():
|
112 |
+
_, adabins_depth = self.adabins_helper.predict_pil(depth_input)
|
113 |
+
if resized:
|
114 |
+
adabins_depth = TF.resize(
|
115 |
+
torch.from_numpy(adabins_depth),
|
116 |
+
torch.Size([h, w]),
|
117 |
+
interpolation=TF.InterpolationMode.BICUBIC
|
118 |
+
)
|
119 |
+
adabins_depth = adabins_depth.cpu().numpy()
|
120 |
+
adabins_depth = adabins_depth.squeeze()
|
121 |
+
except:
|
122 |
+
print(f" exception encountered, falling back to pure MiDaS")
|
123 |
+
use_adabins = False
|
124 |
+
torch.cuda.empty_cache()
|
125 |
+
|
126 |
+
if self.midas_model is not None:
|
127 |
+
# convert image from 0->255 uint8 to 0->1 float for feeding to MiDaS
|
128 |
+
img_midas = prev_img_cv2.astype(np.float32) / 255.0
|
129 |
+
img_midas_input = self.midas_transform({"image": img_midas})["image"]
|
130 |
+
|
131 |
+
# MiDaS depth estimation implementation
|
132 |
+
sample = torch.from_numpy(img_midas_input).float().to(self.device).unsqueeze(0)
|
133 |
+
if self.device == torch.device("cuda"):
|
134 |
+
sample = sample.to(memory_format=torch.channels_last)
|
135 |
+
sample = sample.half()
|
136 |
+
with torch.no_grad():
|
137 |
+
midas_depth = self.midas_model.forward(sample)
|
138 |
+
midas_depth = torch.nn.functional.interpolate(
|
139 |
+
midas_depth.unsqueeze(1),
|
140 |
+
size=img_midas.shape[:2],
|
141 |
+
mode="bicubic",
|
142 |
+
align_corners=False,
|
143 |
+
).squeeze()
|
144 |
+
midas_depth = midas_depth.cpu().numpy()
|
145 |
+
torch.cuda.empty_cache()
|
146 |
+
|
147 |
+
# MiDaS makes the near values greater, and the far values lesser. Let's reverse that and try to align with AdaBins a bit better.
|
148 |
+
midas_depth = np.subtract(50.0, midas_depth)
|
149 |
+
midas_depth = midas_depth / 19.0
|
150 |
+
|
151 |
+
# blend between MiDaS and AdaBins predictions
|
152 |
+
if use_adabins:
|
153 |
+
depth_map = midas_depth*anim_args.midas_weight + adabins_depth*(1.0-anim_args.midas_weight)
|
154 |
+
else:
|
155 |
+
depth_map = midas_depth
|
156 |
+
|
157 |
+
depth_map = np.expand_dims(depth_map, axis=0)
|
158 |
+
depth_tensor = torch.from_numpy(depth_map).squeeze().to(self.device)
|
159 |
+
else:
|
160 |
+
depth_tensor = torch.ones((h, w), device=self.device)
|
161 |
+
|
162 |
+
return depth_tensor
|
163 |
+
|
164 |
+
def save(self, filename: str, depth: torch.Tensor):
|
165 |
+
depth = depth.cpu().numpy()
|
166 |
+
if len(depth.shape) == 2:
|
167 |
+
depth = np.expand_dims(depth, axis=0)
|
168 |
+
self.depth_min = min(self.depth_min, depth.min())
|
169 |
+
self.depth_max = max(self.depth_max, depth.max())
|
170 |
+
print(f" depth min:{depth.min()} max:{depth.max()}")
|
171 |
+
denom = max(1e-8, self.depth_max - self.depth_min)
|
172 |
+
temp = rearrange((depth - self.depth_min) / denom * 255, 'c h w -> h w c')
|
173 |
+
temp = repeat(temp, 'h w 1 -> h w c', c=3)
|
174 |
+
Image.fromarray(temp.astype(np.uint8)).save(filename)
|
175 |
+
|
deforum-stable-diffusion/helpers/generate.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
import requests
|
4 |
+
import numpy as np
|
5 |
+
import torchvision.transforms.functional as TF
|
6 |
+
from pytorch_lightning import seed_everything
|
7 |
+
import os
|
8 |
+
from ldm.models.diffusion.plms import PLMSSampler
|
9 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
10 |
+
from k_diffusion.external import CompVisDenoiser
|
11 |
+
from torch import autocast
|
12 |
+
from contextlib import nullcontext
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
from .prompt import get_uc_and_c
|
16 |
+
from .k_samplers import sampler_fn, make_inject_timing_fn
|
17 |
+
from scipy.ndimage import gaussian_filter
|
18 |
+
|
19 |
+
from .callback import SamplerCallback
|
20 |
+
|
21 |
+
from .conditioning import exposure_loss, make_mse_loss, get_color_palette, make_clip_loss_fn
|
22 |
+
from .conditioning import make_rgb_color_match_loss, blue_loss_fn, threshold_by, make_aesthetics_loss_fn, mean_loss_fn, var_loss_fn, exposure_loss
|
23 |
+
from .model_wrap import CFGDenoiserWithGrad
|
24 |
+
from .load_images import load_img, load_mask_latent, prepare_mask, prepare_overlay_mask
|
25 |
+
|
26 |
+
def add_noise(sample: torch.Tensor, noise_amt: float) -> torch.Tensor:
|
27 |
+
return sample + torch.randn(sample.shape, device=sample.device) * noise_amt
|
28 |
+
|
29 |
+
def generate(args, root, frame = 0, return_latent=False, return_sample=False, return_c=False):
|
30 |
+
seed_everything(args.seed)
|
31 |
+
os.makedirs(args.outdir, exist_ok=True)
|
32 |
+
|
33 |
+
sampler = PLMSSampler(root.model) if args.sampler == 'plms' else DDIMSampler(root.model)
|
34 |
+
model_wrap = CompVisDenoiser(root.model)
|
35 |
+
batch_size = args.n_samples
|
36 |
+
prompt = args.prompt
|
37 |
+
assert prompt is not None
|
38 |
+
data = [batch_size * [prompt]]
|
39 |
+
precision_scope = autocast if args.precision == "autocast" else nullcontext
|
40 |
+
|
41 |
+
init_latent = None
|
42 |
+
mask_image = None
|
43 |
+
init_image = None
|
44 |
+
if args.init_latent is not None:
|
45 |
+
init_latent = args.init_latent
|
46 |
+
elif args.init_sample is not None:
|
47 |
+
with precision_scope("cuda"):
|
48 |
+
init_latent = root.model.get_first_stage_encoding(root.model.encode_first_stage(args.init_sample))
|
49 |
+
elif args.use_init and args.init_image != None and args.init_image != '':
|
50 |
+
init_image, mask_image = load_img(args.init_image,
|
51 |
+
shape=(args.W, args.H),
|
52 |
+
use_alpha_as_mask=args.use_alpha_as_mask)
|
53 |
+
init_image = init_image.to(root.device)
|
54 |
+
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
55 |
+
with precision_scope("cuda"):
|
56 |
+
init_latent = root.model.get_first_stage_encoding(root.model.encode_first_stage(init_image)) # move to latent space
|
57 |
+
|
58 |
+
if not args.use_init and args.strength > 0 and args.strength_0_no_init:
|
59 |
+
print("\nNo init image, but strength > 0. Strength has been auto set to 0, since use_init is False.")
|
60 |
+
print("If you want to force strength > 0 with no init, please set strength_0_no_init to False.\n")
|
61 |
+
args.strength = 0
|
62 |
+
|
63 |
+
# Mask functions
|
64 |
+
if args.use_mask:
|
65 |
+
assert args.mask_file is not None or mask_image is not None, "use_mask==True: An mask image is required for a mask. Please enter a mask_file or use an init image with an alpha channel"
|
66 |
+
assert args.use_init, "use_mask==True: use_init is required for a mask"
|
67 |
+
assert init_latent is not None, "use_mask==True: An latent init image is required for a mask"
|
68 |
+
|
69 |
+
|
70 |
+
mask = prepare_mask(args.mask_file if mask_image is None else mask_image,
|
71 |
+
init_latent.shape,
|
72 |
+
args.mask_contrast_adjust,
|
73 |
+
args.mask_brightness_adjust,
|
74 |
+
args.invert_mask)
|
75 |
+
|
76 |
+
if (torch.all(mask == 0) or torch.all(mask == 1)) and args.use_alpha_as_mask:
|
77 |
+
raise Warning("use_alpha_as_mask==True: Using the alpha channel from the init image as a mask, but the alpha channel is blank.")
|
78 |
+
|
79 |
+
mask = mask.to(root.device)
|
80 |
+
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
|
81 |
+
else:
|
82 |
+
mask = None
|
83 |
+
|
84 |
+
assert not ( (args.use_mask and args.overlay_mask) and (args.init_sample is None and init_image is None)), "Need an init image when use_mask == True and overlay_mask == True"
|
85 |
+
|
86 |
+
# Init MSE loss image
|
87 |
+
init_mse_image = None
|
88 |
+
if args.init_mse_scale and args.init_mse_image != None and args.init_mse_image != '':
|
89 |
+
init_mse_image, mask_image = load_img(args.init_mse_image,
|
90 |
+
shape=(args.W, args.H),
|
91 |
+
use_alpha_as_mask=args.use_alpha_as_mask)
|
92 |
+
init_mse_image = init_mse_image.to(root.device)
|
93 |
+
init_mse_image = repeat(init_mse_image, '1 ... -> b ...', b=batch_size)
|
94 |
+
|
95 |
+
assert not ( args.init_mse_scale != 0 and (args.init_mse_image is None or args.init_mse_image == '') ), "Need an init image when init_mse_scale != 0"
|
96 |
+
|
97 |
+
t_enc = int((1.0-args.strength) * args.steps)
|
98 |
+
|
99 |
+
# Noise schedule for the k-diffusion samplers (used for masking)
|
100 |
+
k_sigmas = model_wrap.get_sigmas(args.steps)
|
101 |
+
args.clamp_schedule = dict(zip(k_sigmas.tolist(), np.linspace(args.clamp_start,args.clamp_stop,args.steps+1)))
|
102 |
+
k_sigmas = k_sigmas[len(k_sigmas)-t_enc-1:]
|
103 |
+
|
104 |
+
if args.sampler in ['plms','ddim']:
|
105 |
+
sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False)
|
106 |
+
|
107 |
+
if args.colormatch_scale != 0:
|
108 |
+
assert args.colormatch_image is not None, "If using color match loss, colormatch_image is needed"
|
109 |
+
colormatch_image, _ = load_img(args.colormatch_image)
|
110 |
+
colormatch_image = colormatch_image.to('cpu')
|
111 |
+
del(_)
|
112 |
+
else:
|
113 |
+
colormatch_image = None
|
114 |
+
|
115 |
+
# Loss functions
|
116 |
+
if args.init_mse_scale != 0:
|
117 |
+
if args.decode_method == "linear":
|
118 |
+
mse_loss_fn = make_mse_loss(root.model.linear_decode(root.model.get_first_stage_encoding(root.model.encode_first_stage(init_mse_image.to(root.device)))))
|
119 |
+
else:
|
120 |
+
mse_loss_fn = make_mse_loss(init_mse_image)
|
121 |
+
else:
|
122 |
+
mse_loss_fn = None
|
123 |
+
|
124 |
+
if args.colormatch_scale != 0:
|
125 |
+
_,_ = get_color_palette(root, args.colormatch_n_colors, colormatch_image, verbose=True) # display target color palette outside the latent space
|
126 |
+
if args.decode_method == "linear":
|
127 |
+
grad_img_shape = (int(args.W/args.f), int(args.H/args.f))
|
128 |
+
colormatch_image = root.model.linear_decode(root.model.get_first_stage_encoding(root.model.encode_first_stage(colormatch_image.to(root.device))))
|
129 |
+
colormatch_image = colormatch_image.to('cpu')
|
130 |
+
else:
|
131 |
+
grad_img_shape = (args.W, args.H)
|
132 |
+
color_loss_fn = make_rgb_color_match_loss(root,
|
133 |
+
colormatch_image,
|
134 |
+
n_colors=args.colormatch_n_colors,
|
135 |
+
img_shape=grad_img_shape,
|
136 |
+
ignore_sat_weight=args.ignore_sat_weight)
|
137 |
+
else:
|
138 |
+
color_loss_fn = None
|
139 |
+
|
140 |
+
if args.clip_scale != 0:
|
141 |
+
clip_loss_fn = make_clip_loss_fn(root, args)
|
142 |
+
else:
|
143 |
+
clip_loss_fn = None
|
144 |
+
|
145 |
+
if args.aesthetics_scale != 0:
|
146 |
+
aesthetics_loss_fn = make_aesthetics_loss_fn(root, args)
|
147 |
+
else:
|
148 |
+
aesthetics_loss_fn = None
|
149 |
+
|
150 |
+
if args.exposure_scale != 0:
|
151 |
+
exposure_loss_fn = exposure_loss(args.exposure_target)
|
152 |
+
else:
|
153 |
+
exposure_loss_fn = None
|
154 |
+
|
155 |
+
loss_fns_scales = [
|
156 |
+
[clip_loss_fn, args.clip_scale],
|
157 |
+
[blue_loss_fn, args.blue_scale],
|
158 |
+
[mean_loss_fn, args.mean_scale],
|
159 |
+
[exposure_loss_fn, args.exposure_scale],
|
160 |
+
[var_loss_fn, args.var_scale],
|
161 |
+
[mse_loss_fn, args.init_mse_scale],
|
162 |
+
[color_loss_fn, args.colormatch_scale],
|
163 |
+
[aesthetics_loss_fn, args.aesthetics_scale]
|
164 |
+
]
|
165 |
+
|
166 |
+
# Conditioning gradients not implemented for ddim or PLMS
|
167 |
+
assert not( any([cond_fs[1]!=0 for cond_fs in loss_fns_scales]) and (args.sampler in ["ddim","plms"]) ), "Conditioning gradients not implemented for ddim or plms. Please use a different sampler."
|
168 |
+
|
169 |
+
callback = SamplerCallback(args=args,
|
170 |
+
root=root,
|
171 |
+
mask=mask,
|
172 |
+
init_latent=init_latent,
|
173 |
+
sigmas=k_sigmas,
|
174 |
+
sampler=sampler,
|
175 |
+
verbose=False).callback
|
176 |
+
|
177 |
+
clamp_fn = threshold_by(threshold=args.clamp_grad_threshold, threshold_type=args.grad_threshold_type, clamp_schedule=args.clamp_schedule)
|
178 |
+
|
179 |
+
grad_inject_timing_fn = make_inject_timing_fn(args.grad_inject_timing, model_wrap, args.steps)
|
180 |
+
|
181 |
+
cfg_model = CFGDenoiserWithGrad(model_wrap,
|
182 |
+
loss_fns_scales,
|
183 |
+
clamp_fn,
|
184 |
+
args.gradient_wrt,
|
185 |
+
args.gradient_add_to,
|
186 |
+
args.cond_uncond_sync,
|
187 |
+
decode_method=args.decode_method,
|
188 |
+
grad_inject_timing_fn=grad_inject_timing_fn, # option to use grad in only a few of the steps
|
189 |
+
grad_consolidate_fn=None, # function to add grad to image fn(img, grad, sigma)
|
190 |
+
verbose=False)
|
191 |
+
|
192 |
+
results = []
|
193 |
+
with torch.no_grad():
|
194 |
+
with precision_scope("cuda"):
|
195 |
+
with root.model.ema_scope():
|
196 |
+
for prompts in data:
|
197 |
+
if isinstance(prompts, tuple):
|
198 |
+
prompts = list(prompts)
|
199 |
+
if args.prompt_weighting:
|
200 |
+
uc, c = get_uc_and_c(prompts, root.model, args, frame)
|
201 |
+
else:
|
202 |
+
uc = root.model.get_learned_conditioning(batch_size * [""])
|
203 |
+
c = root.model.get_learned_conditioning(prompts)
|
204 |
+
|
205 |
+
|
206 |
+
if args.scale == 1.0:
|
207 |
+
uc = None
|
208 |
+
if args.init_c != None:
|
209 |
+
c = args.init_c
|
210 |
+
|
211 |
+
if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral", "dpm_fast", "dpm_adaptive", "dpmpp_2s_a", "dpmpp_2m"]:
|
212 |
+
samples = sampler_fn(
|
213 |
+
c=c,
|
214 |
+
uc=uc,
|
215 |
+
args=args,
|
216 |
+
model_wrap=cfg_model,
|
217 |
+
init_latent=init_latent,
|
218 |
+
t_enc=t_enc,
|
219 |
+
device=root.device,
|
220 |
+
cb=callback,
|
221 |
+
verbose=False)
|
222 |
+
else:
|
223 |
+
# args.sampler == 'plms' or args.sampler == 'ddim':
|
224 |
+
if init_latent is not None and args.strength > 0:
|
225 |
+
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(root.device))
|
226 |
+
else:
|
227 |
+
z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=root.device)
|
228 |
+
if args.sampler == 'ddim':
|
229 |
+
samples = sampler.decode(z_enc,
|
230 |
+
c,
|
231 |
+
t_enc,
|
232 |
+
unconditional_guidance_scale=args.scale,
|
233 |
+
unconditional_conditioning=uc,
|
234 |
+
img_callback=callback)
|
235 |
+
elif args.sampler == 'plms': # no "decode" function in plms, so use "sample"
|
236 |
+
shape = [args.C, args.H // args.f, args.W // args.f]
|
237 |
+
samples, _ = sampler.sample(S=args.steps,
|
238 |
+
conditioning=c,
|
239 |
+
batch_size=args.n_samples,
|
240 |
+
shape=shape,
|
241 |
+
verbose=False,
|
242 |
+
unconditional_guidance_scale=args.scale,
|
243 |
+
unconditional_conditioning=uc,
|
244 |
+
eta=args.ddim_eta,
|
245 |
+
x_T=z_enc,
|
246 |
+
img_callback=callback)
|
247 |
+
else:
|
248 |
+
raise Exception(f"Sampler {args.sampler} not recognised.")
|
249 |
+
|
250 |
+
|
251 |
+
if return_latent:
|
252 |
+
results.append(samples.clone())
|
253 |
+
|
254 |
+
x_samples = root.model.decode_first_stage(samples)
|
255 |
+
|
256 |
+
if args.use_mask and args.overlay_mask:
|
257 |
+
# Overlay the masked image after the image is generated
|
258 |
+
if args.init_sample_raw is not None:
|
259 |
+
img_original = args.init_sample_raw
|
260 |
+
elif init_image is not None:
|
261 |
+
img_original = init_image
|
262 |
+
else:
|
263 |
+
raise Exception("Cannot overlay the masked image without an init image to overlay")
|
264 |
+
|
265 |
+
if args.mask_sample is None:
|
266 |
+
args.mask_sample = prepare_overlay_mask(args, root, img_original.shape)
|
267 |
+
|
268 |
+
x_samples = img_original * args.mask_sample + x_samples * ((args.mask_sample * -1.0) + 1)
|
269 |
+
|
270 |
+
if return_sample:
|
271 |
+
results.append(x_samples.clone())
|
272 |
+
|
273 |
+
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
274 |
+
|
275 |
+
if return_c:
|
276 |
+
results.append(c.clone())
|
277 |
+
|
278 |
+
for x_sample in x_samples:
|
279 |
+
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
280 |
+
image = Image.fromarray(x_sample.astype(np.uint8))
|
281 |
+
results.append(image)
|
282 |
+
return results
|
deforum-stable-diffusion/helpers/k_samplers.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Optional
|
2 |
+
from k_diffusion.external import CompVisDenoiser
|
3 |
+
from k_diffusion import sampling
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def sampler_fn(
|
8 |
+
c: torch.Tensor,
|
9 |
+
uc: torch.Tensor,
|
10 |
+
args,
|
11 |
+
model_wrap: CompVisDenoiser,
|
12 |
+
init_latent: Optional[torch.Tensor] = None,
|
13 |
+
t_enc: Optional[torch.Tensor] = None,
|
14 |
+
device=torch.device("cpu")
|
15 |
+
if not torch.cuda.is_available()
|
16 |
+
else torch.device("cuda"),
|
17 |
+
cb: Callable[[Any], None] = None,
|
18 |
+
verbose: Optional[bool] = False,
|
19 |
+
) -> torch.Tensor:
|
20 |
+
shape = [args.C, args.H // args.f, args.W // args.f]
|
21 |
+
sigmas: torch.Tensor = model_wrap.get_sigmas(args.steps)
|
22 |
+
sigmas = sigmas[len(sigmas) - t_enc - 1 :]
|
23 |
+
if args.use_init:
|
24 |
+
if len(sigmas) > 0:
|
25 |
+
x = (
|
26 |
+
init_latent
|
27 |
+
+ torch.randn([args.n_samples, *shape], device=device) * sigmas[0]
|
28 |
+
)
|
29 |
+
else:
|
30 |
+
x = init_latent
|
31 |
+
else:
|
32 |
+
if len(sigmas) > 0:
|
33 |
+
x = torch.randn([args.n_samples, *shape], device=device) * sigmas[0]
|
34 |
+
else:
|
35 |
+
x = torch.zeros([args.n_samples, *shape], device=device)
|
36 |
+
sampler_args = {
|
37 |
+
"model": model_wrap,
|
38 |
+
"x": x,
|
39 |
+
"sigmas": sigmas,
|
40 |
+
"extra_args": {"cond": c, "uncond": uc, "cond_scale": args.scale},
|
41 |
+
"disable": False,
|
42 |
+
"callback": cb,
|
43 |
+
}
|
44 |
+
min = sigmas[0].item()
|
45 |
+
max = min
|
46 |
+
for i in sigmas:
|
47 |
+
if i.item() < min and i.item() != 0.0:
|
48 |
+
min = i.item()
|
49 |
+
if args.sampler in ["dpm_fast"]:
|
50 |
+
sampler_args = {
|
51 |
+
"model": model_wrap,
|
52 |
+
"x": x,
|
53 |
+
"sigma_min": min,
|
54 |
+
"sigma_max": max,
|
55 |
+
"extra_args": {"cond": c, "uncond": uc, "cond_scale": args.scale},
|
56 |
+
"disable": False,
|
57 |
+
"callback": cb,
|
58 |
+
"n":args.steps,
|
59 |
+
}
|
60 |
+
elif args.sampler in ["dpm_adaptive"]:
|
61 |
+
sampler_args = {
|
62 |
+
"model": model_wrap,
|
63 |
+
"x": x,
|
64 |
+
"sigma_min": min,
|
65 |
+
"sigma_max": max,
|
66 |
+
"extra_args": {"cond": c, "uncond": uc, "cond_scale": args.scale},
|
67 |
+
"disable": False,
|
68 |
+
"callback": cb,
|
69 |
+
}
|
70 |
+
sampler_map = {
|
71 |
+
"klms": sampling.sample_lms,
|
72 |
+
"dpm2": sampling.sample_dpm_2,
|
73 |
+
"dpm2_ancestral": sampling.sample_dpm_2_ancestral,
|
74 |
+
"heun": sampling.sample_heun,
|
75 |
+
"euler": sampling.sample_euler,
|
76 |
+
"euler_ancestral": sampling.sample_euler_ancestral,
|
77 |
+
"dpm_fast": sampling.sample_dpm_fast,
|
78 |
+
"dpm_adaptive": sampling.sample_dpm_adaptive,
|
79 |
+
"dpmpp_2s_a": sampling.sample_dpmpp_2s_ancestral,
|
80 |
+
"dpmpp_2m": sampling.sample_dpmpp_2m,
|
81 |
+
}
|
82 |
+
|
83 |
+
samples = sampler_map[args.sampler](**sampler_args)
|
84 |
+
return samples
|
85 |
+
|
86 |
+
|
87 |
+
def make_inject_timing_fn(inject_timing, model, steps):
|
88 |
+
"""
|
89 |
+
inject_timing (int or list of ints or list of floats between 0.0 and 1.0):
|
90 |
+
int: compute every inject_timing steps
|
91 |
+
list of floats: compute on these decimal fraction steps (eg, [0.5, 1.0] for 50 steps would be at steps 25 and 50)
|
92 |
+
list of ints: compute on these steps
|
93 |
+
model (CompVisDenoiser)
|
94 |
+
steps (int): number of steps
|
95 |
+
"""
|
96 |
+
all_sigmas = model.get_sigmas(steps)
|
97 |
+
target_sigmas = torch.empty([0], device=all_sigmas.device)
|
98 |
+
|
99 |
+
def timing_fn(sigma):
|
100 |
+
is_conditioning_step = False
|
101 |
+
if sigma in target_sigmas:
|
102 |
+
is_conditioning_step = True
|
103 |
+
return is_conditioning_step
|
104 |
+
|
105 |
+
if inject_timing is None:
|
106 |
+
timing_fn = lambda sigma: True
|
107 |
+
elif isinstance(inject_timing,int) and inject_timing <= steps and inject_timing > 0:
|
108 |
+
# Compute every nth step
|
109 |
+
target_sigma_list = [sigma for i,sigma in enumerate(all_sigmas) if (i+1) % inject_timing == 0]
|
110 |
+
target_sigmas = torch.Tensor(target_sigma_list).to(all_sigmas.device)
|
111 |
+
elif all(isinstance(t,float) for t in inject_timing) and all(t>=0.0 and t<=1.0 for t in inject_timing):
|
112 |
+
# Compute on these steps (expressed as a decimal fraction between 0.0 and 1.0)
|
113 |
+
target_indices = [int(frac_step*steps) if frac_step < 1.0 else steps-1 for frac_step in inject_timing]
|
114 |
+
target_sigma_list = [sigma for i,sigma in enumerate(all_sigmas) if i in target_indices]
|
115 |
+
target_sigmas = torch.Tensor(target_sigma_list).to(all_sigmas.device)
|
116 |
+
elif all(isinstance(t,int) for t in inject_timing) and all(t>0 and t<=steps for t in inject_timing):
|
117 |
+
# Compute on these steps
|
118 |
+
target_sigma_list = [sigma for i,sigma in enumerate(all_sigmas) if i+1 in inject_timing]
|
119 |
+
target_sigmas = torch.Tensor(target_sigma_list).to(all_sigmas.device)
|
120 |
+
|
121 |
+
else:
|
122 |
+
raise Exception(f"Not a valid input: inject_timing={inject_timing}\n" +
|
123 |
+
f"Must be an int, list of all ints (between step 1 and {steps}), or list of all floats between 0.0 and 1.0")
|
124 |
+
return timing_fn
|
deforum-stable-diffusion/helpers/load_images.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import requests
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import torchvision.transforms.functional as TF
|
6 |
+
from einops import repeat
|
7 |
+
from scipy.ndimage import gaussian_filter
|
8 |
+
|
9 |
+
def load_img(path, shape=None, use_alpha_as_mask=False):
|
10 |
+
# use_alpha_as_mask: Read the alpha channel of the image as the mask image
|
11 |
+
if path.startswith('http://') or path.startswith('https://'):
|
12 |
+
image = Image.open(requests.get(path, stream=True).raw)
|
13 |
+
else:
|
14 |
+
image = Image.open(path)
|
15 |
+
|
16 |
+
if use_alpha_as_mask:
|
17 |
+
image = image.convert('RGBA')
|
18 |
+
else:
|
19 |
+
image = image.convert('RGB')
|
20 |
+
|
21 |
+
if shape is not None:
|
22 |
+
image = image.resize(shape, resample=Image.LANCZOS)
|
23 |
+
|
24 |
+
mask_image = None
|
25 |
+
if use_alpha_as_mask:
|
26 |
+
# Split alpha channel into a mask_image
|
27 |
+
red, green, blue, alpha = Image.Image.split(image)
|
28 |
+
mask_image = alpha.convert('L')
|
29 |
+
image = image.convert('RGB')
|
30 |
+
|
31 |
+
image = np.array(image).astype(np.float16) / 255.0
|
32 |
+
image = image[None].transpose(0, 3, 1, 2)
|
33 |
+
image = torch.from_numpy(image)
|
34 |
+
image = 2.*image - 1.
|
35 |
+
|
36 |
+
return image, mask_image
|
37 |
+
|
38 |
+
def load_mask_latent(mask_input, shape):
|
39 |
+
# mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object
|
40 |
+
# shape (list-like len(4)): shape of the image to match, usually latent_image.shape
|
41 |
+
|
42 |
+
if isinstance(mask_input, str): # mask input is probably a file name
|
43 |
+
if mask_input.startswith('http://') or mask_input.startswith('https://'):
|
44 |
+
mask_image = Image.open(requests.get(mask_input, stream=True).raw).convert('RGBA')
|
45 |
+
else:
|
46 |
+
mask_image = Image.open(mask_input).convert('RGBA')
|
47 |
+
elif isinstance(mask_input, Image.Image):
|
48 |
+
mask_image = mask_input
|
49 |
+
else:
|
50 |
+
raise Exception("mask_input must be a PIL image or a file name")
|
51 |
+
|
52 |
+
mask_w_h = (shape[-1], shape[-2])
|
53 |
+
mask = mask_image.resize(mask_w_h, resample=Image.LANCZOS)
|
54 |
+
mask = mask.convert("L")
|
55 |
+
return mask
|
56 |
+
|
57 |
+
def prepare_mask(mask_input, mask_shape, mask_brightness_adjust=1.0, mask_contrast_adjust=1.0, invert_mask=False):
|
58 |
+
# mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object
|
59 |
+
# shape (list-like len(4)): shape of the image to match, usually latent_image.shape
|
60 |
+
# mask_brightness_adjust (non-negative float): amount to adjust brightness of the iamge,
|
61 |
+
# 0 is black, 1 is no adjustment, >1 is brighter
|
62 |
+
# mask_contrast_adjust (non-negative float): amount to adjust contrast of the image,
|
63 |
+
# 0 is a flat grey image, 1 is no adjustment, >1 is more contrast
|
64 |
+
|
65 |
+
mask = load_mask_latent(mask_input, mask_shape)
|
66 |
+
|
67 |
+
# Mask brightness/contrast adjustments
|
68 |
+
if mask_brightness_adjust != 1:
|
69 |
+
mask = TF.adjust_brightness(mask, mask_brightness_adjust)
|
70 |
+
if mask_contrast_adjust != 1:
|
71 |
+
mask = TF.adjust_contrast(mask, mask_contrast_adjust)
|
72 |
+
|
73 |
+
# Mask image to array
|
74 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
75 |
+
mask = np.tile(mask,(4,1,1))
|
76 |
+
mask = np.expand_dims(mask,axis=0)
|
77 |
+
mask = torch.from_numpy(mask)
|
78 |
+
|
79 |
+
if invert_mask:
|
80 |
+
mask = ( (mask - 0.5) * -1) + 0.5
|
81 |
+
|
82 |
+
mask = np.clip(mask,0,1)
|
83 |
+
return mask
|
84 |
+
|
85 |
+
def prepare_overlay_mask(args, root, mask_shape):
|
86 |
+
mask_fullres = prepare_mask(args.mask_file,
|
87 |
+
mask_shape,
|
88 |
+
args.mask_contrast_adjust,
|
89 |
+
args.mask_brightness_adjust,
|
90 |
+
args.invert_mask)
|
91 |
+
mask_fullres = mask_fullres[:,:3,:,:]
|
92 |
+
mask_fullres = repeat(mask_fullres, '1 ... -> b ...', b=args.n_samples)
|
93 |
+
|
94 |
+
mask_fullres[mask_fullres < mask_fullres.max()] = 0
|
95 |
+
mask_fullres = gaussian_filter(mask_fullres, args.mask_overlay_blur)
|
96 |
+
mask_fullres = torch.Tensor(mask_fullres).to(root.device)
|
97 |
+
return mask_fullres
|
98 |
+
|
99 |
+
|
deforum-stable-diffusion/helpers/model_load.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
# Decodes the image without passing through the upscaler. The resulting image will be the same size as the latent
|
5 |
+
# Thanks to Kevin Turner (https://github.com/keturn) we have a shortcut to look at the decoded image!
|
6 |
+
def make_linear_decode(model_version, device='cuda:0'):
|
7 |
+
v1_4_rgb_latent_factors = [
|
8 |
+
# R G B
|
9 |
+
[ 0.298, 0.207, 0.208], # L1
|
10 |
+
[ 0.187, 0.286, 0.173], # L2
|
11 |
+
[-0.158, 0.189, 0.264], # L3
|
12 |
+
[-0.184, -0.271, -0.473], # L4
|
13 |
+
]
|
14 |
+
|
15 |
+
if model_version[:5] == "sd-v1":
|
16 |
+
rgb_latent_factors = torch.Tensor(v1_4_rgb_latent_factors).to(device)
|
17 |
+
else:
|
18 |
+
raise Exception(f"Model name {model_version} not recognized.")
|
19 |
+
|
20 |
+
def linear_decode(latent):
|
21 |
+
latent_image = latent.permute(0, 2, 3, 1) @ rgb_latent_factors
|
22 |
+
latent_image = latent_image.permute(0, 3, 1, 2)
|
23 |
+
return latent_image
|
24 |
+
|
25 |
+
return linear_decode
|
26 |
+
|
27 |
+
def load_model(root, load_on_run_all=True, check_sha256=True):
|
28 |
+
|
29 |
+
import requests
|
30 |
+
import torch
|
31 |
+
from ldm.util import instantiate_from_config
|
32 |
+
from omegaconf import OmegaConf
|
33 |
+
from transformers import logging
|
34 |
+
logging.set_verbosity_error()
|
35 |
+
|
36 |
+
try:
|
37 |
+
ipy = get_ipython()
|
38 |
+
except:
|
39 |
+
ipy = 'could not get_ipython'
|
40 |
+
|
41 |
+
if 'google.colab' in str(ipy):
|
42 |
+
path_extend = "deforum-stable-diffusion"
|
43 |
+
else:
|
44 |
+
path_extend = ""
|
45 |
+
|
46 |
+
model_map = {
|
47 |
+
"512-base-ema.ckpt": {
|
48 |
+
'sha256': 'd635794c1fedfdfa261e065370bea59c651fc9bfa65dc6d67ad29e11869a1824',
|
49 |
+
'url': 'https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt',
|
50 |
+
'requires_login': True,
|
51 |
+
},
|
52 |
+
"v1-5-pruned.ckpt": {
|
53 |
+
'sha256': 'e1441589a6f3c5a53f5f54d0975a18a7feb7cdf0b0dee276dfc3331ae376a053',
|
54 |
+
'url': 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt',
|
55 |
+
'requires_login': True,
|
56 |
+
},
|
57 |
+
"v1-5-pruned-emaonly.ckpt": {
|
58 |
+
'sha256': 'cc6cb27103417325ff94f52b7a5d2dde45a7515b25c255d8e396c90014281516',
|
59 |
+
'url': 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt',
|
60 |
+
'requires_login': True,
|
61 |
+
},
|
62 |
+
"sd-v1-4-full-ema.ckpt": {
|
63 |
+
'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a',
|
64 |
+
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-4-full-ema.ckpt',
|
65 |
+
'requires_login': True,
|
66 |
+
},
|
67 |
+
"sd-v1-4.ckpt": {
|
68 |
+
'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556',
|
69 |
+
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt',
|
70 |
+
'requires_login': True,
|
71 |
+
},
|
72 |
+
"sd-v1-3-full-ema.ckpt": {
|
73 |
+
'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca',
|
74 |
+
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/blob/main/sd-v1-3-full-ema.ckpt',
|
75 |
+
'requires_login': True,
|
76 |
+
},
|
77 |
+
"sd-v1-3.ckpt": {
|
78 |
+
'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f',
|
79 |
+
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/resolve/main/sd-v1-3.ckpt',
|
80 |
+
'requires_login': True,
|
81 |
+
},
|
82 |
+
"sd-v1-2-full-ema.ckpt": {
|
83 |
+
'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a',
|
84 |
+
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-2-full-ema.ckpt',
|
85 |
+
'requires_login': True,
|
86 |
+
},
|
87 |
+
"sd-v1-2.ckpt": {
|
88 |
+
'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d',
|
89 |
+
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/resolve/main/sd-v1-2.ckpt',
|
90 |
+
'requires_login': True,
|
91 |
+
},
|
92 |
+
"sd-v1-1-full-ema.ckpt": {
|
93 |
+
'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829',
|
94 |
+
'url':'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1-full-ema.ckpt',
|
95 |
+
'requires_login': True,
|
96 |
+
},
|
97 |
+
"sd-v1-1.ckpt": {
|
98 |
+
'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea',
|
99 |
+
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1.ckpt',
|
100 |
+
'requires_login': True,
|
101 |
+
},
|
102 |
+
"robo-diffusion-v1.ckpt": {
|
103 |
+
'sha256': '244dbe0dcb55c761bde9c2ac0e9b46cc9705ebfe5f1f3a7cc46251573ea14e16',
|
104 |
+
'url': 'https://huggingface.co/nousr/robo-diffusion/resolve/main/models/robo-diffusion-v1.ckpt',
|
105 |
+
'requires_login': False,
|
106 |
+
},
|
107 |
+
"wd-v1-3-float16.ckpt": {
|
108 |
+
'sha256': '4afab9126057859b34d13d6207d90221d0b017b7580469ea70cee37757a29edd',
|
109 |
+
'url': 'https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt',
|
110 |
+
'requires_login': False,
|
111 |
+
},
|
112 |
+
}
|
113 |
+
|
114 |
+
# config path
|
115 |
+
ckpt_config_path = root.custom_config_path if root.model_config == "custom" else os.path.join(root.configs_path, root.model_config)
|
116 |
+
|
117 |
+
if os.path.exists(ckpt_config_path):
|
118 |
+
print(f"{ckpt_config_path} exists")
|
119 |
+
else:
|
120 |
+
print(f"Warning: {ckpt_config_path} does not exist.")
|
121 |
+
ckpt_config_path = os.path.join(path_extend,"configs",root.model_config)
|
122 |
+
print(f"Using {ckpt_config_path} instead.")
|
123 |
+
|
124 |
+
ckpt_config_path = os.path.abspath(ckpt_config_path)
|
125 |
+
|
126 |
+
# checkpoint path or download
|
127 |
+
ckpt_path = root.custom_checkpoint_path if root.model_checkpoint == "custom" else os.path.join(root.models_path, root.model_checkpoint)
|
128 |
+
ckpt_valid = True
|
129 |
+
|
130 |
+
if os.path.exists(ckpt_path):
|
131 |
+
pass
|
132 |
+
elif 'url' in model_map[root.model_checkpoint]:
|
133 |
+
url = model_map[root.model_checkpoint]['url']
|
134 |
+
|
135 |
+
# CLI dialogue to authenticate download
|
136 |
+
if model_map[root.model_checkpoint]['requires_login']:
|
137 |
+
print("This model requires an authentication token")
|
138 |
+
print("Please ensure you have accepted the terms of service before continuing.")
|
139 |
+
|
140 |
+
username = input("[What is your huggingface username?]: ")
|
141 |
+
token = input("[What is your huggingface token?]: ")
|
142 |
+
|
143 |
+
_, path = url.split("https://")
|
144 |
+
|
145 |
+
url = f"https://{username}:{token}@{path}"
|
146 |
+
|
147 |
+
# contact server for model
|
148 |
+
print(f"..attempting to download {root.model_checkpoint}...this may take a while")
|
149 |
+
ckpt_request = requests.get(url)
|
150 |
+
request_status = ckpt_request.status_code
|
151 |
+
|
152 |
+
# inform user of errors
|
153 |
+
if request_status == 403:
|
154 |
+
raise ConnectionRefusedError("You have not accepted the license for this model.")
|
155 |
+
elif request_status == 404:
|
156 |
+
raise ConnectionError("Could not make contact with server")
|
157 |
+
elif request_status != 200:
|
158 |
+
raise ConnectionError(f"Some other error has ocurred - response code: {request_status}")
|
159 |
+
|
160 |
+
# write to model path
|
161 |
+
with open(os.path.join(root.models_path, root.model_checkpoint), 'wb') as model_file:
|
162 |
+
model_file.write(ckpt_request.content)
|
163 |
+
else:
|
164 |
+
print(f"Please download model checkpoint and place in {os.path.join(root.models_path, root.model_checkpoint)}")
|
165 |
+
ckpt_valid = False
|
166 |
+
|
167 |
+
print(f"config_path: {ckpt_config_path}")
|
168 |
+
print(f"ckpt_path: {ckpt_path}")
|
169 |
+
|
170 |
+
if check_sha256 and root.model_checkpoint != "custom" and ckpt_valid:
|
171 |
+
try:
|
172 |
+
import hashlib
|
173 |
+
print("..checking sha256")
|
174 |
+
with open(ckpt_path, "rb") as f:
|
175 |
+
bytes = f.read()
|
176 |
+
hash = hashlib.sha256(bytes).hexdigest()
|
177 |
+
del bytes
|
178 |
+
if model_map[root.model_checkpoint]["sha256"] == hash:
|
179 |
+
print("..hash is correct")
|
180 |
+
else:
|
181 |
+
print("..hash in not correct")
|
182 |
+
ckpt_valid = False
|
183 |
+
except:
|
184 |
+
print("..could not verify model integrity")
|
185 |
+
|
186 |
+
def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True,print_flag=False):
|
187 |
+
map_location = "cuda" # ["cpu", "cuda"]
|
188 |
+
print(f"..loading model")
|
189 |
+
pl_sd = torch.load(ckpt, map_location=map_location)
|
190 |
+
if "global_step" in pl_sd:
|
191 |
+
if print_flag:
|
192 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
193 |
+
sd = pl_sd["state_dict"]
|
194 |
+
model = instantiate_from_config(config.model)
|
195 |
+
m, u = model.load_state_dict(sd, strict=False)
|
196 |
+
if print_flag:
|
197 |
+
if len(m) > 0 and verbose:
|
198 |
+
print("missing keys:")
|
199 |
+
print(m)
|
200 |
+
if len(u) > 0 and verbose:
|
201 |
+
print("unexpected keys:")
|
202 |
+
print(u)
|
203 |
+
|
204 |
+
if half_precision:
|
205 |
+
model = model.half().to(device)
|
206 |
+
else:
|
207 |
+
model = model.to(device)
|
208 |
+
model.eval()
|
209 |
+
return model
|
210 |
+
|
211 |
+
if load_on_run_all and ckpt_valid:
|
212 |
+
local_config = OmegaConf.load(f"{ckpt_config_path}")
|
213 |
+
model = load_model_from_config(local_config, f"{ckpt_path}", half_precision=root.half_precision)
|
214 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
215 |
+
model = model.to(device)
|
216 |
+
|
217 |
+
autoencoder_version = "sd-v1" #TODO this will be different for different models
|
218 |
+
model.linear_decode = make_linear_decode(autoencoder_version, device)
|
219 |
+
|
220 |
+
return model, device
|
221 |
+
|
222 |
+
|
223 |
+
def get_model_output_paths(root):
|
224 |
+
|
225 |
+
models_path = root.models_path
|
226 |
+
output_path = root.output_path
|
227 |
+
|
228 |
+
#@markdown **Google Drive Path Variables (Optional)**
|
229 |
+
|
230 |
+
force_remount = False
|
231 |
+
|
232 |
+
try:
|
233 |
+
ipy = get_ipython()
|
234 |
+
except:
|
235 |
+
ipy = 'could not get_ipython'
|
236 |
+
|
237 |
+
if 'google.colab' in str(ipy):
|
238 |
+
if root.mount_google_drive:
|
239 |
+
from google.colab import drive # type: ignore
|
240 |
+
try:
|
241 |
+
drive_path = "/content/drive"
|
242 |
+
drive.mount(drive_path,force_remount=force_remount)
|
243 |
+
models_path = root.models_path_gdrive
|
244 |
+
output_path = root.output_path_gdrive
|
245 |
+
except:
|
246 |
+
print("..error mounting drive or with drive path variables")
|
247 |
+
print("..reverting to default path variables")
|
248 |
+
|
249 |
+
models_path = os.path.abspath(models_path)
|
250 |
+
output_path = os.path.abspath(output_path)
|
251 |
+
os.makedirs(models_path, exist_ok=True)
|
252 |
+
os.makedirs(output_path, exist_ok=True)
|
253 |
+
|
254 |
+
print(f"models_path: {models_path}")
|
255 |
+
print(f"output_path: {output_path}")
|
256 |
+
|
257 |
+
return models_path, output_path
|
deforum-stable-diffusion/helpers/model_wrap.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from k_diffusion import utils as k_utils
|
3 |
+
import torch
|
4 |
+
from k_diffusion.external import CompVisDenoiser
|
5 |
+
from torchvision.utils import make_grid
|
6 |
+
from IPython import display
|
7 |
+
from torchvision.transforms.functional import to_pil_image
|
8 |
+
|
9 |
+
class CFGDenoiser(nn.Module):
|
10 |
+
def __init__(self, model):
|
11 |
+
super().__init__()
|
12 |
+
self.inner_model = model
|
13 |
+
|
14 |
+
def forward(self, x, sigma, uncond, cond, cond_scale):
|
15 |
+
x_in = torch.cat([x] * 2)
|
16 |
+
sigma_in = torch.cat([sigma] * 2)
|
17 |
+
cond_in = torch.cat([uncond, cond])
|
18 |
+
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
19 |
+
return uncond + (cond - uncond) * cond_scale
|
20 |
+
|
21 |
+
class CFGDenoiserWithGrad(CompVisDenoiser):
|
22 |
+
def __init__(self, model,
|
23 |
+
loss_fns_scales, # List of [cond_function, scale] pairs
|
24 |
+
clamp_func=None, # Gradient clamping function, clamp_func(grad, sigma)
|
25 |
+
gradient_wrt=None, # Calculate gradient with respect to ["x", "x0_pred", "both"]
|
26 |
+
gradient_add_to=None, # Add gradient to ["cond", "uncond", "both"]
|
27 |
+
cond_uncond_sync=True, # Calculates the cond and uncond simultaneously
|
28 |
+
decode_method=None, # Function used to decode the latent during gradient calculation
|
29 |
+
grad_inject_timing_fn=None, # Option to use grad in only a few of the steps
|
30 |
+
grad_consolidate_fn=None, # Function to add grad to image fn(img, grad, sigma)
|
31 |
+
verbose=False):
|
32 |
+
super().__init__(model.inner_model)
|
33 |
+
self.inner_model = model
|
34 |
+
self.cond_uncond_sync = cond_uncond_sync
|
35 |
+
|
36 |
+
# Initialize gradient calculation variables
|
37 |
+
self.clamp_func = clamp_func
|
38 |
+
self.gradient_add_to = gradient_add_to
|
39 |
+
if gradient_wrt is None:
|
40 |
+
self.gradient_wrt = 'x'
|
41 |
+
self.gradient_wrt = gradient_wrt
|
42 |
+
if decode_method is None:
|
43 |
+
decode_fn = lambda x: x
|
44 |
+
elif decode_method == "autoencoder":
|
45 |
+
decode_fn = model.inner_model.differentiable_decode_first_stage
|
46 |
+
elif decode_method == "linear":
|
47 |
+
decode_fn = model.inner_model.linear_decode
|
48 |
+
self.decode_fn = decode_fn
|
49 |
+
|
50 |
+
# Parse loss function-scale pairs
|
51 |
+
cond_fns = []
|
52 |
+
for loss_fn,scale in loss_fns_scales:
|
53 |
+
if scale != 0:
|
54 |
+
cond_fn = self.make_cond_fn(loss_fn, scale)
|
55 |
+
else:
|
56 |
+
cond_fn = None
|
57 |
+
cond_fns += [cond_fn]
|
58 |
+
self.cond_fns = cond_fns
|
59 |
+
|
60 |
+
if grad_inject_timing_fn is None:
|
61 |
+
self.grad_inject_timing_fn = lambda sigma: True
|
62 |
+
else:
|
63 |
+
self.grad_inject_timing_fn = grad_inject_timing_fn
|
64 |
+
if grad_consolidate_fn is None:
|
65 |
+
self.grad_consolidate_fn = lambda img, grad, sigma: img + grad * sigma
|
66 |
+
else:
|
67 |
+
self.grad_consolidate_fn = grad_consolidate_fn
|
68 |
+
|
69 |
+
self.verbose = verbose
|
70 |
+
self.verbose_print = print if self.verbose else lambda *args, **kwargs: None
|
71 |
+
|
72 |
+
|
73 |
+
# General denoising model with gradient conditioning
|
74 |
+
def cond_model_fn_(self, x, sigma, inner_model=None, **kwargs):
|
75 |
+
|
76 |
+
# inner_model: optionally use a different inner_model function or a wrapper function around inner_model, see self.forward._cfg_model
|
77 |
+
if inner_model is None:
|
78 |
+
inner_model = self.inner_model
|
79 |
+
|
80 |
+
total_cond_grad = torch.zeros_like(x)
|
81 |
+
for cond_fn in self.cond_fns:
|
82 |
+
if cond_fn is None: continue
|
83 |
+
|
84 |
+
# Gradient with respect to x
|
85 |
+
if self.gradient_wrt == 'x':
|
86 |
+
with torch.enable_grad():
|
87 |
+
x = x.detach().requires_grad_()
|
88 |
+
denoised = inner_model(x, sigma, **kwargs)
|
89 |
+
cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
|
90 |
+
|
91 |
+
# Gradient wrt x0_pred, so save some compute: don't record grad until after denoised is calculated
|
92 |
+
elif self.gradient_wrt == 'x0_pred':
|
93 |
+
with torch.no_grad():
|
94 |
+
denoised = inner_model(x, sigma, **kwargs)
|
95 |
+
with torch.enable_grad():
|
96 |
+
cond_grad = cond_fn(x, sigma, denoised=denoised.detach().requires_grad_(), **kwargs).detach()
|
97 |
+
total_cond_grad += cond_grad
|
98 |
+
|
99 |
+
total_cond_grad = torch.nan_to_num(total_cond_grad, nan=0.0, posinf=float('inf'), neginf=-float('inf'))
|
100 |
+
|
101 |
+
# Clamp the gradient
|
102 |
+
total_cond_grad = self.clamp_grad_verbose(total_cond_grad, sigma)
|
103 |
+
|
104 |
+
# Add gradient to the image
|
105 |
+
if self.gradient_wrt == 'x':
|
106 |
+
x.copy_(self.grad_consolidate_fn(x.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim)))
|
107 |
+
cond_denoised = inner_model(x, sigma, **kwargs)
|
108 |
+
elif self.gradient_wrt == 'x0_pred':
|
109 |
+
x.copy_(self.grad_consolidate_fn(x.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim)))
|
110 |
+
cond_denoised = self.grad_consolidate_fn(denoised.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim))
|
111 |
+
|
112 |
+
return cond_denoised
|
113 |
+
|
114 |
+
def forward(self, x, sigma, uncond, cond, cond_scale):
|
115 |
+
|
116 |
+
def _cfg_model(x, sigma, cond, **kwargs):
|
117 |
+
# Wrapper to add denoised cond and uncond as in a cfg model
|
118 |
+
# input "cond" is both cond and uncond weights: torch.cat([uncond, cond])
|
119 |
+
x_in = torch.cat([x] * 2)
|
120 |
+
sigma_in = torch.cat([sigma] * 2)
|
121 |
+
|
122 |
+
denoised = self.inner_model(x_in, sigma_in, cond=cond, **kwargs)
|
123 |
+
uncond_x0, cond_x0 = denoised.chunk(2)
|
124 |
+
x0_pred = uncond_x0 + (cond_x0 - uncond_x0) * cond_scale
|
125 |
+
return x0_pred
|
126 |
+
|
127 |
+
# Conditioning
|
128 |
+
if self.check_conditioning_schedule(sigma):
|
129 |
+
# Apply the conditioning gradient to the completed denoised (after both cond and uncond are combined into the diffused image)
|
130 |
+
if self.cond_uncond_sync:
|
131 |
+
# x0 = self.cfg_cond_model_fn_(x, sigma, uncond=uncond, cond=cond, cond_scale=cond_scale)
|
132 |
+
cond_in = torch.cat([uncond, cond])
|
133 |
+
x0 = self.cond_model_fn_(x, sigma, cond=cond_in, inner_model=_cfg_model)
|
134 |
+
|
135 |
+
# Calculate cond and uncond separately
|
136 |
+
else:
|
137 |
+
if self.gradient_add_to == "uncond":
|
138 |
+
uncond = self.cond_model_fn_(x, sigma, cond=uncond)
|
139 |
+
cond = self.inner_model(x, sigma, cond=cond)
|
140 |
+
x0 = uncond + (cond - uncond) * cond_scale
|
141 |
+
elif self.gradient_add_to == "cond":
|
142 |
+
uncond = self.inner_model(x, sigma, cond=uncond)
|
143 |
+
cond = self.cond_model_fn_(x, sigma, cond=cond)
|
144 |
+
x0 = uncond + (cond - uncond) * cond_scale
|
145 |
+
elif self.gradient_add_to == "both":
|
146 |
+
uncond = self.cond_model_fn_(x, sigma, cond=uncond)
|
147 |
+
cond = self.cond_model_fn_(x, sigma, cond=cond)
|
148 |
+
x0 = uncond + (cond - uncond) * cond_scale
|
149 |
+
else:
|
150 |
+
raise Exception(f"Unrecognised option for gradient_add_to: {self.gradient_add_to}")
|
151 |
+
|
152 |
+
# No conditioning
|
153 |
+
else:
|
154 |
+
# calculate cond and uncond simultaneously
|
155 |
+
if self.cond_uncond_sync:
|
156 |
+
cond_in = torch.cat([uncond, cond])
|
157 |
+
x0 = _cfg_model(x, sigma, cond=cond_in)
|
158 |
+
else:
|
159 |
+
uncond = self.inner_model(x, sigma, cond=uncond)
|
160 |
+
cond = self.inner_model(x, sigma, cond=cond)
|
161 |
+
x0 = uncond + (cond - uncond) * cond_scale
|
162 |
+
|
163 |
+
return x0
|
164 |
+
|
165 |
+
def make_cond_fn(self, loss_fn, scale):
|
166 |
+
# Turns a loss function into a cond function that is applied to the decoded RGB sample
|
167 |
+
# loss_fn (function): func(x, sigma, denoised) -> number
|
168 |
+
# scale (number): how much this loss is applied to the image
|
169 |
+
|
170 |
+
# Cond function with respect to x
|
171 |
+
def cond_fn(x, sigma, denoised, **kwargs):
|
172 |
+
with torch.enable_grad():
|
173 |
+
denoised_sample = self.decode_fn(denoised).requires_grad_()
|
174 |
+
loss = loss_fn(denoised_sample, sigma, **kwargs) * scale
|
175 |
+
grad = -torch.autograd.grad(loss, x)[0]
|
176 |
+
self.verbose_print('Loss:', loss.item())
|
177 |
+
return grad
|
178 |
+
|
179 |
+
# Cond function with respect to x0_pred
|
180 |
+
def cond_fn_pred(x, sigma, denoised, **kwargs):
|
181 |
+
with torch.enable_grad():
|
182 |
+
denoised_sample = self.decode_fn(denoised).requires_grad_()
|
183 |
+
loss = loss_fn(denoised_sample, sigma, **kwargs) * scale
|
184 |
+
grad = -torch.autograd.grad(loss, denoised)[0]
|
185 |
+
self.verbose_print('Loss:', loss.item())
|
186 |
+
return grad
|
187 |
+
|
188 |
+
if self.gradient_wrt == 'x':
|
189 |
+
return cond_fn
|
190 |
+
elif self.gradient_wrt == 'x0_pred':
|
191 |
+
return cond_fn_pred
|
192 |
+
else:
|
193 |
+
raise Exception(f"Variable gradient_wrt == {self.gradient_wrt} not recognised.")
|
194 |
+
|
195 |
+
def clamp_grad_verbose(self, grad, sigma):
|
196 |
+
if self.clamp_func is not None:
|
197 |
+
if self.verbose:
|
198 |
+
print("Grad before clamping:")
|
199 |
+
self.display_samples(torch.abs(grad*2.0) - 1.0)
|
200 |
+
grad = self.clamp_func(grad, sigma)
|
201 |
+
if self.verbose:
|
202 |
+
print("Conditioning gradient")
|
203 |
+
self.display_samples(torch.abs(grad*2.0) - 1.0)
|
204 |
+
return grad
|
205 |
+
|
206 |
+
def check_conditioning_schedule(self, sigma):
|
207 |
+
is_conditioning_step = False
|
208 |
+
|
209 |
+
if (self.cond_fns is not None and
|
210 |
+
any(cond_fn is not None for cond_fn in self.cond_fns)):
|
211 |
+
# Conditioning strength != 0
|
212 |
+
# Check if this is a conditioning step
|
213 |
+
if self.grad_inject_timing_fn(sigma):
|
214 |
+
is_conditioning_step = True
|
215 |
+
|
216 |
+
if self.verbose:
|
217 |
+
print(f"Conditioning step for sigma={sigma}")
|
218 |
+
|
219 |
+
return is_conditioning_step
|
220 |
+
|
221 |
+
def display_samples(self, images):
|
222 |
+
images = images.double().cpu().add(1).div(2).clamp(0, 1)
|
223 |
+
images = torch.tensor(images.numpy())
|
224 |
+
grid = make_grid(images, 4).cpu()
|
225 |
+
display.display(to_pil_image(grid))
|
226 |
+
return
|
deforum-stable-diffusion/helpers/prompt.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
def sanitize(prompt):
|
4 |
+
whitelist = set('abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ')
|
5 |
+
tmp = ''.join(filter(whitelist.__contains__, prompt))
|
6 |
+
return tmp.replace(' ', '_')
|
7 |
+
|
8 |
+
def check_is_number(value):
|
9 |
+
float_pattern = r'^(?=.)([+-]?([0-9]*)(\.([0-9]+))?)$'
|
10 |
+
return re.match(float_pattern, value)
|
11 |
+
|
12 |
+
# prompt weighting with colons and number coefficients (like 'bacon:0.75 eggs:0.25')
|
13 |
+
# borrowed from https://github.com/kylewlacy/stable-diffusion/blob/0a4397094eb6e875f98f9d71193e350d859c4220/ldm/dream/conditioning.py
|
14 |
+
# and https://github.com/raefu/stable-diffusion-automatic/blob/unstablediffusion/modules/processing.py
|
15 |
+
def get_uc_and_c(prompts, model, args, frame = 0):
|
16 |
+
prompt = prompts[0] # they are the same in a batch anyway
|
17 |
+
|
18 |
+
# get weighted sub-prompts
|
19 |
+
negative_subprompts, positive_subprompts = split_weighted_subprompts(
|
20 |
+
prompt, frame, not args.normalize_prompt_weights
|
21 |
+
)
|
22 |
+
|
23 |
+
uc = get_learned_conditioning(model, negative_subprompts, "", args, -1)
|
24 |
+
c = get_learned_conditioning(model, positive_subprompts, prompt, args, 1)
|
25 |
+
|
26 |
+
return (uc, c)
|
27 |
+
|
28 |
+
def get_learned_conditioning(model, weighted_subprompts, text, args, sign = 1):
|
29 |
+
if len(weighted_subprompts) < 1:
|
30 |
+
log_tokenization(text, model, args.log_weighted_subprompts, sign)
|
31 |
+
c = model.get_learned_conditioning(args.n_samples * [text])
|
32 |
+
else:
|
33 |
+
c = None
|
34 |
+
for subtext, subweight in weighted_subprompts:
|
35 |
+
log_tokenization(subtext, model, args.log_weighted_subprompts, sign * subweight)
|
36 |
+
if c is None:
|
37 |
+
c = model.get_learned_conditioning(args.n_samples * [subtext])
|
38 |
+
c *= subweight
|
39 |
+
else:
|
40 |
+
c.add_(model.get_learned_conditioning(args.n_samples * [subtext]), alpha=subweight)
|
41 |
+
|
42 |
+
return c
|
43 |
+
|
44 |
+
def parse_weight(match, frame = 0)->float:
|
45 |
+
import numexpr
|
46 |
+
w_raw = match.group("weight")
|
47 |
+
if w_raw == None:
|
48 |
+
return 1
|
49 |
+
if check_is_number(w_raw):
|
50 |
+
return float(w_raw)
|
51 |
+
else:
|
52 |
+
t = frame
|
53 |
+
if len(w_raw) < 3:
|
54 |
+
print('the value inside `-characters cannot represent a math function')
|
55 |
+
return 1
|
56 |
+
return float(numexpr.evaluate(w_raw[1:-1]))
|
57 |
+
|
58 |
+
def normalize_prompt_weights(parsed_prompts):
|
59 |
+
if len(parsed_prompts) == 0:
|
60 |
+
return parsed_prompts
|
61 |
+
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
62 |
+
if weight_sum == 0:
|
63 |
+
print(
|
64 |
+
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
65 |
+
equal_weight = 1 / max(len(parsed_prompts), 1)
|
66 |
+
return [(x[0], equal_weight) for x in parsed_prompts]
|
67 |
+
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
68 |
+
|
69 |
+
def split_weighted_subprompts(text, frame = 0, skip_normalize=False):
|
70 |
+
"""
|
71 |
+
grabs all text up to the first occurrence of ':'
|
72 |
+
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
73 |
+
if ':' has no value defined, defaults to 1.0
|
74 |
+
repeats until no text remaining
|
75 |
+
"""
|
76 |
+
prompt_parser = re.compile("""
|
77 |
+
(?P<prompt> # capture group for 'prompt'
|
78 |
+
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
79 |
+
) # end 'prompt'
|
80 |
+
(?: # non-capture group
|
81 |
+
:+ # match one or more ':' characters
|
82 |
+
(?P<weight>(( # capture group for 'weight'
|
83 |
+
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
84 |
+
)|( # or
|
85 |
+
`[\S\s]*?`# a math function
|
86 |
+
)))? # end weight capture group, make optional
|
87 |
+
\s* # strip spaces after weight
|
88 |
+
| # OR
|
89 |
+
$ # else, if no ':' then match end of line
|
90 |
+
) # end non-capture group
|
91 |
+
""", re.VERBOSE)
|
92 |
+
negative_prompts = []
|
93 |
+
positive_prompts = []
|
94 |
+
for match in re.finditer(prompt_parser, text):
|
95 |
+
w = parse_weight(match, frame)
|
96 |
+
if w < 0:
|
97 |
+
# negating the sign as we'll feed this to uc
|
98 |
+
negative_prompts.append((match.group("prompt").replace("\\:", ":"), -w))
|
99 |
+
elif w > 0:
|
100 |
+
positive_prompts.append((match.group("prompt").replace("\\:", ":"), w))
|
101 |
+
|
102 |
+
if skip_normalize:
|
103 |
+
return (negative_prompts, positive_prompts)
|
104 |
+
return (normalize_prompt_weights(negative_prompts), normalize_prompt_weights(positive_prompts))
|
105 |
+
|
106 |
+
# shows how the prompt is tokenized
|
107 |
+
# usually tokens have '</w>' to indicate end-of-word,
|
108 |
+
# but for readability it has been replaced with ' '
|
109 |
+
def log_tokenization(text, model, log=False, weight=1):
|
110 |
+
if not log:
|
111 |
+
return
|
112 |
+
tokens = model.cond_stage_model.tokenizer._tokenize(text)
|
113 |
+
tokenized = ""
|
114 |
+
discarded = ""
|
115 |
+
usedTokens = 0
|
116 |
+
totalTokens = len(tokens)
|
117 |
+
for i in range(0, totalTokens):
|
118 |
+
token = tokens[i].replace('</w>', ' ')
|
119 |
+
# alternate color
|
120 |
+
s = (usedTokens % 6) + 1
|
121 |
+
if i < model.cond_stage_model.max_length:
|
122 |
+
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
123 |
+
usedTokens += 1
|
124 |
+
else: # over max token length
|
125 |
+
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
126 |
+
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
|
127 |
+
if discarded != "":
|
128 |
+
print(
|
129 |
+
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
130 |
+
)
|
deforum-stable-diffusion/helpers/rank_images.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from tqdm import tqdm
|
4 |
+
from PIL import Image
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torchvision import transforms
|
7 |
+
from torchvision.transforms import functional as TF
|
8 |
+
import torch
|
9 |
+
from simulacra_fit_linear_model import AestheticMeanPredictionLinearModel
|
10 |
+
from CLIP import clip
|
11 |
+
|
12 |
+
parser = ArgumentParser()
|
13 |
+
parser.add_argument("directory")
|
14 |
+
parser.add_argument("-t", "--top-n", default=50)
|
15 |
+
args = parser.parse_args()
|
16 |
+
|
17 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
18 |
+
|
19 |
+
clip_model_name = 'ViT-B/16'
|
20 |
+
clip_model = clip.load(clip_model_name, jit=False, device=device)[0]
|
21 |
+
clip_model.eval().requires_grad_(False)
|
22 |
+
|
23 |
+
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
24 |
+
std=[0.26862954, 0.26130258, 0.27577711])
|
25 |
+
|
26 |
+
# 512 is embed dimension for ViT-B/16 CLIP
|
27 |
+
model = AestheticMeanPredictionLinearModel(512)
|
28 |
+
model.load_state_dict(
|
29 |
+
torch.load("models/sac_public_2022_06_29_vit_b_16_linear.pth")
|
30 |
+
)
|
31 |
+
model = model.to(device)
|
32 |
+
|
33 |
+
def get_filepaths(parentpath, filepaths):
|
34 |
+
paths = []
|
35 |
+
for path in filepaths:
|
36 |
+
try:
|
37 |
+
new_parent = os.path.join(parentpath, path)
|
38 |
+
paths += get_filepaths(new_parent, os.listdir(new_parent))
|
39 |
+
except NotADirectoryError:
|
40 |
+
paths.append(os.path.join(parentpath, path))
|
41 |
+
return paths
|
42 |
+
|
43 |
+
filepaths = get_filepaths(args.directory, os.listdir(args.directory))
|
44 |
+
scores = []
|
45 |
+
for path in tqdm(filepaths):
|
46 |
+
# This is obviously a flawed way to check for an image but this is just
|
47 |
+
# a demo script anyway.
|
48 |
+
if path[-4:] not in (".png", ".jpg"):
|
49 |
+
continue
|
50 |
+
img = Image.open(path).convert('RGB')
|
51 |
+
img = TF.resize(img, 224, transforms.InterpolationMode.LANCZOS)
|
52 |
+
img = TF.center_crop(img, (224,224))
|
53 |
+
img = TF.to_tensor(img).to(device)
|
54 |
+
img = normalize(img)
|
55 |
+
clip_image_embed = F.normalize(
|
56 |
+
clip_model.encode_image(img[None, ...]).float(),
|
57 |
+
dim=-1)
|
58 |
+
score = model(clip_image_embed)
|
59 |
+
if len(scores) < args.top_n:
|
60 |
+
scores.append((score.item(),path))
|
61 |
+
scores.sort()
|
62 |
+
else:
|
63 |
+
if scores[0][0] < score:
|
64 |
+
scores.append((score.item(),path))
|
65 |
+
scores.sort(key=lambda x: x[0])
|
66 |
+
scores = scores[1:]
|
67 |
+
|
68 |
+
for score, path in scores:
|
69 |
+
print(f"{score}: {path}")
|
deforum-stable-diffusion/helpers/render.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from IPython import display
|
4 |
+
import random
|
5 |
+
from torchvision.utils import make_grid
|
6 |
+
from einops import rearrange
|
7 |
+
import pandas as pd
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
import pathlib
|
12 |
+
import torchvision.transforms as T
|
13 |
+
|
14 |
+
from .generate import generate, add_noise
|
15 |
+
from .prompt import sanitize
|
16 |
+
from .animation import DeformAnimKeys, sample_from_cv2, sample_to_cv2, anim_frame_warp, vid2frames
|
17 |
+
from .depth import DepthModel
|
18 |
+
from .colors import maintain_colors
|
19 |
+
from .load_images import prepare_overlay_mask
|
20 |
+
|
21 |
+
def next_seed(args):
|
22 |
+
if args.seed_behavior == 'iter':
|
23 |
+
args.seed += 1
|
24 |
+
elif args.seed_behavior == 'fixed':
|
25 |
+
pass # always keep seed the same
|
26 |
+
else:
|
27 |
+
args.seed = random.randint(0, 2**32 - 1)
|
28 |
+
return args.seed
|
29 |
+
|
30 |
+
def render_image_batch(args, prompts, root):
|
31 |
+
args.prompts = {k: f"{v:05d}" for v, k in enumerate(prompts)}
|
32 |
+
|
33 |
+
# create output folder for the batch
|
34 |
+
os.makedirs(args.outdir, exist_ok=True)
|
35 |
+
if args.save_settings or args.save_samples:
|
36 |
+
print(f"Saving to {os.path.join(args.outdir, args.timestring)}_*")
|
37 |
+
|
38 |
+
# save settings for the batch
|
39 |
+
if args.save_settings:
|
40 |
+
filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt")
|
41 |
+
with open(filename, "w+", encoding="utf-8") as f:
|
42 |
+
dictlist = dict(args.__dict__)
|
43 |
+
del dictlist['master_args']
|
44 |
+
del dictlist['root']
|
45 |
+
del dictlist['get_output_folder']
|
46 |
+
json.dump(dictlist, f, ensure_ascii=False, indent=4)
|
47 |
+
|
48 |
+
index = 0
|
49 |
+
|
50 |
+
# function for init image batching
|
51 |
+
init_array = []
|
52 |
+
if args.use_init:
|
53 |
+
if args.init_image == "":
|
54 |
+
raise FileNotFoundError("No path was given for init_image")
|
55 |
+
if args.init_image.startswith('http://') or args.init_image.startswith('https://'):
|
56 |
+
init_array.append(args.init_image)
|
57 |
+
elif not os.path.isfile(args.init_image):
|
58 |
+
if args.init_image[-1] != "/": # avoids path error by adding / to end if not there
|
59 |
+
args.init_image += "/"
|
60 |
+
for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array
|
61 |
+
if image.split(".")[-1] in ("png", "jpg", "jpeg"):
|
62 |
+
init_array.append(args.init_image + image)
|
63 |
+
else:
|
64 |
+
init_array.append(args.init_image)
|
65 |
+
else:
|
66 |
+
init_array = [""]
|
67 |
+
|
68 |
+
# when doing large batches don't flood browser with images
|
69 |
+
clear_between_batches = args.n_batch >= 32
|
70 |
+
|
71 |
+
for iprompt, prompt in enumerate(prompts):
|
72 |
+
args.prompt = prompt
|
73 |
+
args.clip_prompt = prompt
|
74 |
+
print(f"Prompt {iprompt+1} of {len(prompts)}")
|
75 |
+
print(f"{args.prompt}")
|
76 |
+
|
77 |
+
all_images = []
|
78 |
+
|
79 |
+
for batch_index in range(args.n_batch):
|
80 |
+
if clear_between_batches and batch_index % 32 == 0:
|
81 |
+
display.clear_output(wait=True)
|
82 |
+
print(f"Batch {batch_index+1} of {args.n_batch}")
|
83 |
+
|
84 |
+
for image in init_array: # iterates the init images
|
85 |
+
args.init_image = image
|
86 |
+
results = generate(args, root)
|
87 |
+
for image in results:
|
88 |
+
if args.make_grid:
|
89 |
+
all_images.append(T.functional.pil_to_tensor(image))
|
90 |
+
if args.save_samples:
|
91 |
+
if args.filename_format == "{timestring}_{index}_{prompt}.png":
|
92 |
+
filename = f"{args.timestring}_{index:05}_{sanitize(prompt)[:160]}.png"
|
93 |
+
else:
|
94 |
+
filename = f"{args.timestring}_{index:05}_{args.seed}.png"
|
95 |
+
image.save(os.path.join(args.outdir, filename))
|
96 |
+
if args.display_samples:
|
97 |
+
display.display(image)
|
98 |
+
index += 1
|
99 |
+
args.seed = next_seed(args)
|
100 |
+
|
101 |
+
#print(len(all_images))
|
102 |
+
if args.make_grid:
|
103 |
+
grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows))
|
104 |
+
grid = rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
105 |
+
filename = f"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png"
|
106 |
+
grid_image = Image.fromarray(grid.astype(np.uint8))
|
107 |
+
grid_image.save(os.path.join(args.outdir, filename))
|
108 |
+
display.clear_output(wait=True)
|
109 |
+
display.display(grid_image)
|
110 |
+
|
111 |
+
|
112 |
+
def render_animation(args, anim_args, animation_prompts, root):
|
113 |
+
# animations use key framed prompts
|
114 |
+
args.prompts = animation_prompts
|
115 |
+
|
116 |
+
# expand key frame strings to values
|
117 |
+
keys = DeformAnimKeys(anim_args)
|
118 |
+
|
119 |
+
# resume animation
|
120 |
+
start_frame = 0
|
121 |
+
if anim_args.resume_from_timestring:
|
122 |
+
for tmp in os.listdir(args.outdir):
|
123 |
+
if tmp.split("_")[0] == anim_args.resume_timestring:
|
124 |
+
start_frame += 1
|
125 |
+
start_frame = start_frame - 1
|
126 |
+
|
127 |
+
# create output folder for the batch
|
128 |
+
os.makedirs(args.outdir, exist_ok=True)
|
129 |
+
print(f"Saving animation frames to {args.outdir}")
|
130 |
+
|
131 |
+
# save settings for the batch
|
132 |
+
'''
|
133 |
+
settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt")
|
134 |
+
with open(settings_filename, "w+", encoding="utf-8") as f:
|
135 |
+
s = {**dict(args.__dict__), **dict(anim_args.__dict__)}
|
136 |
+
#DGSpitzer: run.py adds these three parameters
|
137 |
+
del s['master_args']
|
138 |
+
del s['opt']
|
139 |
+
del s['root']
|
140 |
+
del s['get_output_folder']
|
141 |
+
#print(s)
|
142 |
+
json.dump(s, f, ensure_ascii=False, indent=4)
|
143 |
+
'''
|
144 |
+
# resume from timestring
|
145 |
+
if anim_args.resume_from_timestring:
|
146 |
+
args.timestring = anim_args.resume_timestring
|
147 |
+
|
148 |
+
# expand prompts out to per-frame
|
149 |
+
prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)])
|
150 |
+
for i, prompt in animation_prompts.items():
|
151 |
+
prompt_series[int(i)] = prompt
|
152 |
+
prompt_series = prompt_series.ffill().bfill()
|
153 |
+
|
154 |
+
# check for video inits
|
155 |
+
using_vid_init = anim_args.animation_mode == 'Video Input'
|
156 |
+
|
157 |
+
# load depth model for 3D
|
158 |
+
predict_depths = (anim_args.animation_mode == '3D' and anim_args.use_depth_warping) or anim_args.save_depth_maps
|
159 |
+
if predict_depths:
|
160 |
+
depth_model = DepthModel(root.device)
|
161 |
+
depth_model.load_midas(root.models_path)
|
162 |
+
if anim_args.midas_weight < 1.0:
|
163 |
+
depth_model.load_adabins(root.models_path)
|
164 |
+
else:
|
165 |
+
depth_model = None
|
166 |
+
anim_args.save_depth_maps = False
|
167 |
+
|
168 |
+
# state for interpolating between diffusion steps
|
169 |
+
turbo_steps = 1 if using_vid_init else int(anim_args.diffusion_cadence)
|
170 |
+
turbo_prev_image, turbo_prev_frame_idx = None, 0
|
171 |
+
turbo_next_image, turbo_next_frame_idx = None, 0
|
172 |
+
|
173 |
+
# resume animation
|
174 |
+
prev_sample = None
|
175 |
+
color_match_sample = None
|
176 |
+
if anim_args.resume_from_timestring:
|
177 |
+
last_frame = start_frame-1
|
178 |
+
if turbo_steps > 1:
|
179 |
+
last_frame -= last_frame%turbo_steps
|
180 |
+
path = os.path.join(args.outdir,f"{args.timestring}_{last_frame:05}.png")
|
181 |
+
img = cv2.imread(path)
|
182 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
183 |
+
prev_sample = sample_from_cv2(img)
|
184 |
+
if anim_args.color_coherence != 'None':
|
185 |
+
color_match_sample = img
|
186 |
+
if turbo_steps > 1:
|
187 |
+
turbo_next_image, turbo_next_frame_idx = sample_to_cv2(prev_sample, type=np.float32), last_frame
|
188 |
+
turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx
|
189 |
+
start_frame = last_frame+turbo_steps
|
190 |
+
|
191 |
+
args.n_samples = 1
|
192 |
+
frame_idx = start_frame
|
193 |
+
while frame_idx < anim_args.max_frames:
|
194 |
+
print(f"Rendering animation frame {frame_idx} of {anim_args.max_frames}")
|
195 |
+
noise = keys.noise_schedule_series[frame_idx]
|
196 |
+
strength = keys.strength_schedule_series[frame_idx]
|
197 |
+
contrast = keys.contrast_schedule_series[frame_idx]
|
198 |
+
depth = None
|
199 |
+
|
200 |
+
# emit in-between frames
|
201 |
+
if turbo_steps > 1:
|
202 |
+
tween_frame_start_idx = max(0, frame_idx-turbo_steps)
|
203 |
+
for tween_frame_idx in range(tween_frame_start_idx, frame_idx):
|
204 |
+
tween = float(tween_frame_idx - tween_frame_start_idx + 1) / float(frame_idx - tween_frame_start_idx)
|
205 |
+
print(f" creating in between frame {tween_frame_idx} tween:{tween:0.2f}")
|
206 |
+
|
207 |
+
advance_prev = turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx
|
208 |
+
advance_next = tween_frame_idx > turbo_next_frame_idx
|
209 |
+
|
210 |
+
if depth_model is not None:
|
211 |
+
assert(turbo_next_image is not None)
|
212 |
+
depth = depth_model.predict(turbo_next_image, anim_args)
|
213 |
+
|
214 |
+
if advance_prev:
|
215 |
+
turbo_prev_image, _ = anim_frame_warp(turbo_prev_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device)
|
216 |
+
if advance_next:
|
217 |
+
turbo_next_image, _ = anim_frame_warp(turbo_next_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device)
|
218 |
+
# Transformed raw image before color coherence and noise. Used for mask overlay
|
219 |
+
if args.use_mask and args.overlay_mask:
|
220 |
+
# Apply transforms to the original image
|
221 |
+
init_image_raw, _ = anim_frame_warp(args.init_sample_raw, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device)
|
222 |
+
if root.half_precision:
|
223 |
+
args.init_sample_raw = sample_from_cv2(init_image_raw).half().to(root.device)
|
224 |
+
else:
|
225 |
+
args.init_sample_raw = sample_from_cv2(init_image_raw).to(root.device)
|
226 |
+
|
227 |
+
#Transform the mask image
|
228 |
+
if args.use_mask:
|
229 |
+
if args.mask_sample is None:
|
230 |
+
args.mask_sample = prepare_overlay_mask(args, root, prev_sample.shape)
|
231 |
+
# Transform the mask
|
232 |
+
mask_image, _ = anim_frame_warp(args.mask_sample, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device)
|
233 |
+
if root.half_precision:
|
234 |
+
args.mask_sample = sample_from_cv2(mask_image).half().to(root.device)
|
235 |
+
else:
|
236 |
+
args.mask_sample = sample_from_cv2(mask_image).to(root.device)
|
237 |
+
|
238 |
+
turbo_prev_frame_idx = turbo_next_frame_idx = tween_frame_idx
|
239 |
+
|
240 |
+
if turbo_prev_image is not None and tween < 1.0:
|
241 |
+
img = turbo_prev_image*(1.0-tween) + turbo_next_image*tween
|
242 |
+
else:
|
243 |
+
img = turbo_next_image
|
244 |
+
|
245 |
+
filename = f"{args.timestring}_{tween_frame_idx:05}.png"
|
246 |
+
cv2.imwrite(os.path.join(args.outdir, filename), cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR))
|
247 |
+
if anim_args.save_depth_maps:
|
248 |
+
depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{tween_frame_idx:05}.png"), depth)
|
249 |
+
if turbo_next_image is not None:
|
250 |
+
prev_sample = sample_from_cv2(turbo_next_image)
|
251 |
+
|
252 |
+
# apply transforms to previous frame
|
253 |
+
if prev_sample is not None:
|
254 |
+
prev_img, depth = anim_frame_warp(prev_sample, args, anim_args, keys, frame_idx, depth_model, depth=None, device=root.device)
|
255 |
+
|
256 |
+
# Transformed raw image before color coherence and noise. Used for mask overlay
|
257 |
+
if args.use_mask and args.overlay_mask:
|
258 |
+
# Apply transforms to the original image
|
259 |
+
init_image_raw, _ = anim_frame_warp(args.init_sample_raw, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device)
|
260 |
+
|
261 |
+
if root.half_precision:
|
262 |
+
args.init_sample_raw = sample_from_cv2(init_image_raw).half().to(root.device)
|
263 |
+
else:
|
264 |
+
args.init_sample_raw = sample_from_cv2(init_image_raw).to(root.device)
|
265 |
+
|
266 |
+
#Transform the mask image
|
267 |
+
if args.use_mask:
|
268 |
+
if args.mask_sample is None:
|
269 |
+
args.mask_sample = prepare_overlay_mask(args, root, prev_sample.shape)
|
270 |
+
# Transform the mask
|
271 |
+
mask_sample, _ = anim_frame_warp(args.mask_sample, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device)
|
272 |
+
|
273 |
+
if root.half_precision:
|
274 |
+
args.mask_sample = sample_from_cv2(mask_sample).half().to(root.device)
|
275 |
+
else:
|
276 |
+
args.mask_sample = sample_from_cv2(mask_sample).to(root.device)
|
277 |
+
|
278 |
+
# apply color matching
|
279 |
+
if anim_args.color_coherence != 'None':
|
280 |
+
if color_match_sample is None:
|
281 |
+
color_match_sample = prev_img.copy()
|
282 |
+
else:
|
283 |
+
prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence)
|
284 |
+
|
285 |
+
# apply scaling
|
286 |
+
contrast_sample = prev_img * contrast
|
287 |
+
# apply frame noising
|
288 |
+
noised_sample = add_noise(sample_from_cv2(contrast_sample), noise)
|
289 |
+
|
290 |
+
# use transformed previous frame as init for current
|
291 |
+
args.use_init = True
|
292 |
+
if root.half_precision:
|
293 |
+
args.init_sample = noised_sample.half().to(root.device)
|
294 |
+
else:
|
295 |
+
args.init_sample = noised_sample.to(root.device)
|
296 |
+
args.strength = max(0.0, min(1.0, strength))
|
297 |
+
|
298 |
+
# grab prompt for current frame
|
299 |
+
args.prompt = prompt_series[frame_idx]
|
300 |
+
args.clip_prompt = args.prompt
|
301 |
+
print(f"{args.prompt} {args.seed}")
|
302 |
+
if not using_vid_init:
|
303 |
+
print(f"Angle: {keys.angle_series[frame_idx]} Zoom: {keys.zoom_series[frame_idx]}")
|
304 |
+
print(f"Tx: {keys.translation_x_series[frame_idx]} Ty: {keys.translation_y_series[frame_idx]} Tz: {keys.translation_z_series[frame_idx]}")
|
305 |
+
print(f"Rx: {keys.rotation_3d_x_series[frame_idx]} Ry: {keys.rotation_3d_y_series[frame_idx]} Rz: {keys.rotation_3d_z_series[frame_idx]}")
|
306 |
+
|
307 |
+
# grab init image for current frame
|
308 |
+
if using_vid_init:
|
309 |
+
init_frame = os.path.join(args.outdir, 'inputframes', f"{frame_idx+1:05}.jpg")
|
310 |
+
print(f"Using video init frame {init_frame}")
|
311 |
+
args.init_image = init_frame
|
312 |
+
if anim_args.use_mask_video:
|
313 |
+
mask_frame = os.path.join(args.outdir, 'maskframes', f"{frame_idx+1:05}.jpg")
|
314 |
+
args.mask_file = mask_frame
|
315 |
+
|
316 |
+
# sample the diffusion model
|
317 |
+
sample, image = generate(args, root, frame_idx, return_latent=False, return_sample=True)
|
318 |
+
# First image sample used for masking
|
319 |
+
if not using_vid_init:
|
320 |
+
prev_sample = sample
|
321 |
+
if args.use_mask and args.overlay_mask:
|
322 |
+
if args.init_sample_raw is None:
|
323 |
+
args.init_sample_raw = sample
|
324 |
+
|
325 |
+
if turbo_steps > 1:
|
326 |
+
turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx
|
327 |
+
turbo_next_image, turbo_next_frame_idx = sample_to_cv2(sample, type=np.float32), frame_idx
|
328 |
+
frame_idx += turbo_steps
|
329 |
+
else:
|
330 |
+
filename = f"{args.timestring}_{frame_idx:05}.png"
|
331 |
+
image.save(os.path.join(args.outdir, filename))
|
332 |
+
if anim_args.save_depth_maps:
|
333 |
+
depth = depth_model.predict(sample_to_cv2(sample), anim_args)
|
334 |
+
depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{frame_idx:05}.png"), depth)
|
335 |
+
frame_idx += 1
|
336 |
+
|
337 |
+
display.clear_output(wait=True)
|
338 |
+
display.display(image)
|
339 |
+
|
340 |
+
args.seed = next_seed(args)
|
341 |
+
|
342 |
+
def render_input_video(args, anim_args, animation_prompts, root):
|
343 |
+
# create a folder for the video input frames to live in
|
344 |
+
video_in_frame_path = os.path.join(args.outdir, 'inputframes')
|
345 |
+
os.makedirs(video_in_frame_path, exist_ok=True)
|
346 |
+
|
347 |
+
# save the video frames from input video
|
348 |
+
print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...")
|
349 |
+
vid2frames(anim_args.video_init_path, video_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames)
|
350 |
+
|
351 |
+
# determine max frames from length of input frames
|
352 |
+
anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')])
|
353 |
+
args.use_init = True
|
354 |
+
print(f"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}")
|
355 |
+
|
356 |
+
if anim_args.use_mask_video:
|
357 |
+
# create a folder for the mask video input frames to live in
|
358 |
+
mask_in_frame_path = os.path.join(args.outdir, 'maskframes')
|
359 |
+
os.makedirs(mask_in_frame_path, exist_ok=True)
|
360 |
+
|
361 |
+
# save the video frames from mask video
|
362 |
+
print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...")
|
363 |
+
vid2frames(anim_args.video_mask_path, mask_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames)
|
364 |
+
args.use_mask = True
|
365 |
+
args.overlay_mask = True
|
366 |
+
|
367 |
+
render_animation(args, anim_args, animation_prompts, root)
|
368 |
+
|
369 |
+
def render_interpolation(args, anim_args, animation_prompts, root):
|
370 |
+
# animations use key framed prompts
|
371 |
+
args.prompts = animation_prompts
|
372 |
+
|
373 |
+
# create output folder for the batch
|
374 |
+
os.makedirs(args.outdir, exist_ok=True)
|
375 |
+
print(f"Saving animation frames to {args.outdir}")
|
376 |
+
|
377 |
+
# save settings for the batch
|
378 |
+
settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt")
|
379 |
+
with open(settings_filename, "w+", encoding="utf-8") as f:
|
380 |
+
s = {**dict(args.__dict__), **dict(anim_args.__dict__)}
|
381 |
+
del s['master_args']
|
382 |
+
del s['opt']
|
383 |
+
del s['root']
|
384 |
+
del s['get_output_folder']
|
385 |
+
json.dump(s, f, ensure_ascii=False, indent=4)
|
386 |
+
|
387 |
+
# Interpolation Settings
|
388 |
+
args.n_samples = 1
|
389 |
+
args.seed_behavior = 'fixed' # force fix seed at the moment bc only 1 seed is available
|
390 |
+
prompts_c_s = [] # cache all the text embeddings
|
391 |
+
|
392 |
+
print(f"Preparing for interpolation of the following...")
|
393 |
+
|
394 |
+
for i, prompt in animation_prompts.items():
|
395 |
+
args.prompt = prompt
|
396 |
+
args.clip_prompt = args.prompt
|
397 |
+
|
398 |
+
# sample the diffusion model
|
399 |
+
results = generate(args, root, return_c=True)
|
400 |
+
c, image = results[0], results[1]
|
401 |
+
prompts_c_s.append(c)
|
402 |
+
|
403 |
+
# display.clear_output(wait=True)
|
404 |
+
display.display(image)
|
405 |
+
|
406 |
+
args.seed = next_seed(args)
|
407 |
+
|
408 |
+
display.clear_output(wait=True)
|
409 |
+
print(f"Interpolation start...")
|
410 |
+
|
411 |
+
frame_idx = 0
|
412 |
+
|
413 |
+
if anim_args.interpolate_key_frames:
|
414 |
+
for i in range(len(prompts_c_s)-1):
|
415 |
+
dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0]
|
416 |
+
if dist_frames <= 0:
|
417 |
+
print("key frames duplicated or reversed. interpolation skipped.")
|
418 |
+
return
|
419 |
+
else:
|
420 |
+
for j in range(dist_frames):
|
421 |
+
# interpolate the text embedding
|
422 |
+
prompt1_c = prompts_c_s[i]
|
423 |
+
prompt2_c = prompts_c_s[i+1]
|
424 |
+
args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames))
|
425 |
+
|
426 |
+
# sample the diffusion model
|
427 |
+
results = generate(args, root)
|
428 |
+
image = results[0]
|
429 |
+
|
430 |
+
filename = f"{args.timestring}_{frame_idx:05}.png"
|
431 |
+
image.save(os.path.join(args.outdir, filename))
|
432 |
+
frame_idx += 1
|
433 |
+
|
434 |
+
display.clear_output(wait=True)
|
435 |
+
display.display(image)
|
436 |
+
|
437 |
+
args.seed = next_seed(args)
|
438 |
+
|
439 |
+
else:
|
440 |
+
for i in range(len(prompts_c_s)-1):
|
441 |
+
for j in range(anim_args.interpolate_x_frames+1):
|
442 |
+
# interpolate the text embedding
|
443 |
+
prompt1_c = prompts_c_s[i]
|
444 |
+
prompt2_c = prompts_c_s[i+1]
|
445 |
+
args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1)))
|
446 |
+
|
447 |
+
# sample the diffusion model
|
448 |
+
results = generate(args, root)
|
449 |
+
image = results[0]
|
450 |
+
|
451 |
+
filename = f"{args.timestring}_{frame_idx:05}.png"
|
452 |
+
image.save(os.path.join(args.outdir, filename))
|
453 |
+
frame_idx += 1
|
454 |
+
|
455 |
+
display.clear_output(wait=True)
|
456 |
+
display.display(image)
|
457 |
+
|
458 |
+
args.seed = next_seed(args)
|
459 |
+
|
460 |
+
# generate the last prompt
|
461 |
+
args.init_c = prompts_c_s[-1]
|
462 |
+
results = generate(args, root)
|
463 |
+
image = results[0]
|
464 |
+
filename = f"{args.timestring}_{frame_idx:05}.png"
|
465 |
+
image.save(os.path.join(args.outdir, filename))
|
466 |
+
|
467 |
+
display.clear_output(wait=True)
|
468 |
+
display.display(image)
|
469 |
+
args.seed = next_seed(args)
|
470 |
+
|
471 |
+
#clear init_c
|
472 |
+
args.init_c = None
|
deforum-stable-diffusion/helpers/save_images.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
from einops import rearrange
|
3 |
+
import numpy as np, os, torch
|
4 |
+
from PIL import Image
|
5 |
+
from torchvision.utils import make_grid
|
6 |
+
import time
|
7 |
+
|
8 |
+
|
9 |
+
def get_output_folder(output_path, batch_folder):
|
10 |
+
out_path = os.path.join(output_path,time.strftime('%Y-%m'))
|
11 |
+
if batch_folder != "":
|
12 |
+
out_path = os.path.join(out_path, batch_folder)
|
13 |
+
os.makedirs(out_path, exist_ok=True)
|
14 |
+
return out_path
|
15 |
+
|
16 |
+
|
17 |
+
def save_samples(
|
18 |
+
args, x_samples: torch.Tensor, seed: int, n_rows: int
|
19 |
+
) -> Tuple[Image.Image, List[Image.Image]]:
|
20 |
+
"""Function to save samples to disk.
|
21 |
+
Args:
|
22 |
+
args: Stable deforum diffusion arguments.
|
23 |
+
x_samples: Samples to save.
|
24 |
+
seed: Seed for the experiment.
|
25 |
+
n_rows: Number of rows in the grid.
|
26 |
+
Returns:
|
27 |
+
A tuple of the grid image and a list of the generated images.
|
28 |
+
( grid_image, generated_images )
|
29 |
+
"""
|
30 |
+
|
31 |
+
# save samples
|
32 |
+
images = []
|
33 |
+
grid_image = None
|
34 |
+
if args.display_samples or args.save_samples:
|
35 |
+
for index, x_sample in enumerate(x_samples):
|
36 |
+
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
37 |
+
images.append(Image.fromarray(x_sample.astype(np.uint8)))
|
38 |
+
if args.save_samples:
|
39 |
+
images[-1].save(
|
40 |
+
os.path.join(
|
41 |
+
args.outdir, f"{args.timestring}_{index:02}_{seed}.png"
|
42 |
+
)
|
43 |
+
)
|
44 |
+
|
45 |
+
# save grid
|
46 |
+
if args.display_grid or args.save_grid:
|
47 |
+
grid = torch.stack([x_samples], 0)
|
48 |
+
grid = rearrange(grid, "n b c h w -> (n b) c h w")
|
49 |
+
grid = make_grid(grid, nrow=n_rows, padding=0)
|
50 |
+
|
51 |
+
# to image
|
52 |
+
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
|
53 |
+
grid_image = Image.fromarray(grid.astype(np.uint8))
|
54 |
+
if args.save_grid:
|
55 |
+
grid_image.save(
|
56 |
+
os.path.join(args.outdir, f"{args.timestring}_{seed}_grid.png")
|
57 |
+
)
|
58 |
+
|
59 |
+
# return grid_image and individual sample images
|
60 |
+
return grid_image, images
|
deforum-stable-diffusion/helpers/settings.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
def load_args(args_dict, anim_args_dict, settings_file, custom_settings_file, verbose=True):
|
5 |
+
default_settings_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'settings'))
|
6 |
+
if settings_file.lower() == 'custom':
|
7 |
+
settings_filename = custom_settings_file
|
8 |
+
else:
|
9 |
+
settings_filename = os.path.join(default_settings_dir,settings_file)
|
10 |
+
print(f"Reading custom settings from {settings_filename}...")
|
11 |
+
if not os.path.isfile(settings_filename):
|
12 |
+
print('The settings file does not exist. The in-notebook settings will be used instead.')
|
13 |
+
else:
|
14 |
+
if not verbose:
|
15 |
+
print(f"Any settings not included in {settings_filename} will use the in-notebook settings by default.")
|
16 |
+
with open(settings_filename, "r") as f:
|
17 |
+
jdata = json.loads(f.read())
|
18 |
+
if jdata.get("prompts") is not None:
|
19 |
+
animation_prompts = jdata["prompts"]
|
20 |
+
for i, k in enumerate(args_dict):
|
21 |
+
if k in jdata:
|
22 |
+
args_dict[k] = jdata[k]
|
23 |
+
else:
|
24 |
+
if verbose:
|
25 |
+
print(f"key {k} doesn't exist in the custom settings data! using the default value of {args_dict[k]}")
|
26 |
+
for i, k in enumerate(anim_args_dict):
|
27 |
+
if k in jdata:
|
28 |
+
anim_args_dict[k] = jdata[k]
|
29 |
+
else:
|
30 |
+
if verbose:
|
31 |
+
print(f"key {k} doesn't exist in the custom settings data! using the default value of {anim_args_dict[k]}")
|
32 |
+
if verbose:
|
33 |
+
print(args_dict)
|
34 |
+
print(anim_args_dict)
|
deforum-stable-diffusion/helpers/simulacra_compute_embeddings.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""Precomputes CLIP embeddings for Simulacra Aesthetic Captions."""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
import sqlite3
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import multiprocessing as mp
|
14 |
+
from torch.utils import data
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
from CLIP import clip
|
19 |
+
|
20 |
+
|
21 |
+
class SimulacraDataset(data.Dataset):
|
22 |
+
"""Simulacra dataset
|
23 |
+
Args:
|
24 |
+
images_dir: directory
|
25 |
+
transform: preprocessing and augmentation of the training images
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, images_dir, db, transform=None):
|
29 |
+
self.images_dir = Path(images_dir)
|
30 |
+
self.transform = transform
|
31 |
+
self.conn = sqlite3.connect(db)
|
32 |
+
self.ratings = []
|
33 |
+
for row in self.conn.execute('SELECT generations.id, images.idx, paths.path, AVG(ratings.rating) FROM images JOIN generations ON images.gid=generations.id JOIN ratings ON images.id=ratings.iid JOIN paths ON images.id=paths.iid GROUP BY images.id'):
|
34 |
+
self.ratings.append(row)
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.ratings)
|
38 |
+
|
39 |
+
def __getitem__(self, key):
|
40 |
+
gid, idx, filename, rating = self.ratings[key]
|
41 |
+
image = Image.open(self.images_dir / filename).convert('RGB')
|
42 |
+
if self.transform:
|
43 |
+
image = self.transform(image)
|
44 |
+
return image, torch.tensor(rating)
|
45 |
+
|
46 |
+
|
47 |
+
def main():
|
48 |
+
p = argparse.ArgumentParser(description=__doc__)
|
49 |
+
p.add_argument('--batch-size', '-bs', type=int, default=10,
|
50 |
+
help='the CLIP model')
|
51 |
+
p.add_argument('--clip-model', type=str, default='ViT-B/16',
|
52 |
+
help='the CLIP model')
|
53 |
+
p.add_argument('--db', type=str, required=True,
|
54 |
+
help='the database location')
|
55 |
+
p.add_argument('--device', type=str,
|
56 |
+
help='the device to use')
|
57 |
+
p.add_argument('--images-dir', type=str, required=True,
|
58 |
+
help='the dataset images directory')
|
59 |
+
p.add_argument('--num-workers', type=int, default=8,
|
60 |
+
help='the number of data loader workers')
|
61 |
+
p.add_argument('--output', type=str, required=True,
|
62 |
+
help='the output file')
|
63 |
+
p.add_argument('--start-method', type=str, default='spawn',
|
64 |
+
choices=['fork', 'forkserver', 'spawn'],
|
65 |
+
help='the multiprocessing start method')
|
66 |
+
args = p.parse_args()
|
67 |
+
|
68 |
+
mp.set_start_method(args.start_method)
|
69 |
+
if args.device:
|
70 |
+
device = torch.device(device)
|
71 |
+
else:
|
72 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
73 |
+
print('Using device:', device)
|
74 |
+
|
75 |
+
clip_model, clip_tf = clip.load(args.clip_model, device=device, jit=False)
|
76 |
+
clip_model = clip_model.eval().requires_grad_(False)
|
77 |
+
|
78 |
+
dataset = SimulacraDataset(args.images_dir, args.db, transform=clip_tf)
|
79 |
+
loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers)
|
80 |
+
|
81 |
+
embeds, ratings = [], []
|
82 |
+
|
83 |
+
for batch in tqdm(loader):
|
84 |
+
images_batch, ratings_batch = batch
|
85 |
+
embeds.append(clip_model.encode_image(images_batch.to(device)).cpu())
|
86 |
+
ratings.append(ratings_batch.clone())
|
87 |
+
|
88 |
+
obj = {'clip_model': args.clip_model,
|
89 |
+
'embeds': torch.cat(embeds),
|
90 |
+
'ratings': torch.cat(ratings)}
|
91 |
+
|
92 |
+
torch.save(obj, args.output)
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
main()
|