jadechoghari commited on
Commit
f819bef
·
verified ·
1 Parent(s): 3cb89bb

Update diffloss.py

Browse files
Files changed (1) hide show
  1. diffloss.py +5 -20
diffloss.py CHANGED
@@ -5,7 +5,6 @@ import math
5
 
6
  from .diffusion import create_diffusion
7
 
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
  class DiffLoss(nn.Module):
11
  """Diffusion Loss"""
@@ -36,12 +35,12 @@ class DiffLoss(nn.Module):
36
  def sample(self, z, temperature=1.0, cfg=1.0):
37
  # diffusion loss sampling
38
  if not cfg == 1.0:
39
- noise = torch.randn(z.shape[0] // 2, self.in_channels).to(device)
40
  noise = torch.cat([noise, noise], dim=0)
41
  model_kwargs = dict(c=z, cfg_scale=cfg)
42
  sample_fn = self.net.forward_with_cfg
43
  else:
44
- noise = torch.randn(z.shape[0], self.in_channels).to(device)
45
  model_kwargs = dict(c=z)
46
  sample_fn = self.net.forward
47
 
@@ -91,23 +90,9 @@ class TimestepEmbedder(nn.Module):
91
  embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
92
  return embedding
93
 
94
- # def forward(self, t):
95
- # t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
96
- # t_emb = self.mlp(t_freq)
97
- # return t_emb
98
  def forward(self, t):
99
-
100
- device = next(self.mlp.parameters()).device
101
-
102
- t = t.to(device)
103
-
104
  t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
105
-
106
- t_freq = t_freq.to(device)
107
-
108
  t_emb = self.mlp(t_freq)
109
-
110
-
111
  return t_emb
112
 
113
 
@@ -145,7 +130,7 @@ class ResBlock(nn.Module):
145
 
146
  class FinalLayer(nn.Module):
147
  """
148
- The final layer of DiT.
149
  """
150
  def __init__(self, model_channels, out_channels):
151
  super().__init__()
@@ -232,10 +217,10 @@ class SimpleMLPAdaLN(nn.Module):
232
  def forward(self, x, t, c):
233
  """
234
  Apply the model to an input batch.
235
- :param x: an [N x C x ...] Tensor of inputs.
236
  :param t: a 1-D batch of timesteps.
237
  :param c: conditioning from AR transformer.
238
- :return: an [N x C x ...] Tensor of outputs.
239
  """
240
  x = self.input_proj(x)
241
  t = self.time_embed(t)
 
5
 
6
  from .diffusion import create_diffusion
7
 
 
8
 
9
  class DiffLoss(nn.Module):
10
  """Diffusion Loss"""
 
35
  def sample(self, z, temperature=1.0, cfg=1.0):
36
  # diffusion loss sampling
37
  if not cfg == 1.0:
38
+ noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda()
39
  noise = torch.cat([noise, noise], dim=0)
40
  model_kwargs = dict(c=z, cfg_scale=cfg)
41
  sample_fn = self.net.forward_with_cfg
42
  else:
43
+ noise = torch.randn(z.shape[0], self.in_channels).cuda()
44
  model_kwargs = dict(c=z)
45
  sample_fn = self.net.forward
46
 
 
90
  embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
91
  return embedding
92
 
 
 
 
 
93
  def forward(self, t):
 
 
 
 
 
94
  t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
 
 
 
95
  t_emb = self.mlp(t_freq)
 
 
96
  return t_emb
97
 
98
 
 
130
 
131
  class FinalLayer(nn.Module):
132
  """
133
+ The final layer adopted from DiT.
134
  """
135
  def __init__(self, model_channels, out_channels):
136
  super().__init__()
 
217
  def forward(self, x, t, c):
218
  """
219
  Apply the model to an input batch.
220
+ :param x: an [N x C] Tensor of inputs.
221
  :param t: a 1-D batch of timesteps.
222
  :param c: conditioning from AR transformer.
223
+ :return: an [N x C] Tensor of outputs.
224
  """
225
  x = self.input_proj(x)
226
  t = self.time_embed(t)