AP123 commited on
Commit
be85eb8
1 Parent(s): ee36d88

Update app.py

Browse files

Safety checker

Files changed (1) hide show
  1. app.py +24 -7
app.py CHANGED
@@ -13,26 +13,31 @@ from diffusers import (
13
  StableDiffusionLatentUpscalePipeline,
14
  StableDiffusionImg2ImgPipeline,
15
  StableDiffusionControlNetImg2ImgPipeline,
16
- DPMSolverMultistepScheduler, # <-- Added import
17
- EulerDiscreteScheduler, # <-- Added import (5/13)
18
- StableDiffusionSafetyChecker# <-- Added import
19
  )
20
  import tempfile
21
  import time
22
  from share_btn import community_icon_html, loading_icon_html, share_js
23
  import user_history
24
  from illusion_style import css
 
 
 
25
 
26
  BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
27
 
28
  # Initialize both pipelines
29
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
 
30
 
31
- # Initialize the safety checker
32
- safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
 
 
 
 
33
 
34
- #init_pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", torch_dtype=torch.float16)
35
- controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float16)#, torch_dtype=torch.float16)
36
  main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
37
  BASE_MODEL,
38
  controlnet=controlnet,
@@ -41,6 +46,18 @@ main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
41
  torch_dtype=torch.float16,
42
  ).to("cuda")
43
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  #main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
45
  #main_pipe.unet.to(memory_format=torch.channels_last)
46
  #main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
 
13
  StableDiffusionLatentUpscalePipeline,
14
  StableDiffusionImg2ImgPipeline,
15
  StableDiffusionControlNetImg2ImgPipeline,
16
+ DPMSolverMultistepScheduler,
17
+ EulerDiscreteScheduler
 
18
  )
19
  import tempfile
20
  import time
21
  from share_btn import community_icon_html, loading_icon_html, share_js
22
  import user_history
23
  from illusion_style import css
24
+ import os
25
+ from transformers import CLIPFeatureExtractor
26
+ from safety_checker import StableDiffusionSafetyChecker
27
 
28
  BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
29
 
30
  # Initialize both pipelines
31
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
32
+ controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float16)
33
 
34
+ # Initialize the safety checker conditionally
35
+ SAFETY_CHECKER_ENABLED = os.environ.get("SAFETY_CHECKER", "0") == "1"
36
+ safety_checker = None
37
+ if SAFETY_CHECKER_ENABLED:
38
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cuda")
39
+ feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
40
 
 
 
41
  main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
42
  BASE_MODEL,
43
  controlnet=controlnet,
 
46
  torch_dtype=torch.float16,
47
  ).to("cuda")
48
 
49
+ # Function to check NSFW images
50
+ def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
51
+ if SAFETY_CHECKER_ENABLED:
52
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
53
+ has_nsfw_concepts = safety_checker(
54
+ images=[images],
55
+ clip_input=safety_checker_input.pixel_values.to("cuda")
56
+ )
57
+ return images, has_nsfw_concepts
58
+ else:
59
+ return images, [False] * len(images)
60
+
61
  #main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
62
  #main_pipe.unet.to(memory_format=torch.channels_last)
63
  #main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)