Ffftdtd5dtft commited on
Commit
b8e822a
·
verified ·
1 Parent(s): 2152339

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -29
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  from diffusers import DiffusionPipeline
4
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
5
 
6
  # Configuración para usar bfloat16 y CUDA si está disponible
7
  dtype = torch.bfloat16
@@ -13,44 +13,43 @@ pipe = None
13
  def load_model():
14
  global pipe
15
  if pipe is None:
16
- with init_empty_weights():
17
- pipe = DiffusionPipeline.from_pretrained(
18
- "black-forest-labs/FLUX.1-schnell",
19
- torch_dtype=dtype
20
- )
21
- # Cargar el modelo en la GPU sin intentar acceder a `named_parameters`
22
  pipe = load_checkpoint_and_dispatch(
23
  pipe,
24
  "black-forest-labs/FLUX.1-schnell",
25
- device_map="auto",
26
- offload_folder=None,
27
- ).to(device)
 
28
 
29
  MAX_SEED = torch.iinfo(torch.int32).max
30
- MAX_IMAGE_SIZE = 2048
31
 
32
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, num_images=1, progress=gr.Progress(track_tqdm=True)):
 
33
  load_model() # Asegurarse de que el modelo esté cargado antes de la inferencia
34
 
35
  if randomize_seed:
36
  seed = torch.randint(0, MAX_SEED, (1,)).item()
37
- generator = torch.Generator(device=device).manual_seed(seed)
38
 
39
  images = []
40
  for _ in range(num_images):
41
  image = pipe(
42
- prompt=prompt,
43
- width=width,
44
- height=height,
45
- num_inference_steps=num_inference_steps,
46
- generator=generator,
47
- guidance_scale=0.0
48
  ).images[0]
49
  images.append(image)
50
 
51
  return images, seed
52
 
53
- # Gradio Interface
54
  examples = [
55
  "a tiny astronaut hatching from an egg on the moon",
56
  "a cat holding a sign that says hello world",
@@ -73,21 +72,81 @@ with gr.Blocks(css=css) as demo:
73
  """)
74
 
75
  with gr.Row():
76
- prompt = gr.Textbox(label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt")
77
- run_button = gr.Button("Run")
 
 
 
 
 
 
 
 
78
 
79
  results = gr.Gallery(label="Results", show_label=False, elem_id="image-gallery")
80
 
81
  with gr.Accordion("Advanced Settings", open=False):
82
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
 
 
 
 
 
 
 
 
83
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
84
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
85
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
86
- num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
87
- num_images = gr.Slider(label="Number of images", minimum=1, maximum 10, step=1, value=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- gr.Examples(examples, inputs=[prompt], outputs=[results, seed])
90
- run_button.click(infer, inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps, num_images], outputs=[results, seed])
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  # Crear un enlace público con share=True
93
  demo.launch(share=True)
 
1
  import gradio as gr
2
  import torch
3
  from diffusers import DiffusionPipeline
4
+ from accelerate import load_checkpoint_and_dispatch
5
 
6
  # Configuración para usar bfloat16 y CUDA si está disponible
7
  dtype = torch.bfloat16
 
13
  def load_model():
14
  global pipe
15
  if pipe is None:
16
+ pipe = DiffusionPipeline.from_pretrained(
17
+ "black-forest-labs/FLUX.1-schnell",
18
+ torch_dtype=dtype
19
+ )
20
+ # Despachar los pesos a la GPU, evitando acceder a named_parameters
 
21
  pipe = load_checkpoint_and_dispatch(
22
  pipe,
23
  "black-forest-labs/FLUX.1-schnell",
24
+ device_map="auto", # Automatiza el uso de RAM y GPU
25
+ offload_folder=None # Evita que se almacenen los pesos temporalmente en el disco
26
+ )
27
+ pipe.to(device)
28
 
29
  MAX_SEED = torch.iinfo(torch.int32).max
 
30
 
31
+ @gr.Interface()
32
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, num_images=1):
33
  load_model() # Asegurarse de que el modelo esté cargado antes de la inferencia
34
 
35
  if randomize_seed:
36
  seed = torch.randint(0, MAX_SEED, (1,)).item()
37
+ generator = torch.Generator(device).manual_seed(seed)
38
 
39
  images = []
40
  for _ in range(num_images):
41
  image = pipe(
42
+ prompt=prompt,
43
+ width=width,
44
+ height=height,
45
+ num_inference_steps=num_inference_steps,
46
+ generator=generator,
47
+ guidance_scale=0.0
48
  ).images[0]
49
  images.append(image)
50
 
51
  return images, seed
52
 
 
53
  examples = [
54
  "a tiny astronaut hatching from an egg on the moon",
55
  "a cat holding a sign that says hello world",
 
72
  """)
73
 
74
  with gr.Row():
75
+
76
+ prompt = gr.Text(
77
+ label="Prompt",
78
+ show_label=False,
79
+ max_lines=1,
80
+ placeholder="Enter your prompt",
81
+ container=False,
82
+ )
83
+
84
+ run_button = gr.Button("Run", scale=0)
85
 
86
  results = gr.Gallery(label="Results", show_label=False, elem_id="image-gallery")
87
 
88
  with gr.Accordion("Advanced Settings", open=False):
89
+
90
+ seed = gr.Slider(
91
+ label="Seed",
92
+ minimum=0,
93
+ maximum=MAX_SEED,
94
+ step=1,
95
+ value=0,
96
+ )
97
+
98
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
99
+
100
+ with gr.Row():
101
+
102
+ width = gr.Slider(
103
+ label="Width",
104
+ minimum=256,
105
+ maximum=2048,
106
+ step=32,
107
+ value=1024,
108
+ )
109
+
110
+ height = gr.Slider(
111
+ label="Height",
112
+ minimum=256,
113
+ maximum=2048,
114
+ step=32,
115
+ value=1024,
116
+ )
117
+
118
+ with gr.Row():
119
+
120
+ num_inference_steps = gr.Slider(
121
+ label="Number of inference steps",
122
+ minimum=1,
123
+ maximum=50,
124
+ step=1,
125
+ value=4,
126
+ )
127
+
128
+ num_images = gr.Slider(
129
+ label="Number of images",
130
+ minimum=1,
131
+ maximum=300,
132
+ step=1,
133
+ value=1,
134
+ )
135
 
136
+ gr.Examples(
137
+ examples = examples,
138
+ fn = infer,
139
+ inputs = [prompt],
140
+ outputs = [results, seed],
141
+ cache_examples="lazy"
142
+ )
143
+
144
+ gr.on(
145
+ triggers=[run_button.click, prompt.submit],
146
+ fn = infer,
147
+ inputs = [prompt, seed, randomize_seed, width, height, num_inference_steps, num_images],
148
+ outputs = [results, seed]
149
+ )
150
 
151
  # Crear un enlace público con share=True
152
  demo.launch(share=True)