File size: 5,624 Bytes
8ccf632
 
81b26b5
ee4c1bd
06f0278
0624a48
 
 
9934eed
9675628
9934eed
 
0624a48
 
ee4c1bd
0624a48
 
8ccf632
ee4c1bd
a7950e4
29b6509
a7950e4
 
 
e0e4f11
 
 
0624a48
b8e822a
 
 
 
ee4c1bd
a7950e4
 
 
ee4c1bd
b8e822a
 
 
8ccf632
2152339
8ccf632
b8e822a
a7950e4
 
54192f0
2152339
b8e822a
29b6509
 
 
 
b8e822a
 
 
 
 
 
29b6509
 
 
 
1f2e94a
8ccf632
06f0278
 
 
8ccf632
 
1f2e94a
8ccf632
 
e2944a6
8ccf632
 
 
 
 
 
6ebb7df
 
4ea3b6f
8ccf632
 
 
b8e822a
 
 
 
 
 
a478964
b8e822a
 
 
8ccf632
de83e05
29b6509
8ccf632
 
b8e822a
 
 
 
 
 
 
 
 
8ccf632
b8e822a
 
 
 
 
 
de83e05
b8e822a
 
 
 
 
 
 
de83e05
b8e822a
 
 
 
 
 
 
 
 
 
 
 
 
 
de83e05
b8e822a
 
 
de83e05
b8e822a
 
 
8ccf632
b8e822a
 
 
 
 
 
 
 
f284516
 
 
 
 
b8e822a
f284516
 
de83e05
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

# Detección y configuración del dispositivo para compatibilidad con GPU o CPU
if torch.cuda.is_available():
    device = "cuda"  # Para GPUs NVIDIA
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_built():
    device = "mps"  # Para GPUs Apple Silicon (M1/M2) y otras GPUs con soporte Metal
elif hasattr(torch.backends, "rocm") and torch.backends.rocm.is_available():
    device = "rocm"  # Para GPUs AMD con ROCm, si está disponible
else:
    device = "cpu"  # En caso de no tener GPU disponible

# Definir el tipo de dato, usando bfloat16 si es compatible, si no, usar float32
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

# Inicializar el modelo solo una vez y cargarlo en RAM y GPU/CPU
pipe = None

def load_model():
    global pipe
    if pipe is None:
        # Inicializar ZeroGPU antes de cargar el modelo
        init_empty_weights()

        # Cargar el modelo y configurarlo para usar el dispositivo adecuado
        pipe = DiffusionPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell", 
            torch_dtype=dtype
        )
        # Despachar los pesos al dispositivo adecuado (GPU o CPU)
        pipe = load_checkpoint_and_dispatch(
            pipe, 
            "black-forest-labs/FLUX.1-schnell", 
            device_map="auto",  # Automatiza el uso de RAM, GPU o CPU
            offload_folder=None  # Evita que se almacenen los pesos temporalmente en el disco
        )
        pipe.to(device)

MAX_SEED = torch.iinfo(torch.int32).max

def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, num_images=1):
    load_model()  # Asegurarse de que el modelo esté cargado antes de la inferencia
    
    if randomize_seed:
        seed = torch.randint(0, MAX_SEED, (1,)).item()
    generator = torch.Generator(device).manual_seed(seed)
    
    images = []
    for _ in range(num_images):
        image = pipe(
                prompt=prompt, 
                width=width,
                height=height,
                num_inference_steps=num_inference_steps, 
                generator=generator,
                guidance_scale=0.0
        ).images[0]
        images.append(image)
    
    return images, seed

examples = [
    "a tiny astronaut hatching from an egg on the moon",
    "a cat holding a sign that says hello world",
    "an anime illustration of a wiener schnitzel",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# FLUX.1 [schnell]
12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
[[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co./black-forest-labs/FLUX.1-schnell)]
        """)
        
        with gr.Row():
            
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False
            )
            
            run_button = gr.Button("Run", scale=0)
        
        # Usamos gr.Gallery para mostrar múltiples imágenes
        results = gr.Gallery(label="Results", show_label=False, elem_id="image-gallery")
        
        with gr.Accordion("Advanced Settings", open=False):
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            
            with gr.Row():
                
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=2048,  # Ajusta el tamaño máximo según sea necesario
                    step=32,
                    value=1024,
                )
                
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=2048,  # Ajusta el tamaño máximo según sea necesario
                    step=32,
                    value=1024,
                )
            
            with gr.Row():
                
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=4,
                )
            
            # Control para el número de imágenes a generar
            num_images = gr.Slider(
                label="Number of images",
                minimum=1,
                maximum=10,  # Ajusta el número máximo de imágenes según sea necesario
                step=1,
                value=1,
            )
        
        gr.Examples(
            examples = examples,
            fn = infer,
            inputs = [prompt],
            outputs = [results, seed],
            cache_examples="lazy"
        )

    # Conectar el botón y el campo de texto a la función infer
    run_button.click(
        fn=infer,
        inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps, num_images],
        outputs=[results, seed]
    )
    
    # Crear un enlace público con share=True
    demo.launch(share=True)