franciszzj commited on
Commit
e9b3585
Β·
1 Parent(s): 04d5d6b

change to float16

Browse files
Files changed (3) hide show
  1. app.py +5 -2
  2. leffa/model.py +23 -11
  3. leffa/pipeline.py +0 -1
app.py CHANGED
@@ -40,18 +40,21 @@ class LeffaPredictor(object):
40
  vt_model_hd = LeffaModel(
41
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
42
  pretrained_model="./ckpts/virtual_tryon.pth",
 
43
  )
44
  self.vt_inference_hd = LeffaInference(model=vt_model_hd)
45
 
46
  vt_model_dc = LeffaModel(
47
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
48
  pretrained_model="./ckpts/virtual_tryon_dc.pth",
 
49
  )
50
  self.vt_inference_dc = LeffaInference(model=vt_model_dc)
51
 
52
  pt_model = LeffaModel(
53
  pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
54
  pretrained_model="./ckpts/pose_transfer.pth",
 
55
  )
56
  self.pt_inference = LeffaInference(model=pt_model)
57
 
@@ -248,7 +251,7 @@ if __name__ == "__main__":
248
  )
249
 
250
  vt_step = gr.Number(
251
- label="Inference Steps", minimum=30, maximum=100, step=1, value=50)
252
 
253
  vt_scale = gr.Number(
254
  label="Guidance Scale", minimum=0.1, maximum=5.0, step=0.1, value=2.5)
@@ -325,7 +328,7 @@ if __name__ == "__main__":
325
  )
326
 
327
  pt_step = gr.Number(
328
- label="Inference Steps", minimum=30, maximum=100, step=1, value=50)
329
 
330
  pt_scale = gr.Number(
331
  label="Guidance Scale", minimum=0.1, maximum=5.0, step=0.1, value=2.5)
 
40
  vt_model_hd = LeffaModel(
41
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
42
  pretrained_model="./ckpts/virtual_tryon.pth",
43
+ dtype="float16",
44
  )
45
  self.vt_inference_hd = LeffaInference(model=vt_model_hd)
46
 
47
  vt_model_dc = LeffaModel(
48
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
49
  pretrained_model="./ckpts/virtual_tryon_dc.pth",
50
+ dtype="float16",
51
  )
52
  self.vt_inference_dc = LeffaInference(model=vt_model_dc)
53
 
54
  pt_model = LeffaModel(
55
  pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
56
  pretrained_model="./ckpts/pose_transfer.pth",
57
+ dtype="float16",
58
  )
59
  self.pt_inference = LeffaInference(model=pt_model)
60
 
 
251
  )
252
 
253
  vt_step = gr.Number(
254
+ label="Inference Steps", minimum=30, maximum=100, step=1, value=30)
255
 
256
  vt_scale = gr.Number(
257
  label="Guidance Scale", minimum=0.1, maximum=5.0, step=0.1, value=2.5)
 
328
  )
329
 
330
  pt_step = gr.Number(
331
+ label="Inference Steps", minimum=30, maximum=100, step=1, value=30)
332
 
333
  pt_scale = gr.Number(
334
  label="Guidance Scale", minimum=0.1, maximum=5.0, step=0.1, value=2.5)
leffa/model.py CHANGED
@@ -23,6 +23,7 @@ class LeffaModel(nn.Module):
23
  new_in_channels: int = 12, # noisy_image: 4, mask: 1, masked_image: 4, densepose: 3
24
  height: int = 1024,
25
  width: int = 768,
 
26
  ):
27
  super().__init__()
28
 
@@ -35,6 +36,9 @@ class LeffaModel(nn.Module):
35
  new_in_channels,
36
  )
37
 
 
 
 
38
  def build_models(
39
  self,
40
  pretrained_model_name_or_path: str = "",
@@ -60,14 +64,16 @@ class LeffaModel(nn.Module):
60
  return_unused_kwargs=True,
61
  )
62
  self.vae = AutoencoderKL.from_config(vae_config, **vae_kwargs)
63
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
 
64
  # Reference UNet
65
  unet_config, unet_kwargs = ReferenceUNet.load_config(
66
  pretrained_model_name_or_path,
67
  subfolder="unet",
68
  return_unused_kwargs=True,
69
  )
70
- self.unet_encoder = ReferenceUNet.from_config(unet_config, **unet_kwargs)
 
71
  self.unet_encoder.config.addition_embed_type = None
72
  # Generative UNet
73
  unet_config, unet_kwargs = GenerativeUNet.load_config(
@@ -80,7 +86,8 @@ class LeffaModel(nn.Module):
80
  # Change Generative UNet conv_in and conv_out
81
  unet_conv_in_channel_changed = self.unet.config.in_channels != new_in_channels
82
  if unet_conv_in_channel_changed:
83
- self.unet.conv_in = self.replace_conv_in_layer(self.unet, new_in_channels)
 
84
  self.unet.config.in_channels = new_in_channels
85
  unet_conv_out_channel_changed = (
86
  self.unet.config.out_channels != self.vae.config.latent_channels
@@ -114,8 +121,10 @@ class LeffaModel(nn.Module):
114
 
115
  # Load pretrained model
116
  if pretrained_model != "" and pretrained_model is not None:
117
- self.load_state_dict(torch.load(pretrained_model, map_location="cpu"))
118
- logger.info("Load pretrained model from {}".format(pretrained_model))
 
 
119
 
120
  def replace_conv_in_layer(self, unet_model, new_in_channels):
121
  original_conv_in = unet_model.conv_in
@@ -168,7 +177,8 @@ class LeffaModel(nn.Module):
168
  return new_conv_out
169
 
170
  def vae_encode(self, pixel_values):
171
- pixel_values = pixel_values.to(device=self.vae.device, dtype=self.vae.dtype)
 
172
  with torch.no_grad():
173
  latent = self.vae.encode(pixel_values).latent_dist.sample()
174
  latent = latent * self.vae.config.scaling_factor
@@ -208,7 +218,8 @@ def remove_cross_attention(
208
  hidden_size = unet.config.block_out_channels[-1]
209
  elif name.startswith("up_blocks"):
210
  block_id = int(name[len("up_blocks.")])
211
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
 
212
  elif name.startswith("down_blocks"):
213
  block_id = int(name[len("down_blocks.")])
214
  hidden_size = unet.config.block_out_channels[block_id]
@@ -239,7 +250,6 @@ def remove_cross_attention(
239
  return adapter_modules
240
 
241
 
242
-
243
  class AttnProcessor2_0(torch.nn.Module):
244
  r"""
245
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
@@ -315,10 +325,12 @@ class AttnProcessor2_0(torch.nn.Module):
315
  inner_dim = key.shape[-1]
316
  head_dim = inner_dim // attn.heads
317
 
318
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
 
319
 
320
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
321
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
 
322
 
323
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
324
  # TODO: add support for attn.scale when we move to Torch 2.1
@@ -346,4 +358,4 @@ class AttnProcessor2_0(torch.nn.Module):
346
 
347
  hidden_states = hidden_states / attn.rescale_output_factor
348
 
349
- return hidden_states
 
23
  new_in_channels: int = 12, # noisy_image: 4, mask: 1, masked_image: 4, densepose: 3
24
  height: int = 1024,
25
  width: int = 768,
26
+ dtype: str = "float16",
27
  ):
28
  super().__init__()
29
 
 
36
  new_in_channels,
37
  )
38
 
39
+ if dtype == "float16":
40
+ self.half()
41
+
42
  def build_models(
43
  self,
44
  pretrained_model_name_or_path: str = "",
 
64
  return_unused_kwargs=True,
65
  )
66
  self.vae = AutoencoderKL.from_config(vae_config, **vae_kwargs)
67
+ self.vae_scale_factor = 2 ** (
68
+ len(self.vae.config.block_out_channels) - 1)
69
  # Reference UNet
70
  unet_config, unet_kwargs = ReferenceUNet.load_config(
71
  pretrained_model_name_or_path,
72
  subfolder="unet",
73
  return_unused_kwargs=True,
74
  )
75
+ self.unet_encoder = ReferenceUNet.from_config(
76
+ unet_config, **unet_kwargs)
77
  self.unet_encoder.config.addition_embed_type = None
78
  # Generative UNet
79
  unet_config, unet_kwargs = GenerativeUNet.load_config(
 
86
  # Change Generative UNet conv_in and conv_out
87
  unet_conv_in_channel_changed = self.unet.config.in_channels != new_in_channels
88
  if unet_conv_in_channel_changed:
89
+ self.unet.conv_in = self.replace_conv_in_layer(
90
+ self.unet, new_in_channels)
91
  self.unet.config.in_channels = new_in_channels
92
  unet_conv_out_channel_changed = (
93
  self.unet.config.out_channels != self.vae.config.latent_channels
 
121
 
122
  # Load pretrained model
123
  if pretrained_model != "" and pretrained_model is not None:
124
+ self.load_state_dict(torch.load(
125
+ pretrained_model, map_location="cpu"))
126
+ logger.info(
127
+ "Load pretrained model from {}".format(pretrained_model))
128
 
129
  def replace_conv_in_layer(self, unet_model, new_in_channels):
130
  original_conv_in = unet_model.conv_in
 
177
  return new_conv_out
178
 
179
  def vae_encode(self, pixel_values):
180
+ pixel_values = pixel_values.to(
181
+ device=self.vae.device, dtype=self.vae.dtype)
182
  with torch.no_grad():
183
  latent = self.vae.encode(pixel_values).latent_dist.sample()
184
  latent = latent * self.vae.config.scaling_factor
 
218
  hidden_size = unet.config.block_out_channels[-1]
219
  elif name.startswith("up_blocks"):
220
  block_id = int(name[len("up_blocks.")])
221
+ hidden_size = list(reversed(unet.config.block_out_channels))[
222
+ block_id]
223
  elif name.startswith("down_blocks"):
224
  block_id = int(name[len("down_blocks.")])
225
  hidden_size = unet.config.block_out_channels[block_id]
 
250
  return adapter_modules
251
 
252
 
 
253
  class AttnProcessor2_0(torch.nn.Module):
254
  r"""
255
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
 
325
  inner_dim = key.shape[-1]
326
  head_dim = inner_dim // attn.heads
327
 
328
+ query = query.view(batch_size, -1, attn.heads,
329
+ head_dim).transpose(1, 2)
330
 
331
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
332
+ value = value.view(batch_size, -1, attn.heads,
333
+ head_dim).transpose(1, 2)
334
 
335
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
336
  # TODO: add support for attn.scale when we move to Torch 2.1
 
358
 
359
  hidden_states = hidden_states / attn.rescale_output_factor
360
 
361
+ return hidden_states
leffa/pipeline.py CHANGED
@@ -106,7 +106,6 @@ class LeffaPipeline(object):
106
  )
107
  reference_features = list(reference_features)
108
 
109
-
110
  with tqdm.tqdm(total=num_inference_steps) as progress_bar:
111
  for i, t in enumerate(timesteps):
112
  # expand the latent if we are doing classifier free guidance
 
106
  )
107
  reference_features = list(reference_features)
108
 
 
109
  with tqdm.tqdm(total=num_inference_steps) as progress_bar:
110
  for i, t in enumerate(timesteps):
111
  # expand the latent if we are doing classifier free guidance