pt-sk commited on
Commit
161e2ab
·
verified ·
1 Parent(s): 7b61874

Upload 11 files

Browse files
Files changed (11) hide show
  1. SD/attention.py +122 -0
  2. SD/clip.py +96 -0
  3. SD/ddpm.py +123 -0
  4. SD/decoder.py +177 -0
  5. SD/diffusion.py +349 -0
  6. SD/encoder.py +103 -0
  7. SD/model_converter.py +0 -0
  8. SD/model_loader.py +28 -0
  9. SD/pipeline.py +170 -0
  10. SD/run.py +64 -0
  11. SD/sd_demo.ipynb +0 -0
SD/attention.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class SelfAttention(nn.Module):
7
+ def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
8
+ super().__init__()
9
+ # This combines the Wq, Wk and Wv matrices into one matrix
10
+ self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
11
+ # This one represents the Wo matrix
12
+ self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
13
+ self.n_heads = n_heads
14
+ self.d_head = d_embed // n_heads
15
+
16
+ def forward(self, x, causal_mask=False):
17
+ # x: # (Batch_Size, Seq_Len, Dim)
18
+
19
+ # (Batch_Size, Seq_Len, Dim)
20
+ input_shape = x.shape
21
+
22
+ # (Batch_Size, Seq_Len, Dim)
23
+ batch_size, sequence_length, d_embed = input_shape
24
+
25
+ # (Batch_Size, Seq_Len, H, Dim / H)
26
+ interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
27
+
28
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)
29
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
30
+
31
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
32
+ q = q.view(interim_shape).transpose(1, 2)
33
+ k = k.view(interim_shape).transpose(1, 2)
34
+ v = v.view(interim_shape).transpose(1, 2)
35
+
36
+ # (Batch_Size, H, Seq_Len, Dim) @ (Batch_Size, H, Dim, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
37
+ weight = q @ k.transpose(-1, -2)
38
+
39
+ if causal_mask:
40
+ # Mask where the upper triangle (above the principal diagonal) is 1
41
+ mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
42
+ # Fill the upper triangle with -inf
43
+ weight.masked_fill_(mask, -torch.inf)
44
+
45
+ # Divide by d_k (Dim / H).
46
+ # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
47
+ weight /= math.sqrt(self.d_head)
48
+
49
+ # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
50
+ weight = F.softmax(weight, dim=-1)
51
+
52
+ # (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
53
+ output = weight @ v
54
+
55
+ # (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)
56
+ output = output.transpose(1, 2)
57
+
58
+ # (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)
59
+ output = output.reshape(input_shape)
60
+
61
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
62
+ output = self.out_proj(output)
63
+
64
+ # (Batch_Size, Seq_Len, Dim)
65
+ return output
66
+
67
+ class CrossAttention(nn.Module):
68
+ def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
69
+ super().__init__()
70
+ self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
71
+ self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
72
+ self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
73
+ self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
74
+ self.n_heads = n_heads
75
+ self.d_head = d_embed // n_heads
76
+
77
+ def forward(self, x, y):
78
+ # x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)
79
+ # y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)
80
+
81
+ input_shape = x.shape
82
+ batch_size, sequence_length, d_embed = input_shape
83
+ # Divide each embedding of Q into multiple heads such that d_heads * n_heads = Dim_Q
84
+ interim_shape = (batch_size, -1, self.n_heads, self.d_head)
85
+
86
+ # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
87
+ q = self.q_proj(x)
88
+ # (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
89
+ k = self.k_proj(y)
90
+ # (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
91
+ v = self.v_proj(y)
92
+
93
+ # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
94
+ q = q.view(interim_shape).transpose(1, 2)
95
+ # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
96
+ k = k.view(interim_shape).transpose(1, 2)
97
+ # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
98
+ v = v.view(interim_shape).transpose(1, 2)
99
+
100
+ # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) @ (Batch_Size, H, Dim_Q / H, Seq_Len_KV) -> (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
101
+ weight = q @ k.transpose(-1, -2)
102
+
103
+ # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
104
+ weight /= math.sqrt(self.d_head)
105
+
106
+ # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
107
+ weight = F.softmax(weight, dim=-1)
108
+
109
+ # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV) @ (Batch_Size, H, Seq_Len_KV, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
110
+ output = weight @ v
111
+
112
+ # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H)
113
+ output = output.transpose(1, 2).contiguous()
114
+
115
+ # (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, Dim_Q)
116
+ output = output.view(input_shape)
117
+
118
+ # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
119
+ output = self.out_proj(output)
120
+
121
+ # (Batch_Size, Seq_Len_Q, Dim_Q)
122
+ return output
SD/clip.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from attention import SelfAttention
5
+
6
+ class CLIPEmbedding(nn.Module):
7
+ def __init__(self, n_vocab: int, n_embd: int, n_token: int):
8
+ super().__init__()
9
+
10
+ self.token_embedding = nn.Embedding(n_vocab, n_embd)
11
+ # A learnable weight matrix encodes the position information for each token
12
+ self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd)))
13
+
14
+ def forward(self, tokens):
15
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
16
+ x = self.token_embedding(tokens)
17
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
18
+ x += self.position_embedding
19
+
20
+ return x
21
+
22
+ class CLIPLayer(nn.Module):
23
+ def __init__(self, n_head: int, n_embd: int):
24
+ super().__init__()
25
+
26
+ # Pre-attention norm
27
+ self.layernorm_1 = nn.LayerNorm(n_embd)
28
+ # Self attention
29
+ self.attention = SelfAttention(n_head, n_embd)
30
+ # Pre-FNN norm
31
+ self.layernorm_2 = nn.LayerNorm(n_embd)
32
+ # Feedforward layer
33
+ self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
34
+ self.linear_2 = nn.Linear(4 * n_embd, n_embd)
35
+
36
+ def forward(self, x):
37
+ # (Batch_Size, Seq_Len, Dim)
38
+ residue = x
39
+
40
+ ### SELF ATTENTION ###
41
+
42
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
43
+ x = self.layernorm_1(x)
44
+
45
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
46
+ x = self.attention(x, causal_mask=True)
47
+
48
+ # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
49
+ x += residue
50
+
51
+ ### FEEDFORWARD LAYER ###
52
+ # Apply a feedforward layer where the hidden dimension is 4 times the embedding dimension.
53
+
54
+ residue = x
55
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
56
+ x = self.layernorm_2(x)
57
+
58
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, 4 * Dim)
59
+ x = self.linear_1(x)
60
+
61
+ # (Batch_Size, Seq_Len, 4 * Dim) -> (Batch_Size, Seq_Len, 4 * Dim)
62
+ x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function
63
+
64
+ # (Batch_Size, Seq_Len, 4 * Dim) -> (Batch_Size, Seq_Len, Dim)
65
+ x = self.linear_2(x)
66
+
67
+ # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
68
+ x += residue
69
+
70
+ return x
71
+
72
+ class CLIP(nn.Module):
73
+ def __init__(self):
74
+ super().__init__()
75
+ self.embedding = CLIPEmbedding(49408, 768, 77)
76
+
77
+ self.layers = nn.ModuleList([
78
+ CLIPLayer(12, 768) for i in range(12)
79
+ ])
80
+
81
+ self.layernorm = nn.LayerNorm(768)
82
+
83
+ def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
84
+ tokens = tokens.type(torch.long)
85
+
86
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
87
+ state = self.embedding(tokens)
88
+
89
+ # Apply encoder layers similar to the Transformer's encoder.
90
+ for layer in self.layers:
91
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
92
+ state = layer(state)
93
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
94
+ output = self.layernorm(state)
95
+
96
+ return output
SD/ddpm.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ class DDPMSampler:
5
+
6
+ def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
7
+ # Params "beta_start" and "beta_end" taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8
8
+ # For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf)
9
+ self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
10
+ self.alphas = 1.0 - self.betas
11
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
12
+ self.one = torch.tensor(1.0)
13
+
14
+ self.generator = generator
15
+
16
+ self.num_train_timesteps = num_training_steps
17
+ self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
18
+
19
+ def set_inference_timesteps(self, num_inference_steps=50):
20
+ self.num_inference_steps = num_inference_steps
21
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
22
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
23
+ self.timesteps = torch.from_numpy(timesteps)
24
+
25
+ def _get_previous_timestep(self, timestep: int) -> int:
26
+ prev_t = timestep - self.num_train_timesteps // self.num_inference_steps
27
+ return prev_t
28
+
29
+ def _get_variance(self, timestep: int) -> torch.Tensor:
30
+ prev_t = self._get_previous_timestep(timestep)
31
+
32
+ alpha_prod_t = self.alphas_cumprod[timestep]
33
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
34
+ current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
35
+
36
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
37
+ # and sample from it to get previous sample
38
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
39
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
40
+
41
+ # we always take the log of variance, so clamp it to ensure it's not 0
42
+ variance = torch.clamp(variance, min=1e-20)
43
+
44
+ return variance
45
+
46
+ def set_strength(self, strength=1):
47
+ """
48
+ Set how much noise to add to the input image.
49
+ More noise (strength ~ 1) means that the output will be further from the input image.
50
+ Less noise (strength ~ 0) means that the output will be closer to the input image.
51
+ """
52
+ # start_step is the number of noise levels to skip
53
+ start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
54
+ self.timesteps = self.timesteps[start_step:]
55
+ self.start_step = start_step
56
+
57
+ def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
58
+ t = timestep
59
+ prev_t = self._get_previous_timestep(t)
60
+
61
+ # 1. compute alphas, betas
62
+ alpha_prod_t = self.alphas_cumprod[t]
63
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
64
+ beta_prod_t = 1 - alpha_prod_t
65
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
66
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
67
+ current_beta_t = 1 - current_alpha_t
68
+
69
+ # 2. compute predicted original sample from predicted noise also called
70
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
71
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
72
+
73
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
74
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
75
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
76
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
77
+
78
+ # 5. Compute predicted previous sample µ_t
79
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
80
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
81
+
82
+ # 6. Add noise
83
+ variance = 0
84
+ if t > 0:
85
+ device = model_output.device
86
+ noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
87
+ # Compute the variance as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf
88
+ variance = (self._get_variance(t) ** 0.5) * noise
89
+
90
+ # sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
91
+ # the variable "variance" is already multiplied by the noise N(0, 1)
92
+ pred_prev_sample = pred_prev_sample + variance
93
+
94
+ return pred_prev_sample
95
+
96
+ def add_noise(
97
+ self,
98
+ original_samples: torch.FloatTensor,
99
+ timesteps: torch.IntTensor,
100
+ ) -> torch.FloatTensor:
101
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
102
+ timesteps = timesteps.to(original_samples.device)
103
+
104
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
105
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
106
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
107
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
108
+
109
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
110
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
111
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
112
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
113
+
114
+ # Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
115
+ # Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
116
+ # here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
117
+ noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
118
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
119
+ return noisy_samples
120
+
121
+
122
+
123
+
SD/decoder.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from attention import SelfAttention
5
+
6
+ class VAE_AttentionBlock(nn.Module):
7
+ def __init__(self, channels):
8
+ super().__init__()
9
+ self.groupnorm = nn.GroupNorm(32, channels)
10
+ self.attention = SelfAttention(1, channels)
11
+
12
+ def forward(self, x):
13
+ # x: (Batch_Size, Features, Height, Width)
14
+
15
+ residue = x
16
+
17
+ # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
18
+ x = self.groupnorm(x)
19
+
20
+ n, c, h, w = x.shape
21
+
22
+ # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)
23
+ x = x.view((n, c, h * w))
24
+
25
+ # (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features). Each pixel becomes a feature of size "Features", the sequence length is "Height * Width".
26
+ x = x.transpose(-1, -2)
27
+
28
+ # Perform self-attention WITHOUT mask
29
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
30
+ x = self.attention(x)
31
+
32
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)
33
+ x = x.transpose(-1, -2)
34
+
35
+ # (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width)
36
+ x = x.view((n, c, h, w))
37
+
38
+ # (Batch_Size, Features, Height, Width) + (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
39
+ x += residue
40
+
41
+ # (Batch_Size, Features, Height, Width)
42
+ return x
43
+
44
+ class VAE_ResidualBlock(nn.Module):
45
+ def __init__(self, in_channels, out_channels):
46
+ super().__init__()
47
+ self.groupnorm_1 = nn.GroupNorm(32, in_channels)
48
+ self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
49
+
50
+ self.groupnorm_2 = nn.GroupNorm(32, out_channels)
51
+ self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
52
+
53
+ if in_channels == out_channels:
54
+ self.residual_layer = nn.Identity()
55
+ else:
56
+ self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
57
+
58
+ def forward(self, x):
59
+ # x: (Batch_Size, In_Channels, Height, Width)
60
+
61
+ residue = x
62
+
63
+ # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
64
+ x = self.groupnorm_1(x)
65
+
66
+ # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
67
+ x = F.silu(x)
68
+
69
+ # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
70
+ x = self.conv_1(x)
71
+
72
+ # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
73
+ x = self.groupnorm_2(x)
74
+
75
+ # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
76
+ x = F.silu(x)
77
+
78
+ # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
79
+ x = self.conv_2(x)
80
+
81
+ # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
82
+ return x + self.residual_layer(residue)
83
+
84
+ class VAE_Decoder(nn.Sequential):
85
+ def __init__(self):
86
+ super().__init__(
87
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
88
+ nn.Conv2d(4, 4, kernel_size=1, padding=0),
89
+
90
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
91
+ nn.Conv2d(4, 512, kernel_size=3, padding=1),
92
+
93
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
94
+ VAE_ResidualBlock(512, 512),
95
+
96
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
97
+ VAE_AttentionBlock(512),
98
+
99
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
100
+ VAE_ResidualBlock(512, 512),
101
+
102
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
103
+ VAE_ResidualBlock(512, 512),
104
+
105
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
106
+ VAE_ResidualBlock(512, 512),
107
+
108
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
109
+ VAE_ResidualBlock(512, 512),
110
+
111
+ # Repeats the rows and columns of the data by scale_factor (like when you resize an image by doubling its size).
112
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 4, Width / 4)
113
+ nn.Upsample(scale_factor=2),
114
+
115
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
116
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
117
+
118
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
119
+ VAE_ResidualBlock(512, 512),
120
+
121
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
122
+ VAE_ResidualBlock(512, 512),
123
+
124
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
125
+ VAE_ResidualBlock(512, 512),
126
+
127
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 2, Width / 2)
128
+ nn.Upsample(scale_factor=2),
129
+
130
+ # (Batch_Size, 512, Height / 2, Width / 2) -> (Batch_Size, 512, Height / 2, Width / 2)
131
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
132
+
133
+ # (Batch_Size, 512, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
134
+ VAE_ResidualBlock(512, 256),
135
+
136
+ # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
137
+ VAE_ResidualBlock(256, 256),
138
+
139
+ # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
140
+ VAE_ResidualBlock(256, 256),
141
+
142
+ # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height, Width)
143
+ nn.Upsample(scale_factor=2),
144
+
145
+ # (Batch_Size, 256, Height, Width) -> (Batch_Size, 256, Height, Width)
146
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
147
+
148
+ # (Batch_Size, 256, Height, Width) -> (Batch_Size, 128, Height, Width)
149
+ VAE_ResidualBlock(256, 128),
150
+
151
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
152
+ VAE_ResidualBlock(128, 128),
153
+
154
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
155
+ VAE_ResidualBlock(128, 128),
156
+
157
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
158
+ nn.GroupNorm(32, 128),
159
+
160
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
161
+ nn.SiLU(),
162
+
163
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 3, Height, Width)
164
+ nn.Conv2d(128, 3, kernel_size=3, padding=1),
165
+ )
166
+
167
+ def forward(self, x):
168
+ # x: (Batch_Size, 4, Height / 8, Width / 8)
169
+
170
+ # Remove the scaling added by the Encoder.
171
+ x /= 0.18215
172
+
173
+ for module in self:
174
+ x = module(x)
175
+
176
+ # (Batch_Size, 3, Height, Width)
177
+ return x
SD/diffusion.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from attention import SelfAttention, CrossAttention
5
+
6
+ class TimeEmbedding(nn.Module):
7
+ def __init__(self, n_embd):
8
+ super().__init__()
9
+ self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
10
+ self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
11
+
12
+ def forward(self, x):
13
+ # x: (1, 320)
14
+
15
+ # (1, 320) -> (1, 1280)
16
+ x = self.linear_1(x)
17
+
18
+ # (1, 1280) -> (1, 1280)
19
+ x = F.silu(x)
20
+
21
+ # (1, 1280) -> (1, 1280)
22
+ x = self.linear_2(x)
23
+
24
+ return x
25
+
26
+ class UNET_ResidualBlock(nn.Module):
27
+ def __init__(self, in_channels, out_channels, n_time=1280):
28
+ super().__init__()
29
+ self.groupnorm_feature = nn.GroupNorm(32, in_channels)
30
+ self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
31
+ self.linear_time = nn.Linear(n_time, out_channels)
32
+
33
+ self.groupnorm_merged = nn.GroupNorm(32, out_channels)
34
+ self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
35
+
36
+ if in_channels == out_channels:
37
+ self.residual_layer = nn.Identity()
38
+ else:
39
+ self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
40
+
41
+ def forward(self, feature, time):
42
+ # feature: (Batch_Size, In_Channels, Height, Width)
43
+ # time: (1, 1280)
44
+
45
+ residue = feature
46
+
47
+ # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
48
+ feature = self.groupnorm_feature(feature)
49
+
50
+ # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
51
+ feature = F.silu(feature)
52
+
53
+ # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
54
+ feature = self.conv_feature(feature)
55
+
56
+ # (1, 1280) -> (1, 1280)
57
+ time = F.silu(time)
58
+
59
+ # (1, 1280) -> (1, Out_Channels)
60
+ time = self.linear_time(time)
61
+
62
+ # Add width and height dimension to time.
63
+ # (Batch_Size, Out_Channels, Height, Width) + (1, Out_Channels, 1, 1) -> (Batch_Size, Out_Channels, Height, Width)
64
+ merged = feature + time.unsqueeze(-1).unsqueeze(-1)
65
+
66
+ # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
67
+ merged = self.groupnorm_merged(merged)
68
+
69
+ # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
70
+ merged = F.silu(merged)
71
+
72
+ # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
73
+ merged = self.conv_merged(merged)
74
+
75
+ # (Batch_Size, Out_Channels, Height, Width) + (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
76
+ return merged + self.residual_layer(residue)
77
+
78
+ class UNET_AttentionBlock(nn.Module):
79
+ def __init__(self, n_head: int, n_embd: int, d_context=768):
80
+ super().__init__()
81
+ channels = n_head * n_embd
82
+
83
+ self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
84
+ self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
85
+
86
+ self.layernorm_1 = nn.LayerNorm(channels)
87
+ self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
88
+ self.layernorm_2 = nn.LayerNorm(channels)
89
+ self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
90
+ self.layernorm_3 = nn.LayerNorm(channels)
91
+ self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
92
+ self.linear_geglu_2 = nn.Linear(4 * channels, channels)
93
+
94
+ self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
95
+
96
+ def forward(self, x, context):
97
+ # x: (Batch_Size, Features, Height, Width)
98
+ # context: (Batch_Size, Seq_Len, Dim)
99
+
100
+ residue_long = x
101
+
102
+ # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
103
+ x = self.groupnorm(x)
104
+
105
+ # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
106
+ x = self.conv_input(x)
107
+
108
+ n, c, h, w = x.shape
109
+
110
+ # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)
111
+ x = x.view((n, c, h * w))
112
+
113
+ # (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features)
114
+ x = x.transpose(-1, -2)
115
+
116
+ # Normalization + Self-Attention with skip connection
117
+
118
+ # (Batch_Size, Height * Width, Features)
119
+ residue_short = x
120
+
121
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
122
+ x = self.layernorm_1(x)
123
+
124
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
125
+ x = self.attention_1(x)
126
+
127
+ # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
128
+ x += residue_short
129
+
130
+ # (Batch_Size, Height * Width, Features)
131
+ residue_short = x
132
+
133
+ # Normalization + Cross-Attention with skip connection
134
+
135
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
136
+ x = self.layernorm_2(x)
137
+
138
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
139
+ x = self.attention_2(x, context)
140
+
141
+ # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
142
+ x += residue_short
143
+
144
+ # (Batch_Size, Height * Width, Features)
145
+ residue_short = x
146
+
147
+ # Normalization + FFN with GeGLU and skip connection
148
+
149
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
150
+ x = self.layernorm_3(x)
151
+
152
+ # GeGLU as implemented in the original code: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/attention.py#L37C10-L37C10
153
+ # (Batch_Size, Height * Width, Features) -> two tensors of shape (Batch_Size, Height * Width, Features * 4)
154
+ x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
155
+
156
+ # Element-wise product: (Batch_Size, Height * Width, Features * 4) * (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features * 4)
157
+ x = x * F.gelu(gate)
158
+
159
+ # (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features)
160
+ x = self.linear_geglu_2(x)
161
+
162
+ # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
163
+ x += residue_short
164
+
165
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)
166
+ x = x.transpose(-1, -2)
167
+
168
+ # (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width)
169
+ x = x.view((n, c, h, w))
170
+
171
+ # Final skip connection between initial input and output of the block
172
+ # (Batch_Size, Features, Height, Width) + (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
173
+ return self.conv_output(x) + residue_long
174
+
175
+ class Upsample(nn.Module):
176
+ def __init__(self, channels):
177
+ super().__init__()
178
+ self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
179
+
180
+ def forward(self, x):
181
+ # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * 2, Width * 2)
182
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
183
+ return self.conv(x)
184
+
185
+ class SwitchSequential(nn.Sequential):
186
+ def forward(self, x, context, time):
187
+ for layer in self:
188
+ if isinstance(layer, UNET_AttentionBlock):
189
+ x = layer(x, context)
190
+ elif isinstance(layer, UNET_ResidualBlock):
191
+ x = layer(x, time)
192
+ else:
193
+ x = layer(x)
194
+ return x
195
+
196
+ class UNET(nn.Module):
197
+ def __init__(self):
198
+ super().__init__()
199
+ self.encoders = nn.ModuleList([
200
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
201
+ SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
202
+
203
+ # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
204
+ SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
205
+
206
+ # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
207
+ SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
208
+
209
+ # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 16, Width / 16)
210
+ SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
211
+
212
+ # (Batch_Size, 320, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
213
+ SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
214
+
215
+ # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
216
+ SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
217
+
218
+ # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 32, Width / 32)
219
+ SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
220
+
221
+ # (Batch_Size, 640, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
222
+ SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
223
+
224
+ # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
225
+ SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
226
+
227
+ # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 64, Width / 64)
228
+ SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
229
+
230
+ # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
231
+ SwitchSequential(UNET_ResidualBlock(1280, 1280)),
232
+
233
+ # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
234
+ SwitchSequential(UNET_ResidualBlock(1280, 1280)),
235
+ ])
236
+
237
+ self.bottleneck = SwitchSequential(
238
+ # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
239
+ UNET_ResidualBlock(1280, 1280),
240
+
241
+ # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
242
+ UNET_AttentionBlock(8, 160),
243
+
244
+ # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
245
+ UNET_ResidualBlock(1280, 1280),
246
+ )
247
+
248
+ self.decoders = nn.ModuleList([
249
+ # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
250
+ SwitchSequential(UNET_ResidualBlock(2560, 1280)),
251
+
252
+ # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
253
+ SwitchSequential(UNET_ResidualBlock(2560, 1280)),
254
+
255
+ # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 32, Width / 32)
256
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
257
+
258
+ # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
259
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
260
+
261
+ # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
262
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
263
+
264
+ # (Batch_Size, 1920, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 16, Width / 16)
265
+ SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
266
+
267
+ # (Batch_Size, 1920, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
268
+ SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
269
+
270
+ # (Batch_Size, 1280, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
271
+ SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
272
+
273
+ # (Batch_Size, 960, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 8, Width / 8)
274
+ SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
275
+
276
+ # (Batch_Size, 960, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
277
+ SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
278
+
279
+ # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
280
+ SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
281
+
282
+ # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
283
+ SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
284
+ ])
285
+
286
+ def forward(self, x, context, time):
287
+ # x: (Batch_Size, 4, Height / 8, Width / 8)
288
+ # context: (Batch_Size, Seq_Len, Dim)
289
+ # time: (1, 1280)
290
+
291
+ skip_connections = []
292
+ for layers in self.encoders:
293
+ x = layers(x, context, time)
294
+ skip_connections.append(x)
295
+
296
+ x = self.bottleneck(x, context, time)
297
+
298
+ for layers in self.decoders:
299
+ # Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layer
300
+ x = torch.cat((x, skip_connections.pop()), dim=1)
301
+ x = layers(x, context, time)
302
+
303
+ return x
304
+
305
+
306
+ class UNET_OutputLayer(nn.Module):
307
+ def __init__(self, in_channels, out_channels):
308
+ super().__init__()
309
+ self.groupnorm = nn.GroupNorm(32, in_channels)
310
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
311
+
312
+ def forward(self, x):
313
+ # x: (Batch_Size, 320, Height / 8, Width / 8)
314
+
315
+ # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
316
+ x = self.groupnorm(x)
317
+
318
+ # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
319
+ x = F.silu(x)
320
+
321
+ # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
322
+ x = self.conv(x)
323
+
324
+ # (Batch_Size, 4, Height / 8, Width / 8)
325
+ return x
326
+
327
+ class Diffusion(nn.Module):
328
+ def __init__(self):
329
+ super().__init__()
330
+ self.time_embedding = TimeEmbedding(320)
331
+ self.unet = UNET()
332
+ self.final = UNET_OutputLayer(320, 4)
333
+
334
+ def forward(self, latent, context, time):
335
+ # latent: (Batch_Size, 4, Height / 8, Width / 8)
336
+ # context: (Batch_Size, Seq_Len, Dim)
337
+ # time: (1, 320)
338
+
339
+ # (1, 320) -> (1, 1280)
340
+ time = self.time_embedding(time)
341
+
342
+ # (Batch, 4, Height / 8, Width / 8) -> (Batch, 320, Height / 8, Width / 8)
343
+ output = self.unet(latent, context, time)
344
+
345
+ # (Batch, 320, Height / 8, Width / 8) -> (Batch, 4, Height / 8, Width / 8)
346
+ output = self.final(output)
347
+
348
+ # (Batch, 4, Height / 8, Width / 8)
349
+ return output
SD/encoder.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from decoder import VAE_AttentionBlock, VAE_ResidualBlock
5
+
6
+ class VAE_Encoder(nn.Sequential):
7
+ def __init__(self):
8
+ super().__init__(
9
+ # (Batch_Size, Channel, Height, Width) -> (Batch_Size, 128, Height, Width)
10
+ nn.Conv2d(3, 128, kernel_size=3, padding=1),
11
+
12
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
13
+ VAE_ResidualBlock(128, 128),
14
+
15
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
16
+ VAE_ResidualBlock(128, 128),
17
+
18
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height / 2, Width / 2)
19
+ nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
20
+
21
+ # (Batch_Size, 128, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
22
+ VAE_ResidualBlock(128, 256),
23
+
24
+ # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
25
+ VAE_ResidualBlock(256, 256),
26
+
27
+ # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 4, Width / 4)
28
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
29
+
30
+ # (Batch_Size, 256, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
31
+ VAE_ResidualBlock(256, 512),
32
+
33
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
34
+ VAE_ResidualBlock(512, 512),
35
+
36
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 8, Width / 8)
37
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
38
+
39
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
40
+ VAE_ResidualBlock(512, 512),
41
+
42
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
43
+ VAE_ResidualBlock(512, 512),
44
+
45
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
46
+ VAE_ResidualBlock(512, 512),
47
+
48
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
49
+ VAE_AttentionBlock(512),
50
+
51
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
52
+ VAE_ResidualBlock(512, 512),
53
+
54
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
55
+ nn.GroupNorm(32, 512),
56
+
57
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
58
+ nn.SiLU(),
59
+
60
+ # Because the padding=1, it means the width and height will increase by 2
61
+ # Out_Height = In_Height + Padding_Top + Padding_Bottom
62
+ # Out_Width = In_Width + Padding_Left + Padding_Right
63
+ # Since padding = 1 means Padding_Top = Padding_Bottom = Padding_Left = Padding_Right = 1,
64
+ # Since the Out_Width = In_Width + 2 (same for Out_Height), it will compensate for the Kernel size of 3
65
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 8, Height / 8, Width / 8).
66
+ nn.Conv2d(512, 8, kernel_size=3, padding=1),
67
+
68
+ # (Batch_Size, 8, Height / 8, Width / 8) -> (Batch_Size, 8, Height / 8, Width / 8)
69
+ nn.Conv2d(8, 8, kernel_size=1, padding=0),
70
+ )
71
+
72
+ def forward(self, x, noise):
73
+ # x: (Batch_Size, Channel, Height, Width)
74
+ # noise: (Batch_Size, 4, Height / 8, Width / 8)
75
+
76
+ for module in self:
77
+
78
+ if getattr(module, 'stride', None) == (2, 2): # Padding at downsampling should be asymmetric (see #8)
79
+ # Pad: (Padding_Left, Padding_Right, Padding_Top, Padding_Bottom).
80
+ # Pad with zeros on the right and bottom.
81
+ # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Channel, Height + Padding_Top + Padding_Bottom, Width + Padding_Left + Padding_Right) = (Batch_Size, Channel, Height + 1, Width + 1)
82
+ x = F.pad(x, (0, 1, 0, 1))
83
+
84
+ x = module(x)
85
+ # (Batch_Size, 8, Height / 8, Width / 8) -> two tensors of shape (Batch_Size, 4, Height / 8, Width / 8)
86
+ mean, log_variance = torch.chunk(x, 2, dim=1)
87
+ # Clamp the log variance between -30 and 20, so that the variance is between (circa) 1e-14 and 1e8.
88
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
89
+ log_variance = torch.clamp(log_variance, -30, 20)
90
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
91
+ variance = log_variance.exp()
92
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
93
+ stdev = variance.sqrt()
94
+
95
+ # Transform N(0, 1) -> N(mean, stdev)
96
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
97
+ x = mean + stdev * noise
98
+
99
+ # Scale by a constant
100
+ # Constant taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L17C1-L17C1
101
+ x *= 0.18215
102
+
103
+ return x
SD/model_converter.py ADDED
The diff for this file is too large to render. See raw diff
 
SD/model_loader.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from clip import CLIP
2
+ from encoder import VAE_Encoder
3
+ from decoder import VAE_Decoder
4
+ from diffusion import Diffusion
5
+
6
+ import model_converter
7
+
8
+ def preload_models_from_standard_weights(ckpt_path, device):
9
+ state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
10
+
11
+ encoder = VAE_Encoder().to(device)
12
+ encoder.load_state_dict(state_dict['encoder'], strict=True)
13
+
14
+ decoder = VAE_Decoder().to(device)
15
+ decoder.load_state_dict(state_dict['decoder'], strict=True)
16
+
17
+ diffusion = Diffusion().to(device)
18
+ diffusion.load_state_dict(state_dict['diffusion'], strict=True)
19
+
20
+ clip = CLIP().to(device)
21
+ clip.load_state_dict(state_dict['clip'], strict=True)
22
+
23
+ return {
24
+ 'clip': clip,
25
+ 'encoder': encoder,
26
+ 'decoder': decoder,
27
+ 'diffusion': diffusion,
28
+ }
SD/pipeline.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from ddpm import DDPMSampler
5
+
6
+ WIDTH = 512
7
+ HEIGHT = 512
8
+ LATENTS_WIDTH = WIDTH // 8
9
+ LATENTS_HEIGHT = HEIGHT // 8
10
+
11
+ def generate(
12
+ prompt,
13
+ uncond_prompt=None,
14
+ input_image=None,
15
+ strength=0.8,
16
+ do_cfg=True,
17
+ cfg_scale=7.5,
18
+ sampler_name="ddpm",
19
+ n_inference_steps=50,
20
+ models={},
21
+ seed=None,
22
+ device=None,
23
+ idle_device=None,
24
+ tokenizer=None,
25
+ ):
26
+ with torch.no_grad():
27
+ if not 0 < strength <= 1:
28
+ raise ValueError("strength must be between 0 and 1")
29
+
30
+ if idle_device:
31
+ to_idle = lambda x: x.to(idle_device)
32
+ else:
33
+ to_idle = lambda x: x
34
+
35
+ # Initialize random number generator according to the seed specified
36
+ generator = torch.Generator(device=device)
37
+ if seed is None:
38
+ generator.seed()
39
+ else:
40
+ generator.manual_seed(seed)
41
+
42
+ clip = models["clip"]
43
+ clip.to(device)
44
+
45
+ if do_cfg:
46
+ # Convert into a list of length Seq_Len=77
47
+ cond_tokens = tokenizer.batch_encode_plus(
48
+ [prompt], padding="max_length", max_length=77
49
+ ).input_ids
50
+ # (Batch_Size, Seq_Len)
51
+ cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
52
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
53
+ cond_context = clip(cond_tokens)
54
+ # Convert into a list of length Seq_Len=77
55
+ uncond_tokens = tokenizer.batch_encode_plus(
56
+ [uncond_prompt], padding="max_length", max_length=77
57
+ ).input_ids
58
+ # (Batch_Size, Seq_Len)
59
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
60
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
61
+ uncond_context = clip(uncond_tokens)
62
+ # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (2 * Batch_Size, Seq_Len, Dim)
63
+ context = torch.cat([cond_context, uncond_context])
64
+ else:
65
+ # Convert into a list of length Seq_Len=77
66
+ tokens = tokenizer.batch_encode_plus(
67
+ [prompt], padding="max_length", max_length=77
68
+ ).input_ids
69
+ # (Batch_Size, Seq_Len)
70
+ tokens = torch.tensor(tokens, dtype=torch.long, device=device)
71
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
72
+ context = clip(tokens)
73
+ to_idle(clip)
74
+
75
+ if sampler_name == "ddpm":
76
+ sampler = DDPMSampler(generator)
77
+ sampler.set_inference_timesteps(n_inference_steps)
78
+ else:
79
+ raise ValueError("Unknown sampler value %s. ")
80
+
81
+ latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
82
+
83
+ if input_image:
84
+ encoder = models["encoder"]
85
+ encoder.to(device)
86
+
87
+ input_image_tensor = input_image.resize((WIDTH, HEIGHT))
88
+ # (Height, Width, Channel)
89
+ input_image_tensor = np.array(input_image_tensor)
90
+ # (Height, Width, Channel) -> (Height, Width, Channel)
91
+ input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
92
+ # (Height, Width, Channel) -> (Height, Width, Channel)
93
+ input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
94
+ # (Height, Width, Channel) -> (Batch_Size, Height, Width, Channel)
95
+ input_image_tensor = input_image_tensor.unsqueeze(0)
96
+ # (Batch_Size, Height, Width, Channel) -> (Batch_Size, Channel, Height, Width)
97
+ input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
98
+
99
+ # (Batch_Size, 4, Latents_Height, Latents_Width)
100
+ encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
101
+ # (Batch_Size, 4, Latents_Height, Latents_Width)
102
+ latents = encoder(input_image_tensor, encoder_noise)
103
+
104
+ # Add noise to the latents (the encoded input image)
105
+ # (Batch_Size, 4, Latents_Height, Latents_Width)
106
+ sampler.set_strength(strength=strength)
107
+ latents = sampler.add_noise(latents, sampler.timesteps[0])
108
+
109
+ to_idle(encoder)
110
+ else:
111
+ # (Batch_Size, 4, Latents_Height, Latents_Width)
112
+ latents = torch.randn(latents_shape, generator=generator, device=device)
113
+
114
+ diffusion = models["diffusion"]
115
+ diffusion.to(device)
116
+
117
+ timesteps = tqdm(sampler.timesteps)
118
+ for i, timestep in enumerate(timesteps):
119
+ # (1, 320)
120
+ time_embedding = get_time_embedding(timestep).to(device)
121
+
122
+ # (Batch_Size, 4, Latents_Height, Latents_Width)
123
+ model_input = latents
124
+
125
+ if do_cfg:
126
+ # (Batch_Size, 4, Latents_Height, Latents_Width) -> (2 * Batch_Size, 4, Latents_Height, Latents_Width)
127
+ model_input = model_input.repeat(2, 1, 1, 1)
128
+
129
+ # model_output is the predicted noise
130
+ # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
131
+ model_output = diffusion(model_input, context, time_embedding)
132
+
133
+ if do_cfg:
134
+ output_cond, output_uncond = model_output.chunk(2)
135
+ model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
136
+
137
+ # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
138
+ latents = sampler.step(timestep, latents, model_output)
139
+
140
+ to_idle(diffusion)
141
+
142
+ decoder = models["decoder"]
143
+ decoder.to(device)
144
+ # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 3, Height, Width)
145
+ images = decoder(latents)
146
+ to_idle(decoder)
147
+
148
+ images = rescale(images, (-1, 1), (0, 255), clamp=True)
149
+ # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Height, Width, Channel)
150
+ images = images.permute(0, 2, 3, 1)
151
+ images = images.to("cpu", torch.uint8).numpy()
152
+ return images[0]
153
+
154
+ def rescale(x, old_range, new_range, clamp=False):
155
+ old_min, old_max = old_range
156
+ new_min, new_max = new_range
157
+ x -= old_min
158
+ x *= (new_max - new_min) / (old_max - old_min)
159
+ x += new_min
160
+ if clamp:
161
+ x = x.clamp(new_min, new_max)
162
+ return x
163
+
164
+ def get_time_embedding(timestep):
165
+ # Shape: (160,)
166
+ freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
167
+ # Shape: (1, 160)
168
+ x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
169
+ # Shape: (1, 160 * 2)
170
+ return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
SD/run.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import model_loader
2
+ import pipeline
3
+ from PIL import Image
4
+ from pathlib import Path
5
+ from transformers import CLIPTokenizer
6
+ import torch
7
+
8
+
9
+ DEVICE = "cpu"
10
+
11
+ ALLOW_CUDA = True
12
+ ALLOW_MPS = False
13
+
14
+ if torch.cuda.is_available() and ALLOW_CUDA:
15
+ DEVICE = "cuda"
16
+
17
+ print(f"Using device: {DEVICE}")
18
+
19
+ tokenizer = CLIPTokenizer("../data/tokenizer_vocab.json", merges_file="../data/tokenizer_merges.txt")
20
+ model_file = "../data/v1-5-pruned-emaonly.ckpt"
21
+ models = model_loader.preload_models_from_standard_weights(model_file, device=DEVICE)
22
+
23
+ ## TEXT TO IMAGE
24
+
25
+ # prompt = "A dog with sunglasses, wearing comfy hat, looking at camera, highly detailed, ultra sharp, cinematic, 100mm lens, 8k resolution."
26
+ prompt = "A cat stretching on the floor, highly detailed, ultra sharp, cinematic, 100mm lens, 8k resolution."
27
+ uncond_prompt = "" # Also known as negative prom pt
28
+ do_cfg = True
29
+ cfg_scale = 8 # min: 1, max: 14
30
+
31
+ ## IMAGE TO IMAGE
32
+
33
+ input_image = None
34
+ # Comment to disable image to image
35
+ image_path = "../images/dog.jpg"
36
+ # input_image = Image.open(image_path)
37
+ # Higher values means more noise will be added to the input image, so the result will further from the input image.
38
+ # Lower values means less noise is added to the input image, so output will be closer to the input image.
39
+ strength = 0.9
40
+
41
+ ## SAMPLER
42
+
43
+ sampler = "ddpm"
44
+ num_inference_steps = 2
45
+ seed = 42
46
+
47
+ output_image = pipeline.generate(
48
+ prompt=prompt,
49
+ uncond_prompt=uncond_prompt,
50
+ input_image=input_image,
51
+ strength=strength,
52
+ do_cfg=do_cfg,
53
+ cfg_scale=cfg_scale,
54
+ sampler_name=sampler,
55
+ n_inference_steps=num_inference_steps,
56
+ seed=seed,
57
+ models=models,
58
+ device=DEVICE,
59
+ idle_device="cpu",
60
+ tokenizer=tokenizer,
61
+ )
62
+
63
+ # Combine the input image and the output image into a single image.
64
+ Image.fromarray(output_image)
SD/sd_demo.ipynb ADDED
The diff for this file is too large to render. See raw diff