zhengchong commited on
Commit
905c952
·
1 Parent(s): 3e791eb

feat: Enhance CatVTON functionality with new pipelines and UI improvements

Browse files

- Added CatVTONPix2PixPipeline and FluxTryOnPipeline to support additional virtual try-on methods.
- Implemented new submit functions for mask-free and Flux-based try-on.
- Updated UI to include separate tabs for mask-based and mask-free options, enhancing user experience.
- Modified requirements.txt to include new dependencies and updated existing ones.
- Improved error handling and image processing in the submission functions.

app.py CHANGED
@@ -13,7 +13,8 @@ from huggingface_hub import snapshot_download
13
  from PIL import Image
14
  torch.jit.script = lambda f: f
15
  from model.cloth_masker import AutoMasker, vis_mask
16
- from model.pipeline import CatVTONPipeline
 
17
  from utils import init_weight_dtype, resize_and_crop, resize_and_padding
18
 
19
 
@@ -105,7 +106,10 @@ def image_grid(imgs, rows, cols):
105
 
106
 
107
  args = parse_args()
108
- repo_path = snapshot_download(repo_id=args.resume_path)
 
 
 
109
  # Pipeline
110
  pipeline = CatVTONPipeline(
111
  base_ckpt=args.base_model_path,
@@ -123,6 +127,30 @@ automasker = AutoMasker(
123
  device='cuda',
124
  )
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  @spaces.GPU(duration=120)
127
  def submit_function(
128
  person_image,
@@ -202,10 +230,135 @@ def submit_function(
202
  new_result_image.paste(result_image, (condition_width + 5, 0))
203
  return new_result_image
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  def person_example_fn(image_path):
207
  return image_path
208
 
 
209
  HEADER = """
210
  <h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
211
  <div style="display: flex; justify-content: center; align-items: center;">
@@ -241,136 +394,321 @@ HEADER = """
241
  def app_gradio():
242
  with gr.Blocks(title="CatVTON") as demo:
243
  gr.Markdown(HEADER)
244
- with gr.Row():
245
- with gr.Column(scale=1, min_width=350):
246
- with gr.Row():
247
- image_path = gr.Image(
248
- type="filepath",
249
- interactive=True,
250
- visible=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  )
252
- person_image = gr.ImageEditor(
253
- interactive=True, label="Person Image", type="filepath"
 
254
  )
255
-
256
- with gr.Row():
257
- with gr.Column(scale=1, min_width=230):
258
- cloth_image = gr.Image(
259
- interactive=True, label="Condition Image", type="filepath"
 
 
260
  )
261
- with gr.Column(scale=1, min_width=120):
262
- gr.Markdown(
263
- '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
264
  )
265
- cloth_type = gr.Radio(
266
- label="Try-On Cloth Type",
267
- choices=["upper", "lower", "overall"],
268
- value="upper",
269
  )
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
- submit = gr.Button("Submit")
273
- gr.Markdown(
274
- '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
275
  )
276
-
277
- gr.Markdown(
278
- '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
 
 
 
 
 
 
 
 
 
 
279
  )
280
- with gr.Accordion("Advanced Options", open=False):
281
- num_inference_steps = gr.Slider(
282
- label="Inference Step", minimum=10, maximum=100, step=5, value=50
283
- )
284
- # Guidence Scale
285
- guidance_scale = gr.Slider(
286
- label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
287
- )
288
- # Random Seed
289
- seed = gr.Slider(
290
- label="Seed", minimum=-1, maximum=10000, step=1, value=42
 
 
 
 
 
 
 
 
 
 
 
 
291
  )
292
- show_type = gr.Radio(
293
- label="Show Type",
294
- choices=["result only", "input & result", "input & mask & result"],
295
- value="input & mask & result",
296
  )
297
-
298
- with gr.Column(scale=2, min_width=500):
299
- result_image = gr.Image(interactive=False, label="Result")
300
- with gr.Row():
301
- # Photo Examples
302
- root_path = "resource/demo/example"
303
- with gr.Column():
304
- men_exm = gr.Examples(
305
- examples=[
306
- os.path.join(root_path, "person", "men", _)
307
- for _ in os.listdir(os.path.join(root_path, "person", "men"))
308
- ],
309
- examples_per_page=4,
310
- inputs=image_path,
311
- label="Person Examples ①",
312
  )
313
- women_exm = gr.Examples(
314
- examples=[
315
- os.path.join(root_path, "person", "women", _)
316
- for _ in os.listdir(os.path.join(root_path, "person", "women"))
317
- ],
318
- examples_per_page=4,
319
- inputs=image_path,
320
- label="Person Examples ",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  )
322
- gr.Markdown(
323
- '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
324
  )
325
- with gr.Column():
326
- condition_upper_exm = gr.Examples(
327
- examples=[
328
- os.path.join(root_path, "condition", "upper", _)
329
- for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
330
- ],
331
- examples_per_page=4,
332
- inputs=cloth_image,
333
- label="Condition Upper Examples",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  )
335
- condition_overall_exm = gr.Examples(
336
- examples=[
337
- os.path.join(root_path, "condition", "overall", _)
338
- for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
339
- ],
340
- examples_per_page=4,
341
- inputs=cloth_image,
342
- label="Condition Overall Examples",
343
  )
344
- condition_person_exm = gr.Examples(
345
- examples=[
346
- os.path.join(root_path, "condition", "person", _)
347
- for _ in os.listdir(os.path.join(root_path, "condition", "person"))
348
- ],
349
- examples_per_page=4,
350
- inputs=cloth_image,
351
- label="Condition Reference Person Examples",
352
  )
353
- gr.Markdown(
354
- '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
 
 
355
  )
356
 
357
- image_path.change(
358
- person_example_fn, inputs=image_path, outputs=person_image
359
- )
360
-
361
- submit.click(
362
- submit_function,
363
- [
364
- person_image,
365
- cloth_image,
366
- cloth_type,
367
- num_inference_steps,
368
- guidance_scale,
369
- seed,
370
- show_type,
371
- ],
372
- result_image,
373
- )
374
  demo.queue().launch(share=True, show_error=True)
375
 
376
 
 
13
  from PIL import Image
14
  torch.jit.script = lambda f: f
15
  from model.cloth_masker import AutoMasker, vis_mask
16
+ from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
17
+ from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
18
  from utils import init_weight_dtype, resize_and_crop, resize_and_padding
19
 
20
 
 
106
 
107
 
108
  args = parse_args()
109
+
110
+ # Mask-based CatVTON
111
+ catvton_repo = "zhengchong/CatVTON"
112
+ repo_path = snapshot_download(repo_id=catvton_repo)
113
  # Pipeline
114
  pipeline = CatVTONPipeline(
115
  base_ckpt=args.base_model_path,
 
127
  device='cuda',
128
  )
129
 
130
+
131
+ # Flux-based CatVTON
132
+ flux_repo = "black-forest-labs/FLUX.1-Fill-dev"
133
+ pipeline_flux = FluxTryOnPipeline.from_pretrained(flux_repo)
134
+ pipeline_flux.load_lora_weights(
135
+ os.path.join(repo_path, "flux-lora"),
136
+ weight_name='pytorch_lora_weights.safetensors'
137
+ )
138
+ pipeline_flux.to("cuda", init_weight_dtype(args.mixed_precision))
139
+
140
+
141
+ # Mask-free CatVTON
142
+ catvton_mf_repo = "zhengchong/CatVTON-MaskFree"
143
+ repo_path_mf = snapshot_download(repo_id=catvton_mf_repo)
144
+ pipeline_p2p = CatVTONPix2PixPipeline(
145
+ base_ckpt=args.p2p_base_model_path,
146
+ attn_ckpt=repo_path,
147
+ attn_ckpt_version="mix-48k-1024",
148
+ weight_dtype=init_weight_dtype(args.mixed_precision),
149
+ use_tf32=args.allow_tf32,
150
+ device='cuda'
151
+ )
152
+
153
+
154
  @spaces.GPU(duration=120)
155
  def submit_function(
156
  person_image,
 
230
  new_result_image.paste(result_image, (condition_width + 5, 0))
231
  return new_result_image
232
 
233
+ @spaces.GPU(duration=120)
234
+ def submit_function_p2p(
235
+ person_image,
236
+ cloth_image,
237
+ num_inference_steps,
238
+ guidance_scale,
239
+ seed):
240
+ person_image= person_image["background"]
241
+
242
+ tmp_folder = args.output_dir
243
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
244
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
245
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
246
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
247
+
248
+ generator = None
249
+ if seed != -1:
250
+ generator = torch.Generator(device='cuda').manual_seed(seed)
251
+
252
+ person_image = Image.open(person_image).convert("RGB")
253
+ cloth_image = Image.open(cloth_image).convert("RGB")
254
+ person_image = resize_and_crop(person_image, (args.width, args.height))
255
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
256
+
257
+ # Inference
258
+ try:
259
+ result_image = pipeline_p2p(
260
+ image=person_image,
261
+ condition_image=cloth_image,
262
+ num_inference_steps=num_inference_steps,
263
+ guidance_scale=guidance_scale,
264
+ generator=generator
265
+ )[0]
266
+ except Exception as e:
267
+ raise gr.Error(
268
+ "An error occurred. Please try again later: {}".format(e)
269
+ )
270
+
271
+ # Post-process
272
+ save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
273
+ save_result_image.save(result_save_path)
274
+ return result_image
275
+
276
+ @spaces.GPU(duration=120)
277
+ def submit_function_flux(
278
+ person_image,
279
+ cloth_image,
280
+ cloth_type,
281
+ resolution,
282
+ num_inference_steps,
283
+ guidance_scale,
284
+ seed,
285
+ show_type
286
+ ):
287
+ # Set height and width based on resolution
288
+ height = resolution
289
+ width = int(height * 0.75)
290
+ args.width = width
291
+ args.height = height
292
+
293
+ # Process image editor input
294
+ person_image, mask = person_image["background"], person_image["layers"][0]
295
+ mask = Image.open(mask).convert("L")
296
+ if len(np.unique(np.array(mask))) == 1:
297
+ mask = None
298
+ else:
299
+ mask = np.array(mask)
300
+ mask[mask > 0] = 255
301
+ mask = Image.fromarray(mask)
302
+
303
+ # Set random seed
304
+ generator = None
305
+ if seed != -1:
306
+ generator = torch.Generator(device='cuda').manual_seed(seed)
307
+
308
+ # Process input images
309
+ person_image = Image.open(person_image).convert("RGB")
310
+ cloth_image = Image.open(cloth_image).convert("RGB")
311
+
312
+ # Adjust image sizes
313
+ person_image = resize_and_crop(person_image, (args.width, args.height))
314
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
315
+
316
+ # Process mask
317
+ if mask is not None:
318
+ mask = resize_and_crop(mask, (args.width, args.height))
319
+ else:
320
+ mask = automasker(
321
+ person_image,
322
+ cloth_type
323
+ )['mask']
324
+ mask = mask_processor.blur(mask, blur_factor=9)
325
+
326
+ # Inference
327
+ result_image = pipeline_flux(
328
+ image=person_image,
329
+ condition_image=cloth_image,
330
+ mask=mask,
331
+ num_inference_steps=num_inference_steps,
332
+ guidance_scale=guidance_scale,
333
+ generator=generator
334
+ )[0]
335
+
336
+ # Post-processing
337
+ masked_person = vis_mask(person_image, mask)
338
+
339
+ # Return result based on show type
340
+ if show_type == "result only":
341
+ return result_image
342
+ else:
343
+ width, height = person_image.size
344
+ if show_type == "input & result":
345
+ condition_width = width // 2
346
+ conditions = image_grid([person_image, cloth_image], 2, 1)
347
+ else:
348
+ condition_width = width // 3
349
+ conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
350
+
351
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
352
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
353
+ new_result_image.paste(conditions, (0, 0))
354
+ new_result_image.paste(result_image, (condition_width + 5, 0))
355
+ return new_result_image
356
+
357
 
358
  def person_example_fn(image_path):
359
  return image_path
360
 
361
+
362
  HEADER = """
363
  <h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
364
  <div style="display: flex; justify-content: center; align-items: center;">
 
394
  def app_gradio():
395
  with gr.Blocks(title="CatVTON") as demo:
396
  gr.Markdown(HEADER)
397
+ with gr.Tab("Mask-based & SD1.5"):
398
+ with gr.Row():
399
+ with gr.Column(scale=1, min_width=350):
400
+ with gr.Row():
401
+ image_path = gr.Image(
402
+ type="filepath",
403
+ interactive=True,
404
+ visible=False,
405
+ )
406
+ person_image = gr.ImageEditor(
407
+ interactive=True, label="Person Image", type="filepath"
408
+ )
409
+
410
+ with gr.Row():
411
+ with gr.Column(scale=1, min_width=230):
412
+ cloth_image = gr.Image(
413
+ interactive=True, label="Condition Image", type="filepath"
414
+ )
415
+ with gr.Column(scale=1, min_width=120):
416
+ gr.Markdown(
417
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
418
+ )
419
+ cloth_type = gr.Radio(
420
+ label="Try-On Cloth Type",
421
+ choices=["upper", "lower", "overall"],
422
+ value="upper",
423
+ )
424
+
425
+
426
+ submit = gr.Button("Submit")
427
+ gr.Markdown(
428
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
429
  )
430
+
431
+ gr.Markdown(
432
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
433
  )
434
+ with gr.Accordion("Advanced Options", open=False):
435
+ num_inference_steps = gr.Slider(
436
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
437
+ )
438
+ # Guidence Scale
439
+ guidance_scale = gr.Slider(
440
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
441
  )
442
+ # Random Seed
443
+ seed = gr.Slider(
444
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
445
  )
446
+ show_type = gr.Radio(
447
+ label="Show Type",
448
+ choices=["result only", "input & result", "input & mask & result"],
449
+ value="input & mask & result",
450
  )
451
 
452
+ with gr.Column(scale=2, min_width=500):
453
+ result_image = gr.Image(interactive=False, label="Result")
454
+ with gr.Row():
455
+ # Photo Examples
456
+ root_path = "resource/demo/example"
457
+ with gr.Column():
458
+ men_exm = gr.Examples(
459
+ examples=[
460
+ os.path.join(root_path, "person", "men", _)
461
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
462
+ ],
463
+ examples_per_page=4,
464
+ inputs=image_path,
465
+ label="Person Examples ①",
466
+ )
467
+ women_exm = gr.Examples(
468
+ examples=[
469
+ os.path.join(root_path, "person", "women", _)
470
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
471
+ ],
472
+ examples_per_page=4,
473
+ inputs=image_path,
474
+ label="Person Examples ②",
475
+ )
476
+ gr.Markdown(
477
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
478
+ )
479
+ with gr.Column():
480
+ condition_upper_exm = gr.Examples(
481
+ examples=[
482
+ os.path.join(root_path, "condition", "upper", _)
483
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
484
+ ],
485
+ examples_per_page=4,
486
+ inputs=cloth_image,
487
+ label="Condition Upper Examples",
488
+ )
489
+ condition_overall_exm = gr.Examples(
490
+ examples=[
491
+ os.path.join(root_path, "condition", "overall", _)
492
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
493
+ ],
494
+ examples_per_page=4,
495
+ inputs=cloth_image,
496
+ label="Condition Overall Examples",
497
+ )
498
+ condition_person_exm = gr.Examples(
499
+ examples=[
500
+ os.path.join(root_path, "condition", "person", _)
501
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
502
+ ],
503
+ examples_per_page=4,
504
+ inputs=cloth_image,
505
+ label="Condition Reference Person Examples",
506
+ )
507
+ gr.Markdown(
508
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
509
+ )
510
 
511
+ image_path.change(
512
+ person_example_fn, inputs=image_path, outputs=person_image
 
513
  )
514
+
515
+ submit.click(
516
+ submit_function,
517
+ [
518
+ person_image,
519
+ cloth_image,
520
+ cloth_type,
521
+ num_inference_steps,
522
+ guidance_scale,
523
+ seed,
524
+ show_type,
525
+ ],
526
+ result_image,
527
  )
528
+
529
+ with gr.Tab("Mask-free & SD1.5"):
530
+ with gr.Row():
531
+ with gr.Column(scale=1, min_width=350):
532
+ with gr.Row():
533
+ image_path_p2p = gr.Image(
534
+ type="filepath",
535
+ interactive=True,
536
+ visible=False,
537
+ )
538
+ person_image_p2p = gr.ImageEditor(
539
+ interactive=True, label="Person Image", type="filepath"
540
+ )
541
+
542
+ with gr.Row():
543
+ with gr.Column(scale=1, min_width=230):
544
+ cloth_image_p2p = gr.Image(
545
+ interactive=True, label="Condition Image", type="filepath"
546
+ )
547
+
548
+ submit_p2p = gr.Button("Submit")
549
+ gr.Markdown(
550
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
551
  )
552
+
553
+ gr.Markdown(
554
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
 
555
  )
556
+ with gr.Accordion("Advanced Options", open=False):
557
+ num_inference_steps_p2p = gr.Slider(
558
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
559
+ )
560
+ # Guidence Scale
561
+ guidance_scale_p2p = gr.Slider(
562
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
563
+ )
564
+ # Random Seed
565
+ seed_p2p = gr.Slider(
566
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
 
 
 
 
567
  )
568
+ # show_type = gr.Radio(
569
+ # label="Show Type",
570
+ # choices=["result only", "input & result", "input & mask & result"],
571
+ # value="input & mask & result",
572
+ # )
573
+
574
+ with gr.Column(scale=2, min_width=500):
575
+ result_image_p2p = gr.Image(interactive=False, label="Result")
576
+ with gr.Row():
577
+ # Photo Examples
578
+ root_path = "resource/demo/example"
579
+ with gr.Column():
580
+ gr.Examples(
581
+ examples=[
582
+ os.path.join(root_path, "person", "men", _)
583
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
584
+ ],
585
+ examples_per_page=4,
586
+ inputs=image_path_p2p,
587
+ label="Person Examples ①",
588
+ )
589
+ gr.Examples(
590
+ examples=[
591
+ os.path.join(root_path, "person", "women", _)
592
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
593
+ ],
594
+ examples_per_page=4,
595
+ inputs=image_path_p2p,
596
+ label="Person Examples ②",
597
+ )
598
+ gr.Markdown(
599
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
600
+ )
601
+ with gr.Column():
602
+ gr.Examples(
603
+ examples=[
604
+ os.path.join(root_path, "condition", "upper", _)
605
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
606
+ ],
607
+ examples_per_page=4,
608
+ inputs=cloth_image_p2p,
609
+ label="Condition Upper Examples",
610
+ )
611
+ gr.Examples(
612
+ examples=[
613
+ os.path.join(root_path, "condition", "overall", _)
614
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
615
+ ],
616
+ examples_per_page=4,
617
+ inputs=cloth_image_p2p,
618
+ label="Condition Overall Examples",
619
+ )
620
+ condition_person_exm = gr.Examples(
621
+ examples=[
622
+ os.path.join(root_path, "condition", "person", _)
623
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
624
+ ],
625
+ examples_per_page=4,
626
+ inputs=cloth_image_p2p,
627
+ label="Condition Reference Person Examples",
628
+ )
629
+ gr.Markdown(
630
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
631
+ )
632
+
633
+ image_path_p2p.change(
634
+ person_example_fn, inputs=image_path_p2p, outputs=person_image_p2p
635
+ )
636
+
637
+ submit_p2p.click(
638
+ submit_function_p2p,
639
+ [
640
+ person_image_p2p,
641
+ cloth_image_p2p,
642
+ num_inference_steps_p2p,
643
+ guidance_scale_p2p,
644
+ seed_p2p],
645
+ result_image_p2p,
646
+ )
647
+
648
+ with gr.Tab("Mask-based & Flux.1 Fill Dev"):
649
+ with gr.Row():
650
+ with gr.Column(scale=1, min_width=350):
651
+ with gr.Row():
652
+ image_path_flux = gr.Image(
653
+ type="filepath",
654
+ interactive=True,
655
+ visible=False,
656
  )
657
+ person_image_flux = gr.ImageEditor(
658
+ interactive=True, label="Person Image", type="filepath"
659
  )
660
+
661
+ with gr.Row():
662
+ with gr.Column(scale=1, min_width=230):
663
+ cloth_image_flux = gr.Image(
664
+ interactive=True, label="Condition Image", type="filepath"
665
+ )
666
+ with gr.Column(scale=1, min_width=120):
667
+ gr.Markdown(
668
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
669
+ )
670
+ cloth_type = gr.Radio(
671
+ label="Try-On Cloth Type",
672
+ choices=["upper", "lower", "overall"],
673
+ value="upper",
674
+ )
675
+
676
+ submit_flux = gr.Button("Submit")
677
+ gr.Markdown(
678
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
679
+ )
680
+
681
+ with gr.Accordion("Advanced Options", open=False):
682
+ num_inference_steps_flux = gr.Slider(
683
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
684
  )
685
+ # Guidence Scale
686
+ guidance_scale_flux = gr.Slider(
687
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
 
 
 
 
 
688
  )
689
+ # Random Seed
690
+ seed_flux = gr.Slider(
691
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
 
 
 
 
 
692
  )
693
+ show_type = gr.Radio(
694
+ label="Show Type",
695
+ choices=["result only", "input & result", "input & mask & result"],
696
+ value="input & mask & result",
697
  )
698
 
699
+ with gr.Column(scale=2, min_width=500):
700
+ result_image_flux = gr.Image(interactive=False, label="Result")
701
+
702
+ image_path_flux.change(
703
+ person_example_fn, inputs=image_path_flux, outputs=person_image_flux
704
+ )
705
+
706
+ submit_flux.click(
707
+ submit_function_flux,
708
+ [person_image_flux, cloth_image_flux, cloth_type, num_inference_steps_flux, guidance_scale_flux, seed_flux, show_type],
709
+ result_image_flux,
710
+ )
711
+
 
 
 
 
712
  demo.queue().launch(share=True, show_error=True)
713
 
714
 
model/flux/pipeline_flux_tryon.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from diffusers.image_processor import VaeImageProcessor
7
+ from diffusers.loaders import (
8
+ FluxLoraLoaderMixin,
9
+ FromSingleFileMixin,
10
+ TextualInversionLoaderMixin,
11
+ )
12
+ from diffusers.models.autoencoders import AutoencoderKL
13
+ from diffusers.pipelines.flux.pipeline_flux_fill import (
14
+ calculate_shift,
15
+ retrieve_latents,
16
+ retrieve_timesteps,
17
+ )
18
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
19
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
20
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
21
+ from diffusers.utils import logging
22
+ from diffusers.utils.torch_utils import randn_tensor
23
+
24
+ from model.flux.transformer_flux import FluxTransformer2DModel
25
+
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+ # Modified from `diffusers.pipelines.flux.pipeline_flux_fill.FluxFillPipeline`
29
+ class FluxTryOnPipeline(
30
+ DiffusionPipeline,
31
+ FluxLoraLoaderMixin,
32
+ FromSingleFileMixin,
33
+ TextualInversionLoaderMixin,
34
+ ):
35
+ model_cpu_offload_seq = "transformer->vae"
36
+ _optional_components = []
37
+ _callback_tensor_inputs = ["latents"]
38
+
39
+ def __init__(
40
+ self,
41
+ vae: AutoencoderKL,
42
+ scheduler: FlowMatchEulerDiscreteScheduler,
43
+ transformer: FluxTransformer2DModel,
44
+ ):
45
+ super().__init__()
46
+ self.register_modules(
47
+ vae=vae,
48
+ scheduler=scheduler,
49
+ transformer=transformer,
50
+ )
51
+
52
+ self.vae_scale_factor = (
53
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
54
+ )
55
+
56
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
57
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
58
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
59
+ self.mask_processor = VaeImageProcessor(
60
+ vae_scale_factor=self.vae_scale_factor * 2,
61
+ vae_latent_channels=self.vae.config.latent_channels,
62
+ do_normalize=False,
63
+ do_binarize=True,
64
+ do_convert_grayscale=True,
65
+ )
66
+ self.default_sample_size = 128
67
+
68
+ self.transformer.remove_text_layers() # TryOnEdit: remove text layers
69
+
70
+ @classmethod
71
+ def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs):
72
+ transformer = FluxTransformer2DModel.from_pretrained(pretrained_model_name_or_path, subfolder="transformer")
73
+ transformer.remove_text_layers()
74
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
75
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
76
+ return FluxTryOnPipeline(vae, scheduler, transformer)
77
+
78
+ def prepare_mask_latents(
79
+ self,
80
+ mask,
81
+ masked_image,
82
+ batch_size,
83
+ num_channels_latents,
84
+ num_images_per_prompt,
85
+ height,
86
+ width,
87
+ dtype,
88
+ device,
89
+ generator,
90
+ ):
91
+ # 1. calculate the height and width of the latents
92
+ # VAE applies 8x compression on images but we must also account for packing which requires
93
+ # latent height and width to be divisible by 2.
94
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
95
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
96
+
97
+ # 2. encode the masked image
98
+ if masked_image.shape[1] == num_channels_latents:
99
+ masked_image_latents = masked_image
100
+ else:
101
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
102
+
103
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
104
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
105
+
106
+ # 3. duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
107
+ batch_size = batch_size * num_images_per_prompt
108
+ if mask.shape[0] < batch_size:
109
+ if not batch_size % mask.shape[0] == 0:
110
+ raise ValueError(
111
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
112
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
113
+ " of masks that you pass is divisible by the total requested batch size."
114
+ )
115
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
116
+ if masked_image_latents.shape[0] < batch_size:
117
+ if not batch_size % masked_image_latents.shape[0] == 0:
118
+ raise ValueError(
119
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
120
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
121
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
122
+ )
123
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
124
+
125
+ # 4. pack the masked_image_latents
126
+ # batch_size, num_channels_latents, height, width -> batch_size, height//2 * width//2 , num_channels_latents*4
127
+ masked_image_latents = self._pack_latents(
128
+ masked_image_latents,
129
+ batch_size,
130
+ num_channels_latents,
131
+ height,
132
+ width,
133
+ )
134
+
135
+ # 5.resize mask to latents shape we we concatenate the mask to the latents
136
+ mask = mask[:, 0, :, :] # batch_size, 8 * height, 8 * width (mask has not been 8x compressed)
137
+ mask = mask.view(
138
+ batch_size, height, self.vae_scale_factor, width, self.vae_scale_factor
139
+ ) # batch_size, height, 8, width, 8
140
+ mask = mask.permute(0, 2, 4, 1, 3) # batch_size, 8, 8, height, width
141
+ mask = mask.reshape(
142
+ batch_size, self.vae_scale_factor * self.vae_scale_factor, height, width
143
+ ) # batch_size, 8*8, height, width
144
+
145
+ # 6. pack the mask:
146
+ # batch_size, 64, height, width -> batch_size, height//2 * width//2 , 64*2*2
147
+ mask = self._pack_latents(
148
+ mask,
149
+ batch_size,
150
+ self.vae_scale_factor * self.vae_scale_factor,
151
+ height,
152
+ width,
153
+ )
154
+ mask = mask.to(device=device, dtype=dtype)
155
+
156
+ return mask, masked_image_latents
157
+
158
+ def check_inputs(
159
+ self,
160
+ height,
161
+ width,
162
+ callback_on_step_end_tensor_inputs=None,
163
+ max_sequence_length=None,
164
+ image=None,
165
+ mask_image=None,
166
+ condition_image=None,
167
+ masked_image_latents=None,
168
+ ):
169
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
170
+ logger.warning(
171
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
172
+ )
173
+
174
+ if callback_on_step_end_tensor_inputs is not None and not all(
175
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
176
+ ):
177
+ raise ValueError(
178
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
179
+ )
180
+
181
+ if max_sequence_length is not None and max_sequence_length > 512:
182
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
183
+
184
+ if image is not None and masked_image_latents is not None:
185
+ raise ValueError(
186
+ "Please provide either `image` or `masked_image_latents`, `masked_image_latents` should not be passed."
187
+ )
188
+
189
+ if image is not None and mask_image is None:
190
+ raise ValueError("Please provide `mask_image` when passing `image`.")
191
+
192
+ if condition_image is None:
193
+ raise ValueError("Please provide `condition_image`.")
194
+
195
+ @staticmethod
196
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
197
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
198
+ latent_image_ids = torch.zeros(height, width, 3)
199
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
200
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
201
+
202
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
203
+
204
+ latent_image_ids = latent_image_ids.reshape(
205
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
206
+ )
207
+
208
+ return latent_image_ids.to(device=device, dtype=dtype)
209
+
210
+ @staticmethod
211
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
212
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
213
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
214
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
215
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
216
+
217
+ return latents
218
+
219
+ @staticmethod
220
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
221
+ def _unpack_latents(latents, height, width, vae_scale_factor):
222
+ batch_size, num_patches, channels = latents.shape
223
+
224
+ # VAE applies 8x compression on images but we must also account for packing which requires
225
+ # latent height and width to be divisible by 2.
226
+ height = 2 * (int(height) // (vae_scale_factor * 2))
227
+ width = 2 * (int(width) // (vae_scale_factor * 2))
228
+
229
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
230
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
231
+
232
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
233
+
234
+ return latents
235
+
236
+ def enable_vae_slicing(self):
237
+ r"""
238
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
239
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
240
+ """
241
+ self.vae.enable_slicing()
242
+
243
+ def disable_vae_slicing(self):
244
+ r"""
245
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
246
+ computing decoding in one step.
247
+ """
248
+ self.vae.disable_slicing()
249
+
250
+ def enable_vae_tiling(self):
251
+ r"""
252
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
253
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
254
+ processing larger images.
255
+ """
256
+ self.vae.enable_tiling()
257
+
258
+ def disable_vae_tiling(self):
259
+ r"""
260
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
261
+ computing decoding in one step.
262
+ """
263
+ self.vae.disable_tiling()
264
+
265
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
266
+ def prepare_latents(
267
+ self,
268
+ batch_size,
269
+ num_channels_latents,
270
+ height,
271
+ width,
272
+ dtype,
273
+ device,
274
+ generator,
275
+ latents=None,
276
+ ):
277
+ # VAE applies 8x compression on images but we must also account for packing which requires
278
+ # latent height and width to be divisible by 2.
279
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
280
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
281
+
282
+ shape = (batch_size, num_channels_latents, height, width)
283
+
284
+ if latents is not None:
285
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
286
+ return latents.to(device=device, dtype=dtype), latent_image_ids
287
+
288
+ if isinstance(generator, list) and len(generator) != batch_size:
289
+ raise ValueError(
290
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
291
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
292
+ )
293
+
294
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
295
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
296
+
297
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
298
+
299
+ return latents, latent_image_ids
300
+
301
+ @property
302
+ def guidance_scale(self):
303
+ return self._guidance_scale
304
+
305
+ @property
306
+ def joint_attention_kwargs(self):
307
+ return self._joint_attention_kwargs
308
+
309
+ @property
310
+ def num_timesteps(self):
311
+ return self._num_timesteps
312
+
313
+ @property
314
+ def interrupt(self):
315
+ return self._interrupt
316
+
317
+ @torch.no_grad()
318
+ def __call__(
319
+ self,
320
+ image: Optional[torch.FloatTensor] = None,
321
+ condition_image: Optional[torch.FloatTensor] = None, # TryOnEdit: condition image (garment)
322
+ mask_image: Optional[torch.FloatTensor] = None,
323
+ masked_image_latents: Optional[torch.FloatTensor] = None,
324
+ height: Optional[int] = None,
325
+ width: Optional[int] = None,
326
+ num_inference_steps: int = 50,
327
+ sigmas: Optional[List[float]] = None,
328
+ guidance_scale: float = 30.0,
329
+ num_images_per_prompt: Optional[int] = 1,
330
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
331
+ latents: Optional[torch.FloatTensor] = None,
332
+ output_type: Optional[str] = "pil",
333
+ return_dict: bool = True,
334
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
335
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
336
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
337
+ max_sequence_length: int = 512,
338
+ ):
339
+ height = height or self.default_sample_size * self.vae_scale_factor
340
+ width = width or self.default_sample_size * self.vae_scale_factor
341
+
342
+ # 1. Check inputs. Raise error if not correct
343
+ self.check_inputs(
344
+ height,
345
+ width,
346
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
347
+ max_sequence_length=max_sequence_length,
348
+ image=image,
349
+ mask_image=mask_image,
350
+ condition_image=condition_image,
351
+ masked_image_latents=masked_image_latents,
352
+ )
353
+
354
+ self._guidance_scale = guidance_scale
355
+ self._joint_attention_kwargs = joint_attention_kwargs
356
+ self._interrupt = False
357
+
358
+ # 2. Define call parameters
359
+ batch_size = 1
360
+ device = self._execution_device
361
+ dtype = self.transformer.dtype
362
+
363
+ # 3. Prepare prompt embeddings
364
+ lora_scale = (
365
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
366
+ )
367
+
368
+ # 4. Prepare latent variables
369
+ num_channels_latents = self.vae.config.latent_channels
370
+ latents, latent_image_ids = self.prepare_latents(
371
+ batch_size * num_images_per_prompt,
372
+ num_channels_latents,
373
+ height,
374
+ width * 2, # TryOnEdit: width * 2
375
+ dtype,
376
+ device,
377
+ generator,
378
+ latents,
379
+ )
380
+
381
+ # 5. Prepare mask and masked image latents
382
+ if masked_image_latents is not None:
383
+ masked_image_latents = masked_image_latents.to(latents.device)
384
+ else:
385
+ image = self.image_processor.preprocess(image, height=height, width=width)
386
+ condition_image = self.image_processor.preprocess(condition_image, height=height, width=width)
387
+ mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
388
+
389
+ masked_image = image * (1 - mask_image)
390
+ masked_image = masked_image.to(device=device, dtype=dtype)
391
+
392
+ # TryOnEdit: Concat condition image to masked image
393
+ condition_image = condition_image.to(device=device, dtype=dtype)
394
+ masked_image = torch.cat((masked_image, condition_image), dim=-1)
395
+ mask_image = torch.cat((mask_image, torch.zeros_like(mask_image)), dim=-1)
396
+
397
+ height, width = image.shape[-2:]
398
+ mask, masked_image_latents = self.prepare_mask_latents(
399
+ mask_image,
400
+ masked_image,
401
+ batch_size,
402
+ num_channels_latents,
403
+ num_images_per_prompt,
404
+ height,
405
+ width * 2, # TryOnEdit: width * 2
406
+ dtype,
407
+ device,
408
+ generator,
409
+ )
410
+ masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
411
+
412
+ # 6. Prepare timesteps
413
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
414
+ image_seq_len = latents.shape[1]
415
+ mu = calculate_shift(
416
+ image_seq_len,
417
+ self.scheduler.config.base_image_seq_len,
418
+ self.scheduler.config.max_image_seq_len,
419
+ self.scheduler.config.base_shift,
420
+ self.scheduler.config.max_shift,
421
+ )
422
+ timesteps, num_inference_steps = retrieve_timesteps(
423
+ self.scheduler,
424
+ num_inference_steps,
425
+ device,
426
+ sigmas=sigmas,
427
+ mu=mu,
428
+ )
429
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
430
+ self._num_timesteps = len(timesteps)
431
+
432
+ # handle guidance
433
+ if self.transformer.config.guidance_embeds:
434
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
435
+ guidance = guidance.expand(latents.shape[0])
436
+ else:
437
+ guidance = None
438
+
439
+ # 7. Denoising loop
440
+ pooled_prompt_embeds = torch.zeros([latents.shape[0], 768], device=device, dtype=dtype) # TryOnEdit: for now, we don't use pooled prompt embeddings
441
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
442
+ for i, t in enumerate(timesteps):
443
+ if self.interrupt:
444
+ continue
445
+
446
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
447
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
448
+
449
+ noise_pred = self.transformer(
450
+ hidden_states=torch.cat((latents, masked_image_latents), dim=2),
451
+ timestep=timestep / 1000,
452
+ guidance=guidance,
453
+ pooled_projections=pooled_prompt_embeds,
454
+ encoder_hidden_states=None,
455
+ txt_ids=None,
456
+ img_ids=latent_image_ids,
457
+ joint_attention_kwargs=self.joint_attention_kwargs,
458
+ return_dict=False,
459
+ )[0]
460
+
461
+ # compute the previous noisy sample x_t -> x_t-1
462
+ latents_dtype = latents.dtype
463
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
464
+
465
+ if latents.dtype != latents_dtype:
466
+ if torch.backends.mps.is_available():
467
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
468
+ latents = latents.to(latents_dtype)
469
+
470
+ if callback_on_step_end is not None:
471
+ callback_kwargs = {}
472
+ for k in callback_on_step_end_tensor_inputs:
473
+ callback_kwargs[k] = locals()[k]
474
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
475
+
476
+ latents = callback_outputs.pop("latents", latents)
477
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
478
+
479
+ # call the callback, if provided
480
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
481
+ progress_bar.update()
482
+
483
+ # 8. Post-process the image
484
+ if output_type == "latent":
485
+ image = latents
486
+ else:
487
+ latents = self._unpack_latents(latents, height, width * 2, self.vae_scale_factor) # TryOnEdit: width * 2
488
+ latents = latents.split(latents.shape[-1] // 2, dim=-1)[0] # TryOnEdit: split along the last dimension
489
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
490
+ image = self.vae.decode(latents, return_dict=False)[0]
491
+ image = self.image_processor.postprocess(image, output_type=output_type)
492
+
493
+ # Offload all models
494
+ self.maybe_free_model_hooks()
495
+
496
+ if not return_dict:
497
+ return (image,)
498
+
499
+ return FluxPipelineOutput(images=image)
model/flux/transformer_flux.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
10
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
11
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
12
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
13
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
14
+
15
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
16
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
17
+ from diffusers.models.attention import FeedForward
18
+ from diffusers.models.attention_processor import (
19
+ Attention,
20
+ AttentionProcessor,
21
+ FusedFluxAttnProcessor2_0,
22
+ )
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+
27
+ # Modified from `diffusers.models.attention_processor.FluxAttnProcessor2_0`
28
+ class FluxAttnProcessor2_0:
29
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
30
+
31
+ def __init__(self):
32
+ if not hasattr(F, "scaled_dot_product_attention"):
33
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
34
+
35
+ def __call__(
36
+ self,
37
+ attn: Attention,
38
+ hidden_states: torch.FloatTensor,
39
+ encoder_hidden_states: torch.FloatTensor = None,
40
+ attention_mask: Optional[torch.FloatTensor] = None,
41
+ image_rotary_emb: Optional[torch.Tensor] = None,
42
+ ) -> torch.FloatTensor:
43
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
44
+
45
+ # `sample` projections.
46
+ query = attn.to_q(hidden_states)
47
+ key = attn.to_k(hidden_states)
48
+ value = attn.to_v(hidden_states)
49
+
50
+ inner_dim = key.shape[-1]
51
+ head_dim = inner_dim // attn.heads
52
+
53
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
54
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
55
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
56
+
57
+ if attn.norm_q is not None:
58
+ query = attn.norm_q(query)
59
+ if attn.norm_k is not None:
60
+ key = attn.norm_k(key)
61
+
62
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
63
+ if encoder_hidden_states is not None:
64
+ # `context` projections.
65
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
66
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
67
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
68
+
69
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
70
+ batch_size, -1, attn.heads, head_dim
71
+ ).transpose(1, 2)
72
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
73
+ batch_size, -1, attn.heads, head_dim
74
+ ).transpose(1, 2)
75
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
76
+ batch_size, -1, attn.heads, head_dim
77
+ ).transpose(1, 2)
78
+
79
+ if attn.norm_added_q is not None:
80
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
81
+ if attn.norm_added_k is not None:
82
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
83
+
84
+ # attention
85
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
86
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
87
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
88
+
89
+ if image_rotary_emb is not None:
90
+ from diffusers.models.embeddings import apply_rotary_emb
91
+
92
+ query = apply_rotary_emb(query, image_rotary_emb)
93
+ key = apply_rotary_emb(key, image_rotary_emb)
94
+
95
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
96
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
97
+ hidden_states = hidden_states.to(query.dtype)
98
+
99
+ if encoder_hidden_states is not None:
100
+ encoder_hidden_states, hidden_states = (
101
+ hidden_states[:, : encoder_hidden_states.shape[1]],
102
+ hidden_states[:, encoder_hidden_states.shape[1] :],
103
+ )
104
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
105
+
106
+ # edited for try-on
107
+ if not attn.pre_only:
108
+ # linear proj
109
+ hidden_states = attn.to_out[0](hidden_states)
110
+ # dropout
111
+ hidden_states = attn.to_out[1](hidden_states)
112
+
113
+ if encoder_hidden_states is not None:
114
+ return hidden_states, encoder_hidden_states
115
+ else:
116
+ return hidden_states
117
+
118
+
119
+ @maybe_allow_in_graph
120
+ class FluxSingleTransformerBlock(nn.Module):
121
+ r"""
122
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
123
+
124
+ Reference: https://arxiv.org/abs/2403.03206
125
+
126
+ Parameters:
127
+ dim (`int`): The number of channels in the input and output.
128
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
129
+ attention_head_dim (`int`): The number of channels in each head.
130
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
131
+ processing of `context` conditions.
132
+ """
133
+
134
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
135
+ super().__init__()
136
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
137
+
138
+ self.norm = AdaLayerNormZeroSingle(dim)
139
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
140
+ self.act_mlp = nn.GELU(approximate="tanh")
141
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
142
+
143
+ processor = FluxAttnProcessor2_0()
144
+ self.attn = Attention(
145
+ query_dim=dim,
146
+ cross_attention_dim=None,
147
+ dim_head=attention_head_dim,
148
+ heads=num_attention_heads,
149
+ out_dim=dim,
150
+ bias=True,
151
+ processor=processor,
152
+ qk_norm="rms_norm",
153
+ eps=1e-6,
154
+ pre_only=True,
155
+ )
156
+
157
+ def forward(
158
+ self,
159
+ hidden_states: torch.FloatTensor,
160
+ temb: torch.FloatTensor,
161
+ image_rotary_emb=None,
162
+ joint_attention_kwargs=None,
163
+ ):
164
+ residual = hidden_states
165
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
166
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
167
+ joint_attention_kwargs = joint_attention_kwargs or {}
168
+ attn_output = self.attn(
169
+ hidden_states=norm_hidden_states,
170
+ image_rotary_emb=image_rotary_emb,
171
+ **joint_attention_kwargs,
172
+ )
173
+
174
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
175
+ gate = gate.unsqueeze(1)
176
+ hidden_states = gate * self.proj_out(hidden_states)
177
+ hidden_states = residual + hidden_states
178
+ if hidden_states.dtype == torch.float16:
179
+ hidden_states = hidden_states.clip(-65504, 65504)
180
+
181
+ return hidden_states
182
+
183
+
184
+ @maybe_allow_in_graph
185
+ class FluxTransformerBlock(nn.Module):
186
+ r"""
187
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
188
+
189
+ Reference: https://arxiv.org/abs/2403.03206
190
+
191
+ Parameters:
192
+ dim (`int`): The number of channels in the input and output.
193
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
194
+ attention_head_dim (`int`): The number of channels in each head.
195
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
196
+ processing of `context` conditions.
197
+ """
198
+
199
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
200
+ super().__init__()
201
+
202
+ self.norm1 = AdaLayerNormZero(dim)
203
+
204
+ self.norm1_context = AdaLayerNormZero(dim)
205
+
206
+ if hasattr(F, "scaled_dot_product_attention"):
207
+ processor = FluxAttnProcessor2_0()
208
+ else:
209
+ raise ValueError(
210
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
211
+ )
212
+ self.attn = Attention(
213
+ query_dim=dim,
214
+ cross_attention_dim=None,
215
+ added_kv_proj_dim=dim,
216
+ dim_head=attention_head_dim,
217
+ heads=num_attention_heads,
218
+ out_dim=dim,
219
+ context_pre_only=False,
220
+ bias=True,
221
+ processor=processor,
222
+ qk_norm=qk_norm,
223
+ eps=eps,
224
+ )
225
+
226
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
227
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
228
+
229
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
230
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
231
+
232
+ # let chunk size default to None
233
+ self._chunk_size = None
234
+ self._chunk_dim = 0
235
+
236
+ def remove_text_layers(self):
237
+ # for try-on, we don't need the text conditioning
238
+ self.norm1_context = None
239
+ self.ff_context = None
240
+ self.norm2_context = None
241
+ self.attn.to_added_qkv = None
242
+ self.attn.norm_added_q = None
243
+ self.attn.norm_added_k = None
244
+
245
+ def forward(
246
+ self,
247
+ hidden_states: torch.FloatTensor,
248
+ encoder_hidden_states: torch.FloatTensor,
249
+ temb: torch.FloatTensor,
250
+ image_rotary_emb=None,
251
+ joint_attention_kwargs=None,
252
+ ):
253
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
254
+
255
+ if encoder_hidden_states is not None:
256
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
257
+ encoder_hidden_states, emb=temb
258
+ )
259
+ else:
260
+ norm_encoder_hidden_states = None
261
+
262
+ joint_attention_kwargs = joint_attention_kwargs or {}
263
+ # Attention.
264
+
265
+ outputs = self.attn(
266
+ hidden_states=norm_hidden_states,
267
+ encoder_hidden_states=norm_encoder_hidden_states,
268
+ image_rotary_emb=image_rotary_emb,
269
+ **joint_attention_kwargs,
270
+ )
271
+ if isinstance(outputs, tuple):
272
+ attn_output, context_attn_output = outputs
273
+ else:
274
+ attn_output = outputs
275
+
276
+ # Process attention outputs for the `hidden_states`.
277
+ attn_output = gate_msa.unsqueeze(1) * attn_output
278
+ hidden_states = hidden_states + attn_output
279
+
280
+ norm_hidden_states = self.norm2(hidden_states)
281
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
282
+
283
+ ff_output = self.ff(norm_hidden_states)
284
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
285
+
286
+ hidden_states = hidden_states + ff_output
287
+
288
+ # Process attention outputs for the `encoder_hidden_states`.
289
+ if encoder_hidden_states is not None:
290
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
291
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
292
+
293
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
294
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
295
+
296
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
297
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
298
+ if encoder_hidden_states.dtype == torch.float16:
299
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
300
+
301
+ return encoder_hidden_states, hidden_states
302
+
303
+
304
+ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
305
+ """
306
+ The Transformer model introduced in Flux.
307
+
308
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
309
+
310
+ Parameters:
311
+ patch_size (`int`): Patch size to turn the input data into small patches.
312
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
313
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
314
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
315
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
316
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
317
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
318
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
319
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
320
+ """
321
+
322
+ _supports_gradient_checkpointing = True
323
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
324
+
325
+ @register_to_config
326
+ def __init__(
327
+ self,
328
+ patch_size: int = 1,
329
+ in_channels: int = 64,
330
+ out_channels: Optional[int] = None,
331
+ num_layers: int = 19,
332
+ num_single_layers: int = 38,
333
+ attention_head_dim: int = 128,
334
+ num_attention_heads: int = 24,
335
+ joint_attention_dim: int = 4096,
336
+ pooled_projection_dim: int = 768,
337
+ guidance_embeds: bool = False,
338
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
339
+ ):
340
+ super().__init__()
341
+ self.out_channels = out_channels or in_channels
342
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
343
+
344
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
345
+
346
+ text_time_guidance_cls = (
347
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
348
+ )
349
+ self.time_text_embed = text_time_guidance_cls(
350
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
351
+ )
352
+
353
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
354
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
355
+
356
+ self.transformer_blocks = nn.ModuleList(
357
+ [
358
+ FluxTransformerBlock(
359
+ dim=self.inner_dim,
360
+ num_attention_heads=self.config.num_attention_heads,
361
+ attention_head_dim=self.config.attention_head_dim,
362
+ )
363
+ for i in range(self.config.num_layers)
364
+ ]
365
+ )
366
+
367
+ self.single_transformer_blocks = nn.ModuleList(
368
+ [
369
+ FluxSingleTransformerBlock(
370
+ dim=self.inner_dim,
371
+ num_attention_heads=self.config.num_attention_heads,
372
+ attention_head_dim=self.config.attention_head_dim,
373
+ )
374
+ for i in range(self.config.num_single_layers)
375
+ ]
376
+ )
377
+
378
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
379
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
380
+
381
+ self.gradient_checkpointing = False
382
+
383
+ @property
384
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
385
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
386
+ r"""
387
+ Returns:
388
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
389
+ indexed by its weight name.
390
+ """
391
+ # set recursively
392
+ processors = {}
393
+
394
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
395
+ if hasattr(module, "get_processor"):
396
+ processors[f"{name}.processor"] = module.get_processor()
397
+
398
+ for sub_name, child in module.named_children():
399
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
400
+
401
+ return processors
402
+
403
+ for name, module in self.named_children():
404
+ fn_recursive_add_processors(name, module, processors)
405
+
406
+ return processors
407
+
408
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
409
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
410
+ r"""
411
+ Sets the attention processor to use to compute attention.
412
+
413
+ Parameters:
414
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
415
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
416
+ for **all** `Attention` layers.
417
+
418
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
419
+ processor. This is strongly recommended when setting trainable attention processors.
420
+
421
+ """
422
+ count = len(self.attn_processors.keys())
423
+
424
+ if isinstance(processor, dict) and len(processor) != count:
425
+ raise ValueError(
426
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
427
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
428
+ )
429
+
430
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
431
+ if hasattr(module, "set_processor"):
432
+ if not isinstance(processor, dict):
433
+ module.set_processor(processor)
434
+ else:
435
+ module.set_processor(processor.pop(f"{name}.processor"))
436
+
437
+ for sub_name, child in module.named_children():
438
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
439
+
440
+ for name, module in self.named_children():
441
+ fn_recursive_attn_processor(name, module, processor)
442
+
443
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
444
+ def fuse_qkv_projections(self):
445
+ """
446
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
447
+ are fused. For cross-attention modules, key and value projection matrices are fused.
448
+
449
+ <Tip warning={true}>
450
+
451
+ This API is 🧪 experimental.
452
+
453
+ </Tip>
454
+ """
455
+ self.original_attn_processors = None
456
+
457
+ for _, attn_processor in self.attn_processors.items():
458
+ if "Added" in str(attn_processor.__class__.__name__):
459
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
460
+
461
+ self.original_attn_processors = self.attn_processors
462
+
463
+ for module in self.modules():
464
+ if isinstance(module, Attention):
465
+ module.fuse_projections(fuse=True)
466
+
467
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
468
+
469
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
470
+ def unfuse_qkv_projections(self):
471
+ """Disables the fused QKV projection if enabled.
472
+
473
+ <Tip warning={true}>
474
+
475
+ This API is 🧪 experimental.
476
+
477
+ </Tip>
478
+
479
+ """
480
+ if self.original_attn_processors is not None:
481
+ self.set_attn_processor(self.original_attn_processors)
482
+
483
+ def _set_gradient_checkpointing(self, module, value=False):
484
+ if hasattr(module, "gradient_checkpointing"):
485
+ module.gradient_checkpointing = value
486
+
487
+ def remove_text_layers(self):
488
+ self.context_embedder = None
489
+ for transformer_block in self.transformer_blocks:
490
+ transformer_block.remove_text_layers()
491
+
492
+ def forward(
493
+ self,
494
+ hidden_states: torch.Tensor,
495
+ encoder_hidden_states: torch.Tensor = None,
496
+ condition_hidden_states: torch.Tensor = None,
497
+ pooled_projections: torch.Tensor = None,
498
+ timestep: torch.LongTensor = None,
499
+ img_ids: torch.Tensor = None,
500
+ txt_ids: torch.Tensor = None,
501
+ guidance: torch.Tensor = None,
502
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
503
+ controlnet_block_samples=None,
504
+ controlnet_single_block_samples=None,
505
+ return_dict: bool = True,
506
+ controlnet_blocks_repeat: bool = False,
507
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
508
+ """
509
+ The [`FluxTransformer2DModel`] forward method.
510
+
511
+ Args:
512
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
513
+ Input `hidden_states`.
514
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
515
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
516
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
517
+ from the embeddings of input conditions.
518
+ timestep ( `torch.LongTensor`):
519
+ Used to indicate denoising step.
520
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
521
+ A list of tensors that if specified are added to the residuals of transformer blocks.
522
+ joint_attention_kwargs (`dict`, *optional*):
523
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
524
+ `self.processor` in
525
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
526
+ return_dict (`bool`, *optional*, defaults to `True`):
527
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
528
+ tuple.
529
+
530
+ Returns:
531
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
532
+ `tuple` where the first element is the sample tensor.
533
+ """
534
+ if joint_attention_kwargs is not None:
535
+ joint_attention_kwargs = joint_attention_kwargs.copy()
536
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
537
+ else:
538
+ lora_scale = 1.0
539
+
540
+ if USE_PEFT_BACKEND:
541
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
542
+ scale_lora_layers(self, lora_scale)
543
+ else:
544
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
545
+ logger.warning(
546
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
547
+ )
548
+
549
+ hidden_states = self.x_embedder(hidden_states)
550
+
551
+ timestep = timestep.to(hidden_states.dtype) * 1000
552
+ guidance = guidance.to(hidden_states.dtype) * 1000 if guidance is not None else None
553
+
554
+ temb = self.time_text_embed(timestep, pooled_projections) if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections)
555
+
556
+ if encoder_hidden_states is not None:
557
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
558
+
559
+ ids = torch.cat((txt_ids, img_ids), dim=0) if txt_ids is not None else img_ids # for try-on, we don't need txt_ids
560
+ image_rotary_emb = self.pos_embed(ids)
561
+
562
+ # MMDiT Blocks
563
+ for index_block, block in enumerate(self.transformer_blocks):
564
+ if self.training and self.gradient_checkpointing:
565
+ def create_custom_forward(module, return_dict=None):
566
+ def custom_forward(*inputs):
567
+ if return_dict is not None:
568
+ return module(*inputs, return_dict=return_dict)
569
+ else:
570
+ return module(*inputs)
571
+
572
+ return custom_forward
573
+
574
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
575
+ result = torch.utils.checkpoint.checkpoint(
576
+ create_custom_forward(block),
577
+ hidden_states,
578
+ encoder_hidden_states,
579
+ temb,
580
+ image_rotary_emb,
581
+ **ckpt_kwargs,
582
+ )
583
+ if isinstance(result, tuple):
584
+ encoder_hidden_states, hidden_states = result
585
+ else:
586
+ hidden_states = result
587
+
588
+ else:
589
+ result = block(
590
+ hidden_states=hidden_states,
591
+ encoder_hidden_states=encoder_hidden_states,
592
+ temb=temb,
593
+ image_rotary_emb=image_rotary_emb,
594
+ joint_attention_kwargs=joint_attention_kwargs,
595
+ )
596
+ if isinstance(result, tuple):
597
+ encoder_hidden_states, hidden_states = result
598
+ else:
599
+ hidden_states = result
600
+
601
+ # Condition residual (for try-on pose conditioning)
602
+ if condition_hidden_states is not None and index_block == 0:
603
+ hidden_states = hidden_states + condition_hidden_states
604
+
605
+ # controlnet residual
606
+ if controlnet_block_samples is not None:
607
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
608
+ interval_control = int(np.ceil(interval_control))
609
+ # For Xlabs ControlNet.
610
+ if controlnet_blocks_repeat:
611
+ hidden_states = (
612
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
613
+ )
614
+ else:
615
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
616
+
617
+ if encoder_hidden_states is not None:
618
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
619
+
620
+ # Single DiT Blocks
621
+ for index_block, block in enumerate(self.single_transformer_blocks):
622
+ if self.training and self.gradient_checkpointing:
623
+
624
+ def create_custom_forward(module, return_dict=None):
625
+ def custom_forward(*inputs):
626
+ if return_dict is not None:
627
+ return module(*inputs, return_dict=return_dict)
628
+ else:
629
+ return module(*inputs)
630
+
631
+ return custom_forward
632
+
633
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
634
+ hidden_states = torch.utils.checkpoint.checkpoint(
635
+ create_custom_forward(block),
636
+ hidden_states,
637
+ temb,
638
+ image_rotary_emb,
639
+ **ckpt_kwargs,
640
+ )
641
+
642
+ else:
643
+ hidden_states = block(
644
+ hidden_states=hidden_states,
645
+ temb=temb,
646
+ image_rotary_emb=image_rotary_emb,
647
+ joint_attention_kwargs=joint_attention_kwargs,
648
+ )
649
+
650
+ # controlnet residual
651
+ if controlnet_single_block_samples is not None:
652
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
653
+ interval_control = int(np.ceil(interval_control))
654
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
655
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
656
+ + controlnet_single_block_samples[index_block // interval_control]
657
+ )
658
+
659
+ if encoder_hidden_states is not None:
660
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
661
+
662
+ hidden_states = self.norm_out(hidden_states, temb)
663
+ output = self.proj_out(hidden_states)
664
+
665
+ if USE_PEFT_BACKEND:
666
+ # remove `lora_scale` from each PEFT layer
667
+ unscale_lora_layers(self, lora_scale)
668
+
669
+ if not return_dict:
670
+ return (output,)
671
+
672
+ return Transformer2DModelOutput(sample=output)
model/pipeline.py CHANGED
@@ -213,3 +213,120 @@ class CatVTONPipeline:
213
  if not_safe:
214
  image[i] = nsfw_image
215
  return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  if not_safe:
214
  image[i] = nsfw_image
215
  return image
216
+
217
+
218
+ class CatVTONPix2PixPipeline(CatVTONPipeline):
219
+ def auto_attn_ckpt_load(self, attn_ckpt, version):
220
+ # TODO: Temperal fix for the model version
221
+ if os.path.exists(attn_ckpt):
222
+ load_checkpoint_in_model(self.attn_modules, os.path.join(attn_ckpt, version, 'attention'))
223
+ else:
224
+ repo_path = snapshot_download(repo_id=attn_ckpt)
225
+ print(f"Downloaded {attn_ckpt} to {repo_path}")
226
+ load_checkpoint_in_model(self.attn_modules, os.path.join(repo_path, version, 'attention'))
227
+
228
+ def check_inputs(self, image, condition_image, width, height):
229
+ if isinstance(image, torch.Tensor) and isinstance(condition_image, torch.Tensor) and isinstance(torch.Tensor):
230
+ return image, condition_image
231
+ image = resize_and_crop(image, (width, height))
232
+ condition_image = resize_and_padding(condition_image, (width, height))
233
+ return image, condition_image
234
+
235
+ @torch.no_grad()
236
+ def __call__(
237
+ self,
238
+ image: Union[PIL.Image.Image, torch.Tensor],
239
+ condition_image: Union[PIL.Image.Image, torch.Tensor],
240
+ num_inference_steps: int = 50,
241
+ guidance_scale: float = 2.5,
242
+ height: int = 1024,
243
+ width: int = 768,
244
+ generator=None,
245
+ eta=1.0,
246
+ **kwargs
247
+ ):
248
+ concat_dim = -1
249
+ # Prepare inputs to Tensor
250
+ image, condition_image = self.check_inputs(image, condition_image, width, height)
251
+ image = prepare_image(image).to(self.device, dtype=self.weight_dtype)
252
+ condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype)
253
+ # VAE encoding
254
+ image_latent = compute_vae_encodings(image, self.vae)
255
+ condition_latent = compute_vae_encodings(condition_image, self.vae)
256
+ del image, condition_image
257
+ # Concatenate latents
258
+ condition_latent_concat = torch.cat([image_latent, condition_latent], dim=concat_dim)
259
+ # Prepare noise
260
+ latents = randn_tensor(
261
+ condition_latent_concat.shape,
262
+ generator=generator,
263
+ device=condition_latent_concat.device,
264
+ dtype=self.weight_dtype,
265
+ )
266
+ # Prepare timesteps
267
+ self.noise_scheduler.set_timesteps(num_inference_steps, device=self.device)
268
+ timesteps = self.noise_scheduler.timesteps
269
+ latents = latents * self.noise_scheduler.init_noise_sigma
270
+ # Classifier-Free Guidance
271
+ if do_classifier_free_guidance := (guidance_scale > 1.0):
272
+ condition_latent_concat = torch.cat(
273
+ [
274
+ torch.cat([image_latent, torch.zeros_like(condition_latent)], dim=concat_dim),
275
+ condition_latent_concat,
276
+ ]
277
+ )
278
+
279
+ # Denoising loop
280
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
281
+ num_warmup_steps = (len(timesteps) - num_inference_steps * self.noise_scheduler.order)
282
+ with tqdm.tqdm(total=num_inference_steps) as progress_bar:
283
+ for i, t in enumerate(timesteps):
284
+ # expand the latents if we are doing classifier free guidance
285
+ latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)
286
+ latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
287
+ # prepare the input for the inpainting model
288
+ p2p_latent_model_input = torch.cat([latent_model_input, condition_latent_concat], dim=1)
289
+ # predict the noise residual
290
+ noise_pred= self.unet(
291
+ p2p_latent_model_input,
292
+ t.to(self.device),
293
+ encoder_hidden_states=None,
294
+ return_dict=False,
295
+ )[0]
296
+ # perform guidance
297
+ if do_classifier_free_guidance:
298
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
299
+ noise_pred = noise_pred_uncond + guidance_scale * (
300
+ noise_pred_text - noise_pred_uncond
301
+ )
302
+ # compute the previous noisy sample x_t -> x_t-1
303
+ latents = self.noise_scheduler.step(
304
+ noise_pred, t, latents, **extra_step_kwargs
305
+ ).prev_sample
306
+ # call the callback, if provided
307
+ if i == len(timesteps) - 1 or (
308
+ (i + 1) > num_warmup_steps
309
+ and (i + 1) % self.noise_scheduler.order == 0
310
+ ):
311
+ progress_bar.update()
312
+
313
+ # Decode the final latents
314
+ latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]
315
+ latents = 1 / self.vae.config.scaling_factor * latents
316
+ image = self.vae.decode(latents.to(self.device, dtype=self.weight_dtype)).sample
317
+ image = (image / 2 + 0.5).clamp(0, 1)
318
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
319
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
320
+ image = numpy_to_pil(image)
321
+
322
+ # Safety Check
323
+ if not self.skip_safety_check:
324
+ current_script_directory = os.path.dirname(os.path.realpath(__file__))
325
+ nsfw_image = os.path.join(os.path.dirname(current_script_directory), 'resource', 'img', 'NSFW.jpg')
326
+ nsfw_image = PIL.Image.open(nsfw_image).resize(image[0].size)
327
+ image_np = np.array(image)
328
+ _, has_nsfw_concept = self.run_safety_checker(image=image_np)
329
+ for i, not_safe in enumerate(has_nsfw_concept):
330
+ if not_safe:
331
+ image[i] = nsfw_image
332
+ return image
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  torch==2.1.2
2
  torchvision==0.16.2
3
  accelerate==0.31.0
4
- diffusers==0.29.2
5
  huggingface_hub==0.23.4
6
  matplotlib==3.9.1
7
  numpy==1.26.4
@@ -12,10 +12,11 @@ scipy==1.13.1
12
  setuptools==51.0.0
13
  scikit-image==0.24.0
14
  tqdm==4.66.4
15
- transformers==4.27.3
16
  fvcore==0.1.5.post20221221
17
  cloudpickle==3.0.0
18
  omegaconf==2.3.0
19
  pycocotools==2.0.8
20
  av==12.3.0
21
- gradio==4.41.0
 
 
1
  torch==2.1.2
2
  torchvision==0.16.2
3
  accelerate==0.31.0
4
+ git+https://github.com/huggingface/diffusers.git
5
  huggingface_hub==0.23.4
6
  matplotlib==3.9.1
7
  numpy==1.26.4
 
12
  setuptools==51.0.0
13
  scikit-image==0.24.0
14
  tqdm==4.66.4
15
+ transformers==4.46.3
16
  fvcore==0.1.5.post20221221
17
  cloudpickle==3.0.0
18
  omegaconf==2.3.0
19
  pycocotools==2.0.8
20
  av==12.3.0
21
+ gradio==4.41.0
22
+ peft==0.14.0