BestWishYsh commited on
Commit
5ef7e81
·
verified ·
1 Parent(s): 6c2beee

Update models/transformer_consisid.py

Browse files
Files changed (1) hide show
  1. models/transformer_consisid.py +81 -227
models/transformer_consisid.py CHANGED
@@ -16,7 +16,7 @@ import glob
16
  import json
17
  import math
18
  import os
19
- from typing import Any, Dict, Optional, Tuple, Union
20
 
21
  import torch
22
  from torch import nn
@@ -24,11 +24,7 @@ from torch import nn
24
  from diffusers.configuration_utils import ConfigMixin, register_to_config
25
  from diffusers.loaders import PeftAdapterMixin
26
  from diffusers.models.attention import Attention, FeedForward
27
- from diffusers.models.attention_processor import (
28
- AttentionProcessor,
29
- CogVideoXAttnProcessor2_0,
30
- FusedCogVideoXAttnProcessor2_0,
31
- )
32
  from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
33
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
34
  from diffusers.models.modeling_utils import ModelMixin
@@ -40,61 +36,10 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph
40
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
 
42
 
43
- def ConsisIDFeedForward(dim, mult=4):
44
- """
45
- Creates a consistent ID feedforward block consisting of layer normalization, two linear layers, and a GELU
46
- activation.
47
-
48
- Args:
49
- dim (int): The input dimension of the tensor.
50
- mult (int, optional): Multiplier for the inner dimension. Default is 4.
51
-
52
- Returns:
53
- nn.Sequential: A sequence of layers comprising LayerNorm, Linear layers, and GELU.
54
- """
55
- inner_dim = int(dim * mult)
56
- return nn.Sequential(
57
- nn.LayerNorm(dim),
58
- nn.Linear(dim, inner_dim, bias=False),
59
- nn.GELU(),
60
- nn.Linear(inner_dim, dim, bias=False),
61
- )
62
-
63
-
64
- def reshape_tensor(x, heads):
65
- """
66
- Reshapes the input tensor for multi-head attention.
67
-
68
- Args:
69
- x (torch.Tensor): The input tensor with shape (batch_size, length, width).
70
- heads (int): The number of attention heads.
71
-
72
- Returns:
73
- torch.Tensor: The reshaped tensor, with shape (batch_size, heads, length, width).
74
- """
75
- bs, length, width = x.shape
76
- x = x.view(bs, length, heads, -1)
77
- x = x.transpose(1, 2)
78
- x = x.reshape(bs, heads, length, -1)
79
- return x
80
-
81
-
82
  class PerceiverAttention(nn.Module):
83
- """
84
- Implements the Perceiver attention mechanism with multi-head attention.
85
-
86
- This layer takes two inputs: 'x' (image features) and 'latents' (latent features), applying multi-head attention to
87
- both and producing an output tensor with the same dimension as the input tensor 'x'.
88
-
89
- Args:
90
- dim (int): The input dimension.
91
- dim_head (int, optional): The dimension of each attention head. Default is 64.
92
- heads (int, optional): The number of attention heads. Default is 8.
93
- kv_dim (int, optional): The key-value dimension. If None, `dim` is used for both keys and values.
94
- """
95
-
96
- def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
97
  super().__init__()
 
98
  self.scale = dim_head**-0.5
99
  self.dim_head = dim_head
100
  self.heads = heads
@@ -107,80 +52,58 @@ class PerceiverAttention(nn.Module):
107
  self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
108
  self.to_out = nn.Linear(inner_dim, dim, bias=False)
109
 
110
- def forward(self, x, latents):
111
- """
112
- Forward pass for Perceiver attention.
113
-
114
- Args:
115
- x (torch.Tensor): Image features tensor with shape (batch_size, num_pixels, D).
116
- latents (torch.Tensor): Latent features tensor with shape (batch_size, num_latents, D).
117
-
118
- Returns:
119
- torch.Tensor: Output tensor after applying attention and transformation.
120
- """
121
  # Apply normalization
122
- x = self.norm1(x)
123
  latents = self.norm2(latents)
124
 
125
- b, seq_len, _ = latents.shape # Get batch size and sequence length
126
 
127
  # Compute query, key, and value matrices
128
- q = self.to_q(latents)
129
- kv_input = torch.cat((x, latents), dim=-2)
130
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
131
 
132
  # Reshape the tensors for multi-head attention
133
- q = reshape_tensor(q, self.heads)
134
- k = reshape_tensor(k, self.heads)
135
- v = reshape_tensor(v, self.heads)
136
 
137
  # attention
138
  scale = 1 / math.sqrt(math.sqrt(self.dim_head))
139
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
140
  weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
141
- out = weight @ v
142
 
143
  # Reshape and return the final output
144
- out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
145
 
146
- return self.to_out(out)
147
 
148
 
149
  class LocalFacialExtractor(nn.Module):
150
  def __init__(
151
  self,
152
- id_dim=1280,
153
- vit_dim=1024,
154
- depth=10,
155
- dim_head=64,
156
- heads=16,
157
- num_id_token=5,
158
- num_queries=32,
159
- output_dim=2048,
160
- ff_mult=4,
 
161
  ):
162
- """
163
- Initializes the LocalFacialExtractor class.
164
-
165
- Parameters:
166
- - id_dim (int): The dimensionality of id features.
167
- - vit_dim (int): The dimensionality of vit features.
168
- - depth (int): Total number of PerceiverAttention and ConsisIDFeedForward layers.
169
- - dim_head (int): Dimensionality of each attention head.
170
- - heads (int): Number of attention heads.
171
- - num_id_token (int): Number of tokens used for identity features.
172
- - num_queries (int): Number of query tokens for the latent representation.
173
- - output_dim (int): Output dimension after projection.
174
- - ff_mult (int): Multiplier for the feed-forward network hidden dimension.
175
- """
176
  super().__init__()
177
 
178
  # Storing identity token and query information
179
  self.num_id_token = num_id_token
180
  self.vit_dim = vit_dim
181
  self.num_queries = num_queries
182
- assert depth % 5 == 0
183
- self.depth = depth // 5
 
184
  scale = vit_dim**-0.5
185
 
186
  # Learnable latent query embeddings
@@ -195,13 +118,18 @@ class LocalFacialExtractor(nn.Module):
195
  nn.ModuleList(
196
  [
197
  PerceiverAttention(dim=vit_dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer
198
- ConsisIDFeedForward(dim=vit_dim, mult=ff_mult), # ConsisIDFeedForward layer
 
 
 
 
 
199
  ]
200
  )
201
  )
202
 
203
  # Mappings for each of the 5 different ViT features
204
- for i in range(5):
205
  setattr(
206
  self,
207
  f"mapping_{i}",
@@ -227,32 +155,21 @@ class LocalFacialExtractor(nn.Module):
227
  nn.Linear(vit_dim, vit_dim * num_id_token),
228
  )
229
 
230
- def forward(self, x, y):
231
- """
232
- Forward pass for LocalFacialExtractor.
233
-
234
- Parameters:
235
- - x (Tensor): The input identity embedding tensor of shape (batch_size, id_dim).
236
- - y (list of Tensor): A list of 5 visual feature tensors each of shape (batch_size, vit_dim).
237
-
238
- Returns:
239
- - Tensor: The extracted latent features of shape (batch_size, num_queries, output_dim).
240
- """
241
-
242
  # Repeat latent queries for the batch size
243
- latents = self.latents.repeat(x.size(0), 1, 1)
244
 
245
  # Map the identity embedding to tokens
246
- x = self.id_embedding_mapping(x)
247
- x = x.reshape(-1, self.num_id_token, self.vit_dim)
248
 
249
  # Concatenate identity tokens with the latent queries
250
- latents = torch.cat((latents, x), dim=1)
251
 
252
- # Process each of the 5 visual feature inputs
253
- for i in range(5):
254
- vit_feature = getattr(self, f"mapping_{i}")(y[i])
255
- ctx_feature = torch.cat((x, vit_feature), dim=1)
256
 
257
  # Pass through the PerceiverAttention and ConsisIDFeedForward layers
258
  for attn, ff in self.layers[i * self.depth : (i + 1) * self.depth]:
@@ -267,26 +184,9 @@ class LocalFacialExtractor(nn.Module):
267
 
268
 
269
  class PerceiverCrossAttention(nn.Module):
270
- """
271
-
272
- Args:
273
- dim (int): Dimension of the input latent and output. Default is 3072.
274
- dim_head (int): Dimension of each attention head. Default is 128.
275
- heads (int): Number of attention heads. Default is 16.
276
- kv_dim (int): Dimension of the key/value input, allowing flexible cross-attention. Default is 2048.
277
-
278
- Attributes:
279
- scale (float): Scaling factor used in dot-product attention for numerical stability.
280
- norm1 (nn.LayerNorm): Layer normalization applied to the input image features.
281
- norm2 (nn.LayerNorm): Layer normalization applied to the latent features.
282
- to_q (nn.Linear): Linear layer for projecting the latent features into queries.
283
- to_kv (nn.Linear): Linear layer for projecting the input features into keys and values.
284
- to_out (nn.Linear): Linear layer for outputting the final result after attention.
285
-
286
- """
287
-
288
- def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
289
  super().__init__()
 
290
  self.scale = dim_head**-0.5
291
  self.dim_head = dim_head
292
  self.heads = heads
@@ -301,47 +201,32 @@ class PerceiverCrossAttention(nn.Module):
301
  self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
302
  self.to_out = nn.Linear(inner_dim, dim, bias=False)
303
 
304
- def forward(self, x, latents):
305
- """
306
-
307
- Args:
308
- x (torch.Tensor): Input image features with shape (batch_size, n1, D), where:
309
- - batch_size (b): Number of samples in the batch.
310
- - n1: Sequence length (e.g., number of patches or tokens).
311
- - D: Feature dimension.
312
-
313
- latents (torch.Tensor): Latent feature representations with shape (batch_size, n2, D), where:
314
- - n2: Number of latent elements.
315
-
316
- Returns:
317
- torch.Tensor: Attention-modulated features with shape (batch_size, n2, D).
318
-
319
- """
320
  # Apply layer normalization to the input image and latent features
321
- x = self.norm1(x)
322
- latents = self.norm2(latents)
323
 
324
- b, seq_len, _ = latents.shape
325
 
326
  # Compute queries, keys, and values
327
- q = self.to_q(latents)
328
- k, v = self.to_kv(x).chunk(2, dim=-1)
329
 
330
  # Reshape tensors to split into attention heads
331
- q = reshape_tensor(q, self.heads)
332
- k = reshape_tensor(k, self.heads)
333
- v = reshape_tensor(v, self.heads)
334
 
335
  # Compute attention weights
336
  scale = 1 / math.sqrt(math.sqrt(self.dim_head))
337
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable scaling than post-division
338
  weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
339
 
340
  # Compute the output via weighted combination of values
341
- out = weight @ v
342
 
343
  # Reshape and permute to prepare for final linear transformation
344
- out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
345
 
346
  return self.to_out(out)
347
 
@@ -567,6 +452,9 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
567
  The multiplication factor applied to the feed-forward network's hidden layer size in the Local Facial
568
  Extractor (LFE). A higher value increases the model's capacity to learn more complex facial feature
569
  transformations, but also increases the computation and memory requirements.
 
 
 
570
  local_face_scale (`float`, defaults to `1.0`):
571
  A scaling factor used to adjust the importance of local facial features in the model. This can influence
572
  how strongly the model focuses on high frequency face-related content.
@@ -616,6 +504,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
616
  LFE_num_querie: int = 32,
617
  LFE_output_dim: int = 2048,
618
  LFE_ff_mult: int = 4,
 
619
  local_face_scale: float = 1.0,
620
  ):
621
  super().__init__()
@@ -680,8 +569,6 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
680
  )
681
  self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
682
 
683
- self.gradient_checkpointing = False
684
-
685
  self.is_train_face = is_train_face
686
  self.is_kps = is_kps
687
 
@@ -697,6 +584,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
697
  self.LFE_num_querie = LFE_num_querie
698
  self.LFE_output_dim = LFE_output_dim
699
  self.LFE_ff_mult = LFE_ff_mult
 
700
  # cross configs
701
  self.inner_dim = inner_dim
702
  self.cross_attn_interval = cross_attn_interval
@@ -708,6 +596,8 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
708
  # face modules
709
  self._init_face_inputs()
710
 
 
 
711
  def _set_gradient_checkpointing(self, module, value=False):
712
  self.gradient_checkpointing = value
713
 
@@ -724,8 +614,8 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
724
  num_queries=self.LFE_num_querie,
725
  output_dim=self.LFE_output_dim,
726
  ff_mult=self.LFE_ff_mult,
727
- )
728
- self.local_facial_extractor.to(device, dtype=weight_dtype)
729
  self.perceiver_cross_attention = nn.ModuleList(
730
  [
731
  PerceiverCrossAttention(
@@ -811,46 +701,6 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
811
  for name, module in self.named_children():
812
  fn_recursive_attn_processor(name, module, processor)
813
 
814
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
815
- def fuse_qkv_projections(self):
816
- """
817
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
818
- are fused. For cross-attention modules, key and value projection matrices are fused.
819
-
820
- <Tip warning={true}>
821
-
822
- This API is 🧪 experimental.
823
-
824
- </Tip>
825
- """
826
- self.original_attn_processors = None
827
-
828
- for _, attn_processor in self.attn_processors.items():
829
- if "Added" in str(attn_processor.__class__.__name__):
830
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
831
-
832
- self.original_attn_processors = self.attn_processors
833
-
834
- for module in self.modules():
835
- if isinstance(module, Attention):
836
- module.fuse_projections(fuse=True)
837
-
838
- self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
839
-
840
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
841
- def unfuse_qkv_projections(self):
842
- """Disables the fused QKV projection if enabled.
843
-
844
- <Tip warning={true}>
845
-
846
- This API is 🧪 experimental.
847
-
848
- </Tip>
849
-
850
- """
851
- if self.original_attn_processors is not None:
852
- self.set_attn_processor(self.original_attn_processors)
853
-
854
  def forward(
855
  self,
856
  hidden_states: torch.Tensor,
@@ -863,13 +713,6 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
863
  id_vit_hidden: Optional[torch.Tensor] = None,
864
  return_dict: bool = True,
865
  ):
866
- # fuse clip and insightface
867
- if self.is_train_face:
868
- assert id_cond is not None and id_vit_hidden is not None
869
- valid_face_emb = self.local_facial_extractor(
870
- id_cond, id_vit_hidden
871
- ) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048])
872
-
873
  if attention_kwargs is not None:
874
  attention_kwargs = attention_kwargs.copy()
875
  lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -885,6 +728,17 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
885
  "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
886
  )
887
 
 
 
 
 
 
 
 
 
 
 
 
888
  batch_size, num_frames, channels, height, width = hidden_states.shape
889
 
890
  # 1. Time embedding
@@ -1086,4 +940,4 @@ if __name__ == '__main__':
1086
  id_cond=id_cond if id_cond is not None else None,
1087
  )[0]
1088
 
1089
- print(model_output)
 
16
  import json
17
  import math
18
  import os
19
+ from typing import Any, List, Dict, Optional, Tuple, Union
20
 
21
  import torch
22
  from torch import nn
 
24
  from diffusers.configuration_utils import ConfigMixin, register_to_config
25
  from diffusers.loaders import PeftAdapterMixin
26
  from diffusers.models.attention import Attention, FeedForward
27
+ from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
 
 
 
 
28
  from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
29
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
30
  from diffusers.models.modeling_utils import ModelMixin
 
36
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  class PerceiverAttention(nn.Module):
40
+ def __init__(self, dim: int, dim_head: int = 64, heads: int = 8, kv_dim: Optional[int] = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  super().__init__()
42
+
43
  self.scale = dim_head**-0.5
44
  self.dim_head = dim_head
45
  self.heads = heads
 
52
  self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
53
  self.to_out = nn.Linear(inner_dim, dim, bias=False)
54
 
55
+ def forward(self, image_embeds: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
56
  # Apply normalization
57
+ image_embeds = self.norm1(image_embeds)
58
  latents = self.norm2(latents)
59
 
60
+ batch_size, seq_len, _ = latents.shape # Get batch size and sequence length
61
 
62
  # Compute query, key, and value matrices
63
+ query = self.to_q(latents)
64
+ kv_input = torch.cat((image_embeds, latents), dim=-2)
65
+ key, value = self.to_kv(kv_input).chunk(2, dim=-1)
66
 
67
  # Reshape the tensors for multi-head attention
68
+ query = query.reshape(query.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
69
+ key = key.reshape(key.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
70
+ value = value.reshape(value.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
71
 
72
  # attention
73
  scale = 1 / math.sqrt(math.sqrt(self.dim_head))
74
+ weight = (query * scale) @ (key * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
75
  weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
76
+ output = weight @ value
77
 
78
  # Reshape and return the final output
79
+ output = output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
80
 
81
+ return self.to_out(output)
82
 
83
 
84
  class LocalFacialExtractor(nn.Module):
85
  def __init__(
86
  self,
87
+ id_dim: int = 1280,
88
+ vit_dim: int = 1024,
89
+ depth: int = 10,
90
+ dim_head: int = 64,
91
+ heads: int = 16,
92
+ num_id_token: int = 5,
93
+ num_queries: int = 32,
94
+ output_dim: int = 2048,
95
+ ff_mult: int = 4,
96
+ num_scale: int = 5,
97
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  super().__init__()
99
 
100
  # Storing identity token and query information
101
  self.num_id_token = num_id_token
102
  self.vit_dim = vit_dim
103
  self.num_queries = num_queries
104
+ assert depth % num_scale == 0
105
+ self.depth = depth // num_scale
106
+ self.num_scale = num_scale
107
  scale = vit_dim**-0.5
108
 
109
  # Learnable latent query embeddings
 
118
  nn.ModuleList(
119
  [
120
  PerceiverAttention(dim=vit_dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer
121
+ nn.Sequential(
122
+ nn.LayerNorm(vit_dim),
123
+ nn.Linear(vit_dim, vit_dim * ff_mult, bias=False),
124
+ nn.GELU(),
125
+ nn.Linear(vit_dim * ff_mult, vit_dim, bias=False),
126
+ ), # ConsisIDFeedForward layer
127
  ]
128
  )
129
  )
130
 
131
  # Mappings for each of the 5 different ViT features
132
+ for i in range(num_scale):
133
  setattr(
134
  self,
135
  f"mapping_{i}",
 
155
  nn.Linear(vit_dim, vit_dim * num_id_token),
156
  )
157
 
158
+ def forward(self, id_embeds: torch.Tensor, vit_hidden_states: List[torch.Tensor]) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
159
  # Repeat latent queries for the batch size
160
+ latents = self.latents.repeat(id_embeds.size(0), 1, 1)
161
 
162
  # Map the identity embedding to tokens
163
+ id_embeds = self.id_embedding_mapping(id_embeds)
164
+ id_embeds = id_embeds.reshape(-1, self.num_id_token, self.vit_dim)
165
 
166
  # Concatenate identity tokens with the latent queries
167
+ latents = torch.cat((latents, id_embeds), dim=1)
168
 
169
+ # Process each of the num_scale visual feature inputs
170
+ for i in range(self.num_scale):
171
+ vit_feature = getattr(self, f"mapping_{i}")(vit_hidden_states[i])
172
+ ctx_feature = torch.cat((id_embeds, vit_feature), dim=1)
173
 
174
  # Pass through the PerceiverAttention and ConsisIDFeedForward layers
175
  for attn, ff in self.layers[i * self.depth : (i + 1) * self.depth]:
 
184
 
185
 
186
  class PerceiverCrossAttention(nn.Module):
187
+ def __init__(self, dim: int = 3072, dim_head: int = 128, heads: int = 16, kv_dim: int = 2048):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  super().__init__()
189
+
190
  self.scale = dim_head**-0.5
191
  self.dim_head = dim_head
192
  self.heads = heads
 
201
  self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
202
  self.to_out = nn.Linear(inner_dim, dim, bias=False)
203
 
204
+ def forward(self, image_embeds: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  # Apply layer normalization to the input image and latent features
206
+ image_embeds = self.norm1(image_embeds)
207
+ hidden_states = self.norm2(hidden_states)
208
 
209
+ batch_size, seq_len, _ = hidden_states.shape
210
 
211
  # Compute queries, keys, and values
212
+ query = self.to_q(hidden_states)
213
+ key, value = self.to_kv(image_embeds).chunk(2, dim=-1)
214
 
215
  # Reshape tensors to split into attention heads
216
+ query = query.reshape(query.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
217
+ key = key.reshape(key.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
218
+ value = value.reshape(value.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
219
 
220
  # Compute attention weights
221
  scale = 1 / math.sqrt(math.sqrt(self.dim_head))
222
+ weight = (query * scale) @ (key * scale).transpose(-2, -1) # More stable scaling than post-division
223
  weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
224
 
225
  # Compute the output via weighted combination of values
226
+ out = weight @ value
227
 
228
  # Reshape and permute to prepare for final linear transformation
229
+ out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
230
 
231
  return self.to_out(out)
232
 
 
452
  The multiplication factor applied to the feed-forward network's hidden layer size in the Local Facial
453
  Extractor (LFE). A higher value increases the model's capacity to learn more complex facial feature
454
  transformations, but also increases the computation and memory requirements.
455
+ LFE_num_scale (`int`, optional, defaults to `5`):
456
+ The number of different scales visual feature. A higher value increases the model's capacity to learn more
457
+ complex facial feature transformations, but also increases the computation and memory requirements.
458
  local_face_scale (`float`, defaults to `1.0`):
459
  A scaling factor used to adjust the importance of local facial features in the model. This can influence
460
  how strongly the model focuses on high frequency face-related content.
 
504
  LFE_num_querie: int = 32,
505
  LFE_output_dim: int = 2048,
506
  LFE_ff_mult: int = 4,
507
+ LFE_num_scale: int = 5,
508
  local_face_scale: float = 1.0,
509
  ):
510
  super().__init__()
 
569
  )
570
  self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
571
 
 
 
572
  self.is_train_face = is_train_face
573
  self.is_kps = is_kps
574
 
 
584
  self.LFE_num_querie = LFE_num_querie
585
  self.LFE_output_dim = LFE_output_dim
586
  self.LFE_ff_mult = LFE_ff_mult
587
+ self.LFE_num_scale = LFE_num_scale
588
  # cross configs
589
  self.inner_dim = inner_dim
590
  self.cross_attn_interval = cross_attn_interval
 
596
  # face modules
597
  self._init_face_inputs()
598
 
599
+ self.gradient_checkpointing = False
600
+
601
  def _set_gradient_checkpointing(self, module, value=False):
602
  self.gradient_checkpointing = value
603
 
 
614
  num_queries=self.LFE_num_querie,
615
  output_dim=self.LFE_output_dim,
616
  ff_mult=self.LFE_ff_mult,
617
+ num_scale=self.LFE_num_scale,
618
+ ).to(device, dtype=weight_dtype)
619
  self.perceiver_cross_attention = nn.ModuleList(
620
  [
621
  PerceiverCrossAttention(
 
701
  for name, module in self.named_children():
702
  fn_recursive_attn_processor(name, module, processor)
703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
704
  def forward(
705
  self,
706
  hidden_states: torch.Tensor,
 
713
  id_vit_hidden: Optional[torch.Tensor] = None,
714
  return_dict: bool = True,
715
  ):
 
 
 
 
 
 
 
716
  if attention_kwargs is not None:
717
  attention_kwargs = attention_kwargs.copy()
718
  lora_scale = attention_kwargs.pop("scale", 1.0)
 
728
  "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
729
  )
730
 
731
+ # fuse clip and insightface
732
+ valid_face_emb = None
733
+ if self.is_train_face:
734
+ id_cond = id_cond.to(device=hidden_states.device, dtype=hidden_states.dtype)
735
+ id_vit_hidden = [
736
+ tensor.to(device=hidden_states.device, dtype=hidden_states.dtype) for tensor in id_vit_hidden
737
+ ]
738
+ valid_face_emb = self.local_facial_extractor(
739
+ id_cond, id_vit_hidden
740
+ ) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048])
741
+
742
  batch_size, num_frames, channels, height, width = hidden_states.shape
743
 
744
  # 1. Time embedding
 
940
  id_cond=id_cond if id_cond is not None else None,
941
  )[0]
942
 
943
+ print(model_output)