aiqtech commited on
Commit
8ffc581
·
verified ·
1 Parent(s): 94318b2

Upload 2 files

Browse files
Files changed (2) hide show
  1. controlnet_union (1).py +1085 -0
  2. pipeline_fill_sd_xl (1).py +559 -0
controlnet_union (1).py ADDED
@@ -0,0 +1,1085 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from collections import OrderedDict
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.loaders import FromOriginalModelMixin
21
+ from diffusers.models.attention_processor import (
22
+ ADDED_KV_ATTENTION_PROCESSORS,
23
+ CROSS_ATTENTION_PROCESSORS,
24
+ AttentionProcessor,
25
+ AttnAddedKVProcessor,
26
+ AttnProcessor,
27
+ )
28
+ from diffusers.models.embeddings import (
29
+ TextImageProjection,
30
+ TextImageTimeEmbedding,
31
+ TextTimeEmbedding,
32
+ TimestepEmbedding,
33
+ Timesteps,
34
+ )
35
+ from diffusers.models.modeling_utils import ModelMixin
36
+ from diffusers.models.unets.unet_2d_blocks import (
37
+ CrossAttnDownBlock2D,
38
+ DownBlock2D,
39
+ UNetMidBlock2DCrossAttn,
40
+ get_down_block,
41
+ )
42
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
43
+ from diffusers.utils import BaseOutput, logging
44
+ from torch import nn
45
+ from torch.nn import functional as F
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+
50
+ # Transformer Block
51
+ # Used to exchange info between different conditions and input image
52
+ # With reference to https://github.com/TencentARC/T2I-Adapter/blob/SD/ldm/modules/encoders/adapter.py#L147
53
+ class QuickGELU(nn.Module):
54
+ def forward(self, x: torch.Tensor):
55
+ return x * torch.sigmoid(1.702 * x)
56
+
57
+
58
+ class LayerNorm(nn.LayerNorm):
59
+ """Subclass torch's LayerNorm to handle fp16."""
60
+
61
+ def forward(self, x: torch.Tensor):
62
+ orig_type = x.dtype
63
+ ret = super().forward(x)
64
+ return ret.type(orig_type)
65
+
66
+
67
+ class ResidualAttentionBlock(nn.Module):
68
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
69
+ super().__init__()
70
+
71
+ self.attn = nn.MultiheadAttention(d_model, n_head)
72
+ self.ln_1 = LayerNorm(d_model)
73
+ self.mlp = nn.Sequential(
74
+ OrderedDict(
75
+ [
76
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
77
+ ("gelu", QuickGELU()),
78
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
79
+ ]
80
+ )
81
+ )
82
+ self.ln_2 = LayerNorm(d_model)
83
+ self.attn_mask = attn_mask
84
+
85
+ def attention(self, x: torch.Tensor):
86
+ self.attn_mask = (
87
+ self.attn_mask.to(dtype=x.dtype, device=x.device)
88
+ if self.attn_mask is not None
89
+ else None
90
+ )
91
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
92
+
93
+ def forward(self, x: torch.Tensor):
94
+ x = x + self.attention(self.ln_1(x))
95
+ x = x + self.mlp(self.ln_2(x))
96
+ return x
97
+
98
+
99
+ # -----------------------------------------------------------------------------------------------------
100
+
101
+
102
+ @dataclass
103
+ class ControlNetOutput(BaseOutput):
104
+ """
105
+ The output of [`ControlNetModel`].
106
+
107
+ Args:
108
+ down_block_res_samples (`tuple[torch.Tensor]`):
109
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
110
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
111
+ used to condition the original UNet's downsampling activations.
112
+ mid_down_block_re_sample (`torch.Tensor`):
113
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
114
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
115
+ Output can be used to condition the original UNet's middle block activation.
116
+ """
117
+
118
+ down_block_res_samples: Tuple[torch.Tensor]
119
+ mid_block_res_sample: torch.Tensor
120
+
121
+
122
+ class ControlNetConditioningEmbedding(nn.Module):
123
+ """
124
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
125
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
126
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
127
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
128
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
129
+ model) to encode image-space conditions ... into feature maps ..."
130
+ """
131
+
132
+ # original setting is (16, 32, 96, 256)
133
+ def __init__(
134
+ self,
135
+ conditioning_embedding_channels: int,
136
+ conditioning_channels: int = 3,
137
+ block_out_channels: Tuple[int] = (48, 96, 192, 384),
138
+ ):
139
+ super().__init__()
140
+
141
+ self.conv_in = nn.Conv2d(
142
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
143
+ )
144
+
145
+ self.blocks = nn.ModuleList([])
146
+
147
+ for i in range(len(block_out_channels) - 1):
148
+ channel_in = block_out_channels[i]
149
+ channel_out = block_out_channels[i + 1]
150
+ self.blocks.append(
151
+ nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)
152
+ )
153
+ self.blocks.append(
154
+ nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)
155
+ )
156
+
157
+ self.conv_out = zero_module(
158
+ nn.Conv2d(
159
+ block_out_channels[-1],
160
+ conditioning_embedding_channels,
161
+ kernel_size=3,
162
+ padding=1,
163
+ )
164
+ )
165
+
166
+ def forward(self, conditioning):
167
+ embedding = self.conv_in(conditioning)
168
+ embedding = F.silu(embedding)
169
+
170
+ for block in self.blocks:
171
+ embedding = block(embedding)
172
+ embedding = F.silu(embedding)
173
+
174
+ embedding = self.conv_out(embedding)
175
+
176
+ return embedding
177
+
178
+
179
+ class ControlNetModel_Union(ModelMixin, ConfigMixin, FromOriginalModelMixin):
180
+ """
181
+ A ControlNet model.
182
+
183
+ Args:
184
+ in_channels (`int`, defaults to 4):
185
+ The number of channels in the input sample.
186
+ flip_sin_to_cos (`bool`, defaults to `True`):
187
+ Whether to flip the sin to cos in the time embedding.
188
+ freq_shift (`int`, defaults to 0):
189
+ The frequency shift to apply to the time embedding.
190
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
191
+ The tuple of downsample blocks to use.
192
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
193
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
194
+ The tuple of output channels for each block.
195
+ layers_per_block (`int`, defaults to 2):
196
+ The number of layers per block.
197
+ downsample_padding (`int`, defaults to 1):
198
+ The padding to use for the downsampling convolution.
199
+ mid_block_scale_factor (`float`, defaults to 1):
200
+ The scale factor to use for the mid block.
201
+ act_fn (`str`, defaults to "silu"):
202
+ The activation function to use.
203
+ norm_num_groups (`int`, *optional*, defaults to 32):
204
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
205
+ in post-processing.
206
+ norm_eps (`float`, defaults to 1e-5):
207
+ The epsilon to use for the normalization.
208
+ cross_attention_dim (`int`, defaults to 1280):
209
+ The dimension of the cross attention features.
210
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
211
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
212
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
213
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
214
+ encoder_hid_dim (`int`, *optional*, defaults to None):
215
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
216
+ dimension to `cross_attention_dim`.
217
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
218
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
219
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
220
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
221
+ The dimension of the attention heads.
222
+ use_linear_projection (`bool`, defaults to `False`):
223
+ class_embed_type (`str`, *optional*, defaults to `None`):
224
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
225
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
226
+ addition_embed_type (`str`, *optional*, defaults to `None`):
227
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
228
+ "text". "text" will use the `TextTimeEmbedding` layer.
229
+ num_class_embeds (`int`, *optional*, defaults to 0):
230
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
231
+ class conditioning with `class_embed_type` equal to `None`.
232
+ upcast_attention (`bool`, defaults to `False`):
233
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
234
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
235
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
236
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
237
+ `class_embed_type="projection"`.
238
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
239
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
240
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
241
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
242
+ global_pool_conditions (`bool`, defaults to `False`):
243
+ """
244
+
245
+ _supports_gradient_checkpointing = True
246
+
247
+ @register_to_config
248
+ def __init__(
249
+ self,
250
+ in_channels: int = 4,
251
+ conditioning_channels: int = 3,
252
+ flip_sin_to_cos: bool = True,
253
+ freq_shift: int = 0,
254
+ down_block_types: Tuple[str] = (
255
+ "CrossAttnDownBlock2D",
256
+ "CrossAttnDownBlock2D",
257
+ "CrossAttnDownBlock2D",
258
+ "DownBlock2D",
259
+ ),
260
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
261
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
262
+ layers_per_block: int = 2,
263
+ downsample_padding: int = 1,
264
+ mid_block_scale_factor: float = 1,
265
+ act_fn: str = "silu",
266
+ norm_num_groups: Optional[int] = 32,
267
+ norm_eps: float = 1e-5,
268
+ cross_attention_dim: int = 1280,
269
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
270
+ encoder_hid_dim: Optional[int] = None,
271
+ encoder_hid_dim_type: Optional[str] = None,
272
+ attention_head_dim: Union[int, Tuple[int]] = 8,
273
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
274
+ use_linear_projection: bool = False,
275
+ class_embed_type: Optional[str] = None,
276
+ addition_embed_type: Optional[str] = None,
277
+ addition_time_embed_dim: Optional[int] = None,
278
+ num_class_embeds: Optional[int] = None,
279
+ upcast_attention: bool = False,
280
+ resnet_time_scale_shift: str = "default",
281
+ projection_class_embeddings_input_dim: Optional[int] = None,
282
+ controlnet_conditioning_channel_order: str = "rgb",
283
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
284
+ global_pool_conditions: bool = False,
285
+ addition_embed_type_num_heads=64,
286
+ num_control_type=6,
287
+ ):
288
+ super().__init__()
289
+
290
+ # If `num_attention_heads` is not defined (which is the case for most models)
291
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
292
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
293
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
294
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
295
+ # which is why we correct for the naming here.
296
+ num_attention_heads = num_attention_heads or attention_head_dim
297
+
298
+ # Check inputs
299
+ if len(block_out_channels) != len(down_block_types):
300
+ raise ValueError(
301
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
302
+ )
303
+
304
+ if not isinstance(only_cross_attention, bool) and len(
305
+ only_cross_attention
306
+ ) != len(down_block_types):
307
+ raise ValueError(
308
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
309
+ )
310
+
311
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
312
+ down_block_types
313
+ ):
314
+ raise ValueError(
315
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
316
+ )
317
+
318
+ if isinstance(transformer_layers_per_block, int):
319
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
320
+ down_block_types
321
+ )
322
+
323
+ # input
324
+ conv_in_kernel = 3
325
+ conv_in_padding = (conv_in_kernel - 1) // 2
326
+ self.conv_in = nn.Conv2d(
327
+ in_channels,
328
+ block_out_channels[0],
329
+ kernel_size=conv_in_kernel,
330
+ padding=conv_in_padding,
331
+ )
332
+
333
+ # time
334
+ time_embed_dim = block_out_channels[0] * 4
335
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
336
+ timestep_input_dim = block_out_channels[0]
337
+ self.time_embedding = TimestepEmbedding(
338
+ timestep_input_dim,
339
+ time_embed_dim,
340
+ act_fn=act_fn,
341
+ )
342
+
343
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
344
+ encoder_hid_dim_type = "text_proj"
345
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
346
+ logger.info(
347
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
348
+ )
349
+
350
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
351
+ raise ValueError(
352
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
353
+ )
354
+
355
+ if encoder_hid_dim_type == "text_proj":
356
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
357
+ elif encoder_hid_dim_type == "text_image_proj":
358
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
359
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
360
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
361
+ self.encoder_hid_proj = TextImageProjection(
362
+ text_embed_dim=encoder_hid_dim,
363
+ image_embed_dim=cross_attention_dim,
364
+ cross_attention_dim=cross_attention_dim,
365
+ )
366
+
367
+ elif encoder_hid_dim_type is not None:
368
+ raise ValueError(
369
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
370
+ )
371
+ else:
372
+ self.encoder_hid_proj = None
373
+
374
+ # class embedding
375
+ if class_embed_type is None and num_class_embeds is not None:
376
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
377
+ elif class_embed_type == "timestep":
378
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
379
+ elif class_embed_type == "identity":
380
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
381
+ elif class_embed_type == "projection":
382
+ if projection_class_embeddings_input_dim is None:
383
+ raise ValueError(
384
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
385
+ )
386
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
387
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
388
+ # 2. it projects from an arbitrary input dimension.
389
+ #
390
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
391
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
392
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
393
+ self.class_embedding = TimestepEmbedding(
394
+ projection_class_embeddings_input_dim, time_embed_dim
395
+ )
396
+ else:
397
+ self.class_embedding = None
398
+
399
+ if addition_embed_type == "text":
400
+ if encoder_hid_dim is not None:
401
+ text_time_embedding_from_dim = encoder_hid_dim
402
+ else:
403
+ text_time_embedding_from_dim = cross_attention_dim
404
+
405
+ self.add_embedding = TextTimeEmbedding(
406
+ text_time_embedding_from_dim,
407
+ time_embed_dim,
408
+ num_heads=addition_embed_type_num_heads,
409
+ )
410
+ elif addition_embed_type == "text_image":
411
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
412
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
413
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
414
+ self.add_embedding = TextImageTimeEmbedding(
415
+ text_embed_dim=cross_attention_dim,
416
+ image_embed_dim=cross_attention_dim,
417
+ time_embed_dim=time_embed_dim,
418
+ )
419
+ elif addition_embed_type == "text_time":
420
+ self.add_time_proj = Timesteps(
421
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
422
+ )
423
+ self.add_embedding = TimestepEmbedding(
424
+ projection_class_embeddings_input_dim, time_embed_dim
425
+ )
426
+
427
+ elif addition_embed_type is not None:
428
+ raise ValueError(
429
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
430
+ )
431
+
432
+ # control net conditioning embedding
433
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
434
+ conditioning_embedding_channels=block_out_channels[0],
435
+ block_out_channels=conditioning_embedding_out_channels,
436
+ conditioning_channels=conditioning_channels,
437
+ )
438
+
439
+ # Copyright by Qi Xin(2024/07/06)
440
+ # Condition Transformer(fuse single/multi conditions with input image)
441
+ # The Condition Transformer augment the feature representation of conditions
442
+ # The overall design is somewhat like resnet. The output of Condition Transformer is used to predict a condition bias adding to the original condition feature.
443
+ # num_control_type = 6
444
+ num_trans_channel = 320
445
+ num_trans_head = 8
446
+ num_trans_layer = 1
447
+ num_proj_channel = 320
448
+ task_scale_factor = num_trans_channel**0.5
449
+
450
+ self.task_embedding = nn.Parameter(
451
+ task_scale_factor * torch.randn(num_control_type, num_trans_channel)
452
+ )
453
+ self.transformer_layes = nn.Sequential(
454
+ *[
455
+ ResidualAttentionBlock(num_trans_channel, num_trans_head)
456
+ for _ in range(num_trans_layer)
457
+ ]
458
+ )
459
+ self.spatial_ch_projs = zero_module(
460
+ nn.Linear(num_trans_channel, num_proj_channel)
461
+ )
462
+ # -----------------------------------------------------------------------------------------------------
463
+
464
+ # Copyright by Qi Xin(2024/07/06)
465
+ # Control Encoder to distinguish different control conditions
466
+ # A simple but effective module, consists of an embedding layer and a linear layer, to inject the control info to time embedding.
467
+ self.control_type_proj = Timesteps(
468
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
469
+ )
470
+ self.control_add_embedding = TimestepEmbedding(
471
+ addition_time_embed_dim * num_control_type, time_embed_dim
472
+ )
473
+ # -----------------------------------------------------------------------------------------------------
474
+
475
+ self.down_blocks = nn.ModuleList([])
476
+ self.controlnet_down_blocks = nn.ModuleList([])
477
+
478
+ if isinstance(only_cross_attention, bool):
479
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
480
+
481
+ if isinstance(attention_head_dim, int):
482
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
483
+
484
+ if isinstance(num_attention_heads, int):
485
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
486
+
487
+ # down
488
+ output_channel = block_out_channels[0]
489
+
490
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
491
+ controlnet_block = zero_module(controlnet_block)
492
+ self.controlnet_down_blocks.append(controlnet_block)
493
+
494
+ for i, down_block_type in enumerate(down_block_types):
495
+ input_channel = output_channel
496
+ output_channel = block_out_channels[i]
497
+ is_final_block = i == len(block_out_channels) - 1
498
+
499
+ down_block = get_down_block(
500
+ down_block_type,
501
+ num_layers=layers_per_block,
502
+ transformer_layers_per_block=transformer_layers_per_block[i],
503
+ in_channels=input_channel,
504
+ out_channels=output_channel,
505
+ temb_channels=time_embed_dim,
506
+ add_downsample=not is_final_block,
507
+ resnet_eps=norm_eps,
508
+ resnet_act_fn=act_fn,
509
+ resnet_groups=norm_num_groups,
510
+ cross_attention_dim=cross_attention_dim,
511
+ num_attention_heads=num_attention_heads[i],
512
+ attention_head_dim=attention_head_dim[i]
513
+ if attention_head_dim[i] is not None
514
+ else output_channel,
515
+ downsample_padding=downsample_padding,
516
+ use_linear_projection=use_linear_projection,
517
+ only_cross_attention=only_cross_attention[i],
518
+ upcast_attention=upcast_attention,
519
+ resnet_time_scale_shift=resnet_time_scale_shift,
520
+ )
521
+ self.down_blocks.append(down_block)
522
+
523
+ for _ in range(layers_per_block):
524
+ controlnet_block = nn.Conv2d(
525
+ output_channel, output_channel, kernel_size=1
526
+ )
527
+ controlnet_block = zero_module(controlnet_block)
528
+ self.controlnet_down_blocks.append(controlnet_block)
529
+
530
+ if not is_final_block:
531
+ controlnet_block = nn.Conv2d(
532
+ output_channel, output_channel, kernel_size=1
533
+ )
534
+ controlnet_block = zero_module(controlnet_block)
535
+ self.controlnet_down_blocks.append(controlnet_block)
536
+
537
+ # mid
538
+ mid_block_channel = block_out_channels[-1]
539
+
540
+ controlnet_block = nn.Conv2d(
541
+ mid_block_channel, mid_block_channel, kernel_size=1
542
+ )
543
+ controlnet_block = zero_module(controlnet_block)
544
+ self.controlnet_mid_block = controlnet_block
545
+
546
+ self.mid_block = UNetMidBlock2DCrossAttn(
547
+ transformer_layers_per_block=transformer_layers_per_block[-1],
548
+ in_channels=mid_block_channel,
549
+ temb_channels=time_embed_dim,
550
+ resnet_eps=norm_eps,
551
+ resnet_act_fn=act_fn,
552
+ output_scale_factor=mid_block_scale_factor,
553
+ resnet_time_scale_shift=resnet_time_scale_shift,
554
+ cross_attention_dim=cross_attention_dim,
555
+ num_attention_heads=num_attention_heads[-1],
556
+ resnet_groups=norm_num_groups,
557
+ use_linear_projection=use_linear_projection,
558
+ upcast_attention=upcast_attention,
559
+ )
560
+
561
+ @classmethod
562
+ def from_unet(
563
+ cls,
564
+ unet: UNet2DConditionModel,
565
+ controlnet_conditioning_channel_order: str = "rgb",
566
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
567
+ load_weights_from_unet: bool = True,
568
+ ):
569
+ r"""
570
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
571
+
572
+ Parameters:
573
+ unet (`UNet2DConditionModel`):
574
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
575
+ where applicable.
576
+ """
577
+ transformer_layers_per_block = (
578
+ unet.config.transformer_layers_per_block
579
+ if "transformer_layers_per_block" in unet.config
580
+ else 1
581
+ )
582
+ encoder_hid_dim = (
583
+ unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
584
+ )
585
+ encoder_hid_dim_type = (
586
+ unet.config.encoder_hid_dim_type
587
+ if "encoder_hid_dim_type" in unet.config
588
+ else None
589
+ )
590
+ addition_embed_type = (
591
+ unet.config.addition_embed_type
592
+ if "addition_embed_type" in unet.config
593
+ else None
594
+ )
595
+ addition_time_embed_dim = (
596
+ unet.config.addition_time_embed_dim
597
+ if "addition_time_embed_dim" in unet.config
598
+ else None
599
+ )
600
+
601
+ controlnet = cls(
602
+ encoder_hid_dim=encoder_hid_dim,
603
+ encoder_hid_dim_type=encoder_hid_dim_type,
604
+ addition_embed_type=addition_embed_type,
605
+ addition_time_embed_dim=addition_time_embed_dim,
606
+ transformer_layers_per_block=transformer_layers_per_block,
607
+ # transformer_layers_per_block=[1, 2, 5],
608
+ in_channels=unet.config.in_channels,
609
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
610
+ freq_shift=unet.config.freq_shift,
611
+ down_block_types=unet.config.down_block_types,
612
+ only_cross_attention=unet.config.only_cross_attention,
613
+ block_out_channels=unet.config.block_out_channels,
614
+ layers_per_block=unet.config.layers_per_block,
615
+ downsample_padding=unet.config.downsample_padding,
616
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
617
+ act_fn=unet.config.act_fn,
618
+ norm_num_groups=unet.config.norm_num_groups,
619
+ norm_eps=unet.config.norm_eps,
620
+ cross_attention_dim=unet.config.cross_attention_dim,
621
+ attention_head_dim=unet.config.attention_head_dim,
622
+ num_attention_heads=unet.config.num_attention_heads,
623
+ use_linear_projection=unet.config.use_linear_projection,
624
+ class_embed_type=unet.config.class_embed_type,
625
+ num_class_embeds=unet.config.num_class_embeds,
626
+ upcast_attention=unet.config.upcast_attention,
627
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
628
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
629
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
630
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
631
+ )
632
+
633
+ if load_weights_from_unet:
634
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
635
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
636
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
637
+
638
+ if controlnet.class_embedding:
639
+ controlnet.class_embedding.load_state_dict(
640
+ unet.class_embedding.state_dict()
641
+ )
642
+
643
+ controlnet.down_blocks.load_state_dict(
644
+ unet.down_blocks.state_dict(), strict=False
645
+ )
646
+ controlnet.mid_block.load_state_dict(
647
+ unet.mid_block.state_dict(), strict=False
648
+ )
649
+
650
+ return controlnet
651
+
652
+ @property
653
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
654
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
655
+ r"""
656
+ Returns:
657
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
658
+ indexed by its weight name.
659
+ """
660
+ # set recursively
661
+ processors = {}
662
+
663
+ def fn_recursive_add_processors(
664
+ name: str,
665
+ module: torch.nn.Module,
666
+ processors: Dict[str, AttentionProcessor],
667
+ ):
668
+ if hasattr(module, "get_processor"):
669
+ processors[f"{name}.processor"] = module.get_processor(
670
+ return_deprecated_lora=True
671
+ )
672
+
673
+ for sub_name, child in module.named_children():
674
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
675
+
676
+ return processors
677
+
678
+ for name, module in self.named_children():
679
+ fn_recursive_add_processors(name, module, processors)
680
+
681
+ return processors
682
+
683
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
684
+ def set_attn_processor(
685
+ self,
686
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
687
+ _remove_lora=False,
688
+ ):
689
+ r"""
690
+ Sets the attention processor to use to compute attention.
691
+
692
+ Parameters:
693
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
694
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
695
+ for **all** `Attention` layers.
696
+
697
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
698
+ processor. This is strongly recommended when setting trainable attention processors.
699
+
700
+ """
701
+ count = len(self.attn_processors.keys())
702
+
703
+ if isinstance(processor, dict) and len(processor) != count:
704
+ raise ValueError(
705
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
706
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
707
+ )
708
+
709
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
710
+ if hasattr(module, "set_processor"):
711
+ if not isinstance(processor, dict):
712
+ module.set_processor(processor, _remove_lora=_remove_lora)
713
+ else:
714
+ module.set_processor(
715
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
716
+ )
717
+
718
+ for sub_name, child in module.named_children():
719
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
720
+
721
+ for name, module in self.named_children():
722
+ fn_recursive_attn_processor(name, module, processor)
723
+
724
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
725
+ def set_default_attn_processor(self):
726
+ """
727
+ Disables custom attention processors and sets the default attention implementation.
728
+ """
729
+ if all(
730
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
731
+ for proc in self.attn_processors.values()
732
+ ):
733
+ processor = AttnAddedKVProcessor()
734
+ elif all(
735
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
736
+ for proc in self.attn_processors.values()
737
+ ):
738
+ processor = AttnProcessor()
739
+ else:
740
+ raise ValueError(
741
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
742
+ )
743
+
744
+ self.set_attn_processor(processor, _remove_lora=True)
745
+
746
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
747
+ def set_attention_slice(self, slice_size):
748
+ r"""
749
+ Enable sliced attention computation.
750
+
751
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
752
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
753
+
754
+ Args:
755
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
756
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
757
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
758
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
759
+ must be a multiple of `slice_size`.
760
+ """
761
+ sliceable_head_dims = []
762
+
763
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
764
+ if hasattr(module, "set_attention_slice"):
765
+ sliceable_head_dims.append(module.sliceable_head_dim)
766
+
767
+ for child in module.children():
768
+ fn_recursive_retrieve_sliceable_dims(child)
769
+
770
+ # retrieve number of attention layers
771
+ for module in self.children():
772
+ fn_recursive_retrieve_sliceable_dims(module)
773
+
774
+ num_sliceable_layers = len(sliceable_head_dims)
775
+
776
+ if slice_size == "auto":
777
+ # half the attention head size is usually a good trade-off between
778
+ # speed and memory
779
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
780
+ elif slice_size == "max":
781
+ # make smallest slice possible
782
+ slice_size = num_sliceable_layers * [1]
783
+
784
+ slice_size = (
785
+ num_sliceable_layers * [slice_size]
786
+ if not isinstance(slice_size, list)
787
+ else slice_size
788
+ )
789
+
790
+ if len(slice_size) != len(sliceable_head_dims):
791
+ raise ValueError(
792
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
793
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
794
+ )
795
+
796
+ for i in range(len(slice_size)):
797
+ size = slice_size[i]
798
+ dim = sliceable_head_dims[i]
799
+ if size is not None and size > dim:
800
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
801
+
802
+ # Recursively walk through all the children.
803
+ # Any children which exposes the set_attention_slice method
804
+ # gets the message
805
+ def fn_recursive_set_attention_slice(
806
+ module: torch.nn.Module, slice_size: List[int]
807
+ ):
808
+ if hasattr(module, "set_attention_slice"):
809
+ module.set_attention_slice(slice_size.pop())
810
+
811
+ for child in module.children():
812
+ fn_recursive_set_attention_slice(child, slice_size)
813
+
814
+ reversed_slice_size = list(reversed(slice_size))
815
+ for module in self.children():
816
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
817
+
818
+ def _set_gradient_checkpointing(self, module, value=False):
819
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
820
+ module.gradient_checkpointing = value
821
+
822
+ def forward(
823
+ self,
824
+ sample: torch.FloatTensor,
825
+ timestep: Union[torch.Tensor, float, int],
826
+ encoder_hidden_states: torch.Tensor,
827
+ controlnet_cond_list: torch.FloatTensor,
828
+ conditioning_scale: float = 1.0,
829
+ class_labels: Optional[torch.Tensor] = None,
830
+ timestep_cond: Optional[torch.Tensor] = None,
831
+ attention_mask: Optional[torch.Tensor] = None,
832
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
833
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
834
+ guess_mode: bool = False,
835
+ return_dict: bool = True,
836
+ ) -> Union[ControlNetOutput, Tuple]:
837
+ """
838
+ The [`ControlNetModel`] forward method.
839
+
840
+ Args:
841
+ sample (`torch.FloatTensor`):
842
+ The noisy input tensor.
843
+ timestep (`Union[torch.Tensor, float, int]`):
844
+ The number of timesteps to denoise an input.
845
+ encoder_hidden_states (`torch.Tensor`):
846
+ The encoder hidden states.
847
+ controlnet_cond (`torch.FloatTensor`):
848
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
849
+ conditioning_scale (`float`, defaults to `1.0`):
850
+ The scale factor for ControlNet outputs.
851
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
852
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
853
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
854
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
855
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
856
+ embeddings.
857
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
858
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
859
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
860
+ negative values to the attention scores corresponding to "discard" tokens.
861
+ added_cond_kwargs (`dict`):
862
+ Additional conditions for the Stable Diffusion XL UNet.
863
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
864
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
865
+ guess_mode (`bool`, defaults to `False`):
866
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
867
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
868
+ return_dict (`bool`, defaults to `True`):
869
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
870
+
871
+ Returns:
872
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
873
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
874
+ returned where the first element is the sample tensor.
875
+ """
876
+ # check channel order
877
+ channel_order = self.config.controlnet_conditioning_channel_order
878
+
879
+ if channel_order == "rgb":
880
+ # in rgb order by default
881
+ ...
882
+ # elif channel_order == "bgr":
883
+ # controlnet_cond = torch.flip(controlnet_cond, dims=[1])
884
+ else:
885
+ raise ValueError(
886
+ f"unknown `controlnet_conditioning_channel_order`: {channel_order}"
887
+ )
888
+
889
+ # prepare attention_mask
890
+ if attention_mask is not None:
891
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
892
+ attention_mask = attention_mask.unsqueeze(1)
893
+
894
+ # 1. time
895
+ timesteps = timestep
896
+ if not torch.is_tensor(timesteps):
897
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
898
+ # This would be a good case for the `match` statement (Python 3.10+)
899
+ is_mps = sample.device.type == "mps"
900
+ if isinstance(timestep, float):
901
+ dtype = torch.float32 if is_mps else torch.float64
902
+ else:
903
+ dtype = torch.int32 if is_mps else torch.int64
904
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
905
+ elif len(timesteps.shape) == 0:
906
+ timesteps = timesteps[None].to(sample.device)
907
+
908
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
909
+ timesteps = timesteps.expand(sample.shape[0])
910
+
911
+ t_emb = self.time_proj(timesteps)
912
+
913
+ # timesteps does not contain any weights and will always return f32 tensors
914
+ # but time_embedding might actually be running in fp16. so we need to cast here.
915
+ # there might be better ways to encapsulate this.
916
+ t_emb = t_emb.to(dtype=sample.dtype)
917
+
918
+ emb = self.time_embedding(t_emb, timestep_cond)
919
+ aug_emb = None
920
+
921
+ if self.class_embedding is not None:
922
+ if class_labels is None:
923
+ raise ValueError(
924
+ "class_labels should be provided when num_class_embeds > 0"
925
+ )
926
+
927
+ if self.config.class_embed_type == "timestep":
928
+ class_labels = self.time_proj(class_labels)
929
+
930
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
931
+ emb = emb + class_emb
932
+
933
+ if self.config.addition_embed_type is not None:
934
+ if self.config.addition_embed_type == "text":
935
+ aug_emb = self.add_embedding(encoder_hidden_states)
936
+
937
+ elif self.config.addition_embed_type == "text_time":
938
+ if "text_embeds" not in added_cond_kwargs:
939
+ raise ValueError(
940
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
941
+ )
942
+ text_embeds = added_cond_kwargs.get("text_embeds")
943
+ if "time_ids" not in added_cond_kwargs:
944
+ raise ValueError(
945
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
946
+ )
947
+ time_ids = added_cond_kwargs.get("time_ids")
948
+ time_embeds = self.add_time_proj(time_ids.flatten())
949
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
950
+
951
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
952
+ add_embeds = add_embeds.to(emb.dtype)
953
+ aug_emb = self.add_embedding(add_embeds)
954
+
955
+ # Copyright by Qi Xin(2024/07/06)
956
+ # inject control type info to time embedding to distinguish different control conditions
957
+ control_type = added_cond_kwargs.get("control_type")
958
+ control_embeds = self.control_type_proj(control_type.flatten())
959
+ control_embeds = control_embeds.reshape((t_emb.shape[0], -1))
960
+ control_embeds = control_embeds.to(emb.dtype)
961
+ control_emb = self.control_add_embedding(control_embeds)
962
+ emb = emb + control_emb
963
+ # ---------------------------------------------------------------------------------
964
+
965
+ emb = emb + aug_emb if aug_emb is not None else emb
966
+
967
+ # 2. pre-process
968
+ sample = self.conv_in(sample)
969
+ indices = torch.nonzero(control_type[0])
970
+
971
+ # Copyright by Qi Xin(2024/07/06)
972
+ # add single/multi conditons to input image.
973
+ # Condition Transformer provides an easy and effective way to fuse different features naturally
974
+ inputs = []
975
+ condition_list = []
976
+
977
+ for idx in range(indices.shape[0] + 1):
978
+ if idx == indices.shape[0]:
979
+ controlnet_cond = sample
980
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C
981
+ else:
982
+ controlnet_cond = self.controlnet_cond_embedding(
983
+ controlnet_cond_list[indices[idx][0]]
984
+ )
985
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C
986
+ feat_seq = feat_seq + self.task_embedding[indices[idx][0]]
987
+
988
+ inputs.append(feat_seq.unsqueeze(1))
989
+ condition_list.append(controlnet_cond)
990
+
991
+ x = torch.cat(inputs, dim=1) # NxLxC
992
+ x = self.transformer_layes(x)
993
+
994
+ controlnet_cond_fuser = sample * 0.0
995
+ for idx in range(indices.shape[0]):
996
+ alpha = self.spatial_ch_projs(x[:, idx])
997
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
998
+ controlnet_cond_fuser += condition_list[idx] + alpha
999
+
1000
+ sample = sample + controlnet_cond_fuser
1001
+ # -------------------------------------------------------------------------------------------
1002
+
1003
+ # 3. down
1004
+ down_block_res_samples = (sample,)
1005
+ for downsample_block in self.down_blocks:
1006
+ if (
1007
+ hasattr(downsample_block, "has_cross_attention")
1008
+ and downsample_block.has_cross_attention
1009
+ ):
1010
+ sample, res_samples = downsample_block(
1011
+ hidden_states=sample,
1012
+ temb=emb,
1013
+ encoder_hidden_states=encoder_hidden_states,
1014
+ attention_mask=attention_mask,
1015
+ cross_attention_kwargs=cross_attention_kwargs,
1016
+ )
1017
+ else:
1018
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1019
+
1020
+ down_block_res_samples += res_samples
1021
+
1022
+ # 4. mid
1023
+ if self.mid_block is not None:
1024
+ sample = self.mid_block(
1025
+ sample,
1026
+ emb,
1027
+ encoder_hidden_states=encoder_hidden_states,
1028
+ attention_mask=attention_mask,
1029
+ cross_attention_kwargs=cross_attention_kwargs,
1030
+ )
1031
+
1032
+ # 5. Control net blocks
1033
+
1034
+ controlnet_down_block_res_samples = ()
1035
+
1036
+ for down_block_res_sample, controlnet_block in zip(
1037
+ down_block_res_samples, self.controlnet_down_blocks
1038
+ ):
1039
+ down_block_res_sample = controlnet_block(down_block_res_sample)
1040
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (
1041
+ down_block_res_sample,
1042
+ )
1043
+
1044
+ down_block_res_samples = controlnet_down_block_res_samples
1045
+
1046
+ mid_block_res_sample = self.controlnet_mid_block(sample)
1047
+
1048
+ # 6. scaling
1049
+ if guess_mode and not self.config.global_pool_conditions:
1050
+ scales = torch.logspace(
1051
+ -1, 0, len(down_block_res_samples) + 1, device=sample.device
1052
+ ) # 0.1 to 1.0
1053
+ scales = scales * conditioning_scale
1054
+ down_block_res_samples = [
1055
+ sample * scale for sample, scale in zip(down_block_res_samples, scales)
1056
+ ]
1057
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
1058
+ else:
1059
+ down_block_res_samples = [
1060
+ sample * conditioning_scale for sample in down_block_res_samples
1061
+ ]
1062
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
1063
+
1064
+ if self.config.global_pool_conditions:
1065
+ down_block_res_samples = [
1066
+ torch.mean(sample, dim=(2, 3), keepdim=True)
1067
+ for sample in down_block_res_samples
1068
+ ]
1069
+ mid_block_res_sample = torch.mean(
1070
+ mid_block_res_sample, dim=(2, 3), keepdim=True
1071
+ )
1072
+
1073
+ if not return_dict:
1074
+ return (down_block_res_samples, mid_block_res_sample)
1075
+
1076
+ return ControlNetOutput(
1077
+ down_block_res_samples=down_block_res_samples,
1078
+ mid_block_res_sample=mid_block_res_sample,
1079
+ )
1080
+
1081
+
1082
+ def zero_module(module):
1083
+ for p in module.parameters():
1084
+ nn.init.zeros_(p)
1085
+ return module
pipeline_fill_sd_xl (1).py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Union
16
+
17
+ import cv2
18
+ import PIL.Image
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
22
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
23
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
24
+ from diffusers.schedulers import KarrasDiffusionSchedulers
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
27
+
28
+ from controlnet_union import ControlNetModel_Union
29
+
30
+
31
+ def latents_to_rgb(latents):
32
+ weights = ((60, -60, 25, -70), (60, -5, 15, -50), (60, 10, -5, -35))
33
+
34
+ weights_tensor = torch.t(
35
+ torch.tensor(weights, dtype=latents.dtype).to(latents.device)
36
+ )
37
+ biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(
38
+ latents.device
39
+ )
40
+ rgb_tensor = torch.einsum(
41
+ "...lxy,lr -> ...rxy", latents, weights_tensor
42
+ ) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
43
+ image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
44
+ image_array = image_array.transpose(1, 2, 0) # Change the order of dimensions
45
+
46
+ denoised_image = cv2.fastNlMeansDenoisingColored(image_array, None, 10, 10, 7, 21)
47
+ blurred_image = cv2.GaussianBlur(denoised_image, (5, 5), 0)
48
+ final_image = PIL.Image.fromarray(blurred_image)
49
+
50
+ width, height = final_image.size
51
+ final_image = final_image.resize(
52
+ (width * 8, height * 8), PIL.Image.Resampling.LANCZOS
53
+ )
54
+
55
+ return final_image
56
+
57
+
58
+ def retrieve_timesteps(
59
+ scheduler,
60
+ num_inference_steps: Optional[int] = None,
61
+ device: Optional[Union[str, torch.device]] = None,
62
+ **kwargs,
63
+ ):
64
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
65
+ timesteps = scheduler.timesteps
66
+
67
+ return timesteps, num_inference_steps
68
+
69
+
70
+ class StableDiffusionXLFillPipeline(DiffusionPipeline, StableDiffusionMixin):
71
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
72
+ _optional_components = [
73
+ "tokenizer",
74
+ "tokenizer_2",
75
+ "text_encoder",
76
+ "text_encoder_2",
77
+ ]
78
+
79
+ def __init__(
80
+ self,
81
+ vae: AutoencoderKL,
82
+ text_encoder: CLIPTextModel,
83
+ text_encoder_2: CLIPTextModelWithProjection,
84
+ tokenizer: CLIPTokenizer,
85
+ tokenizer_2: CLIPTokenizer,
86
+ unet: UNet2DConditionModel,
87
+ controlnet: ControlNetModel_Union,
88
+ scheduler: KarrasDiffusionSchedulers,
89
+ force_zeros_for_empty_prompt: bool = True,
90
+ ):
91
+ super().__init__()
92
+
93
+ self.register_modules(
94
+ vae=vae,
95
+ text_encoder=text_encoder,
96
+ text_encoder_2=text_encoder_2,
97
+ tokenizer=tokenizer,
98
+ tokenizer_2=tokenizer_2,
99
+ unet=unet,
100
+ controlnet=controlnet,
101
+ scheduler=scheduler,
102
+ )
103
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
104
+ self.image_processor = VaeImageProcessor(
105
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
106
+ )
107
+ self.control_image_processor = VaeImageProcessor(
108
+ vae_scale_factor=self.vae_scale_factor,
109
+ do_convert_rgb=True,
110
+ do_normalize=False,
111
+ )
112
+
113
+ self.register_to_config(
114
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt
115
+ )
116
+
117
+ def encode_prompt(
118
+ self,
119
+ prompt: str,
120
+ device: Optional[torch.device] = None,
121
+ do_classifier_free_guidance: bool = True,
122
+ ):
123
+ device = device or self._execution_device
124
+ prompt = [prompt] if isinstance(prompt, str) else prompt
125
+
126
+ if prompt is not None:
127
+ batch_size = len(prompt)
128
+
129
+ # Define tokenizers and text encoders
130
+ tokenizers = (
131
+ [self.tokenizer, self.tokenizer_2]
132
+ if self.tokenizer is not None
133
+ else [self.tokenizer_2]
134
+ )
135
+ text_encoders = (
136
+ [self.text_encoder, self.text_encoder_2]
137
+ if self.text_encoder is not None
138
+ else [self.text_encoder_2]
139
+ )
140
+
141
+ prompt_2 = prompt
142
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
143
+
144
+ # textual inversion: process multi-vector tokens if necessary
145
+ prompt_embeds_list = []
146
+ prompts = [prompt, prompt_2]
147
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
148
+ text_inputs = tokenizer(
149
+ prompt,
150
+ padding="max_length",
151
+ max_length=tokenizer.model_max_length,
152
+ truncation=True,
153
+ return_tensors="pt",
154
+ )
155
+
156
+ text_input_ids = text_inputs.input_ids
157
+
158
+ prompt_embeds = text_encoder(
159
+ text_input_ids.to(device), output_hidden_states=True
160
+ )
161
+
162
+ # We are only ALWAYS interested in the pooled output of the final text encoder
163
+ pooled_prompt_embeds = prompt_embeds[0]
164
+ prompt_embeds = prompt_embeds.hidden_states[-2]
165
+ prompt_embeds_list.append(prompt_embeds)
166
+
167
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
168
+
169
+ # get unconditional embeddings for classifier free guidance
170
+ zero_out_negative_prompt = True
171
+ negative_prompt_embeds = None
172
+ negative_pooled_prompt_embeds = None
173
+
174
+ if do_classifier_free_guidance and zero_out_negative_prompt:
175
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
176
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
177
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
178
+ negative_prompt = ""
179
+ negative_prompt_2 = negative_prompt
180
+
181
+ # normalize str to list
182
+ negative_prompt = (
183
+ batch_size * [negative_prompt]
184
+ if isinstance(negative_prompt, str)
185
+ else negative_prompt
186
+ )
187
+ negative_prompt_2 = (
188
+ batch_size * [negative_prompt_2]
189
+ if isinstance(negative_prompt_2, str)
190
+ else negative_prompt_2
191
+ )
192
+
193
+ uncond_tokens: List[str]
194
+ if prompt is not None and type(prompt) is not type(negative_prompt):
195
+ raise TypeError(
196
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
197
+ f" {type(prompt)}."
198
+ )
199
+ elif batch_size != len(negative_prompt):
200
+ raise ValueError(
201
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
202
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
203
+ " the batch size of `prompt`."
204
+ )
205
+ else:
206
+ uncond_tokens = [negative_prompt, negative_prompt_2]
207
+
208
+ negative_prompt_embeds_list = []
209
+ for negative_prompt, tokenizer, text_encoder in zip(
210
+ uncond_tokens, tokenizers, text_encoders
211
+ ):
212
+ max_length = prompt_embeds.shape[1]
213
+ uncond_input = tokenizer(
214
+ negative_prompt,
215
+ padding="max_length",
216
+ max_length=max_length,
217
+ truncation=True,
218
+ return_tensors="pt",
219
+ )
220
+
221
+ negative_prompt_embeds = text_encoder(
222
+ uncond_input.input_ids.to(device),
223
+ output_hidden_states=True,
224
+ )
225
+ # We are only ALWAYS interested in the pooled output of the final text encoder
226
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
227
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
228
+
229
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
230
+
231
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
232
+
233
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
234
+
235
+ bs_embed, seq_len, _ = prompt_embeds.shape
236
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
237
+ prompt_embeds = prompt_embeds.repeat(1, 1, 1)
238
+ prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1)
239
+
240
+ if do_classifier_free_guidance:
241
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
242
+ seq_len = negative_prompt_embeds.shape[1]
243
+
244
+ if self.text_encoder_2 is not None:
245
+ negative_prompt_embeds = negative_prompt_embeds.to(
246
+ dtype=self.text_encoder_2.dtype, device=device
247
+ )
248
+ else:
249
+ negative_prompt_embeds = negative_prompt_embeds.to(
250
+ dtype=self.unet.dtype, device=device
251
+ )
252
+
253
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, 1, 1)
254
+ negative_prompt_embeds = negative_prompt_embeds.view(
255
+ batch_size * 1, seq_len, -1
256
+ )
257
+
258
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1)
259
+ if do_classifier_free_guidance:
260
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
261
+ 1, 1
262
+ ).view(bs_embed * 1, -1)
263
+
264
+ return (
265
+ prompt_embeds,
266
+ negative_prompt_embeds,
267
+ pooled_prompt_embeds,
268
+ negative_pooled_prompt_embeds,
269
+ )
270
+
271
+ def check_inputs(
272
+ self,
273
+ prompt_embeds,
274
+ negative_prompt_embeds,
275
+ pooled_prompt_embeds,
276
+ negative_pooled_prompt_embeds,
277
+ image,
278
+ controlnet_conditioning_scale=1.0,
279
+ ):
280
+ if prompt_embeds is None:
281
+ raise ValueError(
282
+ "Provide `prompt_embeds`. Cannot leave `prompt_embeds` undefined."
283
+ )
284
+
285
+ if negative_prompt_embeds is None:
286
+ raise ValueError(
287
+ "Provide `negative_prompt_embeds`. Cannot leave `negative_prompt_embeds` undefined."
288
+ )
289
+
290
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
291
+ raise ValueError(
292
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
293
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
294
+ f" {negative_prompt_embeds.shape}."
295
+ )
296
+
297
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
298
+ raise ValueError(
299
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
300
+ )
301
+
302
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
303
+ raise ValueError(
304
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
305
+ )
306
+
307
+ # Check `image`
308
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
309
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
310
+ )
311
+ if (
312
+ isinstance(self.controlnet, ControlNetModel_Union)
313
+ or is_compiled
314
+ and isinstance(self.controlnet._orig_mod, ControlNetModel_Union)
315
+ ):
316
+ if not isinstance(image, PIL.Image.Image):
317
+ raise TypeError(
318
+ f"image must be passed and has to be a PIL image, but is {type(image)}"
319
+ )
320
+
321
+ else:
322
+ assert False
323
+
324
+ # Check `controlnet_conditioning_scale`
325
+ if (
326
+ isinstance(self.controlnet, ControlNetModel_Union)
327
+ or is_compiled
328
+ and isinstance(self.controlnet._orig_mod, ControlNetModel_Union)
329
+ ):
330
+ if not isinstance(controlnet_conditioning_scale, float):
331
+ raise TypeError(
332
+ "For single controlnet: `controlnet_conditioning_scale` must be type `float`."
333
+ )
334
+ else:
335
+ assert False
336
+
337
+ def prepare_image(self, image, device, dtype, do_classifier_free_guidance=False):
338
+ image = self.control_image_processor.preprocess(image).to(dtype=torch.float32)
339
+
340
+ image_batch_size = image.shape[0]
341
+
342
+ image = image.repeat_interleave(image_batch_size, dim=0)
343
+ image = image.to(device=device, dtype=dtype)
344
+
345
+ if do_classifier_free_guidance:
346
+ image = torch.cat([image] * 2)
347
+
348
+ return image
349
+
350
+ def prepare_latents(
351
+ self, batch_size, num_channels_latents, height, width, dtype, device
352
+ ):
353
+ shape = (
354
+ batch_size,
355
+ num_channels_latents,
356
+ int(height) // self.vae_scale_factor,
357
+ int(width) // self.vae_scale_factor,
358
+ )
359
+
360
+ latents = randn_tensor(shape, device=device, dtype=dtype)
361
+
362
+ # scale the initial noise by the standard deviation required by the scheduler
363
+ latents = latents * self.scheduler.init_noise_sigma
364
+ return latents
365
+
366
+ @property
367
+ def guidance_scale(self):
368
+ return self._guidance_scale
369
+
370
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
371
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
372
+ # corresponds to doing no classifier free guidance.
373
+ @property
374
+ def do_classifier_free_guidance(self):
375
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
376
+
377
+ @property
378
+ def num_timesteps(self):
379
+ return self._num_timesteps
380
+
381
+ @torch.no_grad()
382
+ def __call__(
383
+ self,
384
+ prompt_embeds: torch.Tensor,
385
+ negative_prompt_embeds: torch.Tensor,
386
+ pooled_prompt_embeds: torch.Tensor,
387
+ negative_pooled_prompt_embeds: torch.Tensor,
388
+ image: PipelineImageInput = None,
389
+ num_inference_steps: int = 8,
390
+ guidance_scale: float = 1.5,
391
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
392
+ ):
393
+ # 1. Check inputs. Raise error if not correct
394
+ self.check_inputs(
395
+ prompt_embeds,
396
+ negative_prompt_embeds,
397
+ pooled_prompt_embeds,
398
+ negative_pooled_prompt_embeds,
399
+ image,
400
+ controlnet_conditioning_scale,
401
+ )
402
+
403
+ self._guidance_scale = guidance_scale
404
+
405
+ # 2. Define call parameters
406
+ batch_size = 1
407
+ device = self._execution_device
408
+
409
+ # 4. Prepare image
410
+ if isinstance(self.controlnet, ControlNetModel_Union):
411
+ image = self.prepare_image(
412
+ image=image,
413
+ device=device,
414
+ dtype=self.controlnet.dtype,
415
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
416
+ )
417
+ height, width = image.shape[-2:]
418
+ else:
419
+ assert False
420
+
421
+ # 5. Prepare timesteps
422
+ timesteps, num_inference_steps = retrieve_timesteps(
423
+ self.scheduler, num_inference_steps, device
424
+ )
425
+ self._num_timesteps = len(timesteps)
426
+
427
+ # 6. Prepare latent variables
428
+ num_channels_latents = self.unet.config.in_channels
429
+ latents = self.prepare_latents(
430
+ batch_size,
431
+ num_channels_latents,
432
+ height,
433
+ width,
434
+ prompt_embeds.dtype,
435
+ device,
436
+ )
437
+
438
+ # 7 Prepare added time ids & embeddings
439
+ add_text_embeds = pooled_prompt_embeds
440
+
441
+ add_time_ids = negative_add_time_ids = torch.tensor(
442
+ image.shape[-2:] + torch.Size([0, 0]) + image.shape[-2:]
443
+ ).unsqueeze(0)
444
+
445
+ if self.do_classifier_free_guidance:
446
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
447
+ add_text_embeds = torch.cat(
448
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
449
+ )
450
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
451
+
452
+ prompt_embeds = prompt_embeds.to(device)
453
+ add_text_embeds = add_text_embeds.to(device)
454
+ add_time_ids = add_time_ids.to(device).repeat(batch_size, 1)
455
+
456
+ controlnet_image_list = [0, 0, 0, 0, 0, 0, image, 0]
457
+ union_control_type = (
458
+ torch.Tensor([0, 0, 0, 0, 0, 0, 1, 0])
459
+ .to(device, dtype=prompt_embeds.dtype)
460
+ .repeat(batch_size * 2, 1)
461
+ )
462
+
463
+ added_cond_kwargs = {
464
+ "text_embeds": add_text_embeds,
465
+ "time_ids": add_time_ids,
466
+ "control_type": union_control_type,
467
+ }
468
+
469
+ controlnet_prompt_embeds = prompt_embeds
470
+ controlnet_added_cond_kwargs = added_cond_kwargs
471
+
472
+ # 8. Denoising loop
473
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
474
+
475
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
476
+ for i, t in enumerate(timesteps):
477
+ # expand the latents if we are doing classifier free guidance
478
+ latent_model_input = (
479
+ torch.cat([latents] * 2)
480
+ if self.do_classifier_free_guidance
481
+ else latents
482
+ )
483
+ latent_model_input = self.scheduler.scale_model_input(
484
+ latent_model_input, t
485
+ )
486
+
487
+ # controlnet(s) inference
488
+ control_model_input = latent_model_input
489
+
490
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
491
+ control_model_input,
492
+ t,
493
+ encoder_hidden_states=controlnet_prompt_embeds,
494
+ controlnet_cond_list=controlnet_image_list,
495
+ conditioning_scale=controlnet_conditioning_scale,
496
+ guess_mode=False,
497
+ added_cond_kwargs=controlnet_added_cond_kwargs,
498
+ return_dict=False,
499
+ )
500
+
501
+ # predict the noise residual
502
+ noise_pred = self.unet(
503
+ latent_model_input,
504
+ t,
505
+ encoder_hidden_states=prompt_embeds,
506
+ timestep_cond=None,
507
+ cross_attention_kwargs={},
508
+ down_block_additional_residuals=down_block_res_samples,
509
+ mid_block_additional_residual=mid_block_res_sample,
510
+ added_cond_kwargs=added_cond_kwargs,
511
+ return_dict=False,
512
+ )[0]
513
+
514
+ # perform guidance
515
+ if self.do_classifier_free_guidance:
516
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
517
+ noise_pred = noise_pred_uncond + guidance_scale * (
518
+ noise_pred_text - noise_pred_uncond
519
+ )
520
+
521
+ # compute the previous noisy sample x_t -> x_t-1
522
+ latents = self.scheduler.step(
523
+ noise_pred, t, latents, return_dict=False
524
+ )[0]
525
+
526
+ if i == 2:
527
+ prompt_embeds = prompt_embeds[-1:]
528
+ add_text_embeds = add_text_embeds[-1:]
529
+ add_time_ids = add_time_ids[-1:]
530
+ union_control_type = union_control_type[-1:]
531
+
532
+ added_cond_kwargs = {
533
+ "text_embeds": add_text_embeds,
534
+ "time_ids": add_time_ids,
535
+ "control_type": union_control_type,
536
+ }
537
+
538
+ controlnet_prompt_embeds = prompt_embeds
539
+ controlnet_added_cond_kwargs = added_cond_kwargs
540
+
541
+ image = image[-1:]
542
+ controlnet_image_list = [0, 0, 0, 0, 0, 0, image, 0]
543
+
544
+ self._guidance_scale = 0.0
545
+
546
+ if i == len(timesteps) - 1 or (
547
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
548
+ ):
549
+ progress_bar.update()
550
+ yield latents_to_rgb(latents)
551
+
552
+ latents = latents / self.vae.config.scaling_factor
553
+ image = self.vae.decode(latents, return_dict=False)[0]
554
+ image = self.image_processor.postprocess(image)[0]
555
+
556
+ # Offload all models
557
+ self.maybe_free_model_hooks()
558
+
559
+ yield image