Ffftdtd5dtft commited on
Commit
a7950e4
·
verified ·
1 Parent(s): 29b6509

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -10,23 +10,29 @@ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
10
  dtype = torch.bfloat16
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- # Inicialización del modelo en la RAM
14
- with init_empty_weights():
15
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype)
16
 
17
- # Cargar el modelo en la RAM y despachar los pesos a la GPU
18
- pipe = load_checkpoint_and_dispatch(
19
- pipe,
20
- "black-forest-labs/FLUX.1-schnell",
21
- device_map="auto", # Automatiza el uso de RAM y GPU
22
- offload_folder=None, # Evita que se almacenen los pesos temporalmente en el disco
23
- ).to(device)
 
 
 
 
 
24
 
25
  MAX_SEED = np.iinfo(np.int32).max
26
  MAX_IMAGE_SIZE = 2048
27
 
28
  @spaces.GPU()
29
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, num_images=1, progress=gr.Progress(track_tqdm=True)):
 
 
30
  if randomize_seed:
31
  seed = random.randint(0, MAX_SEED)
32
  generator = torch.Generator().manual_seed(seed)
 
10
  dtype = torch.bfloat16
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Inicializar el modelo solo una vez y cargarlo en RAM y GPU
14
+ pipe = None
 
15
 
16
+ def load_model():
17
+ global pipe
18
+ if pipe is None:
19
+ with init_empty_weights():
20
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype)
21
+ # Cargar el modelo en la RAM y despachar los pesos a la GPU
22
+ pipe = load_checkpoint_and_dispatch(
23
+ pipe,
24
+ "black-forest-labs/FLUX.1-schnell",
25
+ device_map="auto", # Automatiza el uso de RAM y GPU
26
+ offload_folder=None, # Evita que se almacenen los pesos temporalmente en el disco
27
+ ).to(device)
28
 
29
  MAX_SEED = np.iinfo(np.int32).max
30
  MAX_IMAGE_SIZE = 2048
31
 
32
  @spaces.GPU()
33
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, num_images=1, progress=gr.Progress(track_tqdm=True)):
34
+ load_model() # Asegurarse de que el modelo esté cargado antes de la inferencia
35
+
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
38
  generator = torch.Generator().manual_seed(seed)