ruslanmv commited on
Commit
4eb9e6f
·
1 Parent(s): 83dbf4b
Files changed (2) hide show
  1. app-v3-working.py +534 -0
  2. app.py +1 -2
app-v3-working.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import copy
4
+ import time
5
+ import random
6
+ import logging
7
+ import numpy as np
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+ import torch
11
+ from PIL import Image
12
+ import gradio as gr
13
+
14
+ from diffusers import (
15
+ DiffusionPipeline,
16
+ AutoencoderTiny,
17
+ AutoencoderKL,
18
+ AutoPipelineForImage2Image,
19
+ FluxPipeline,
20
+ FlowMatchEulerDiscreteScheduler
21
+ )
22
+
23
+ from huggingface_hub import (
24
+ hf_hub_download,
25
+ HfFileSystem,
26
+ ModelCard,
27
+ snapshot_download
28
+ )
29
+
30
+ from diffusers.utils import load_image
31
+
32
+ import spaces
33
+
34
+ # Import the prompt enhancer generator from enhance.py
35
+ from enhance import generate as enhance_generate
36
+
37
+ # Attempt to import loras from lora.py; otherwise use a default placeholder.
38
+ try:
39
+ from lora import loras
40
+ except ImportError:
41
+ loras = [
42
+ {"image": "placeholder.jpg", "title": "Placeholder LoRA", "repo": "placeholder/repo", "weights": None, "trigger_word": ""}
43
+ ]
44
+
45
+ #---if workspace = local or colab---
46
+ # (Optional: add Hugging Face login code here)
47
+
48
+ def calculate_shift(
49
+ image_seq_len,
50
+ base_seq_len: int = 256,
51
+ max_seq_len: int = 4096,
52
+ base_shift: float = 0.5,
53
+ max_shift: float = 1.16,
54
+ ):
55
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
56
+ b = base_shift - m * base_seq_len
57
+ mu = image_seq_len * m + b
58
+ return mu
59
+
60
+ def retrieve_timesteps(
61
+ scheduler,
62
+ num_inference_steps: Optional[int] = None,
63
+ device: Optional[Union[str, torch.device]] = None,
64
+ timesteps: Optional[List[int]] = None,
65
+ sigmas: Optional[List[float]] = None,
66
+ **kwargs,
67
+ ):
68
+ if timesteps is not None and sigmas is not None:
69
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
70
+ if timesteps is not None:
71
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
72
+ timesteps = scheduler.timesteps
73
+ num_inference_steps = len(timesteps)
74
+ elif sigmas is not None:
75
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
76
+ timesteps = scheduler.timesteps
77
+ num_inference_steps = len(timesteps)
78
+ else:
79
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
80
+ timesteps = scheduler.timesteps
81
+ return timesteps, num_inference_steps
82
+
83
+ # FLUX pipeline
84
+ @torch.inference_mode()
85
+ def flux_pipe_call_that_returns_an_iterable_of_images(
86
+ self,
87
+ prompt: Union[str, List[str]] = None,
88
+ prompt_2: Optional[Union[str, List[str]]] = None,
89
+ height: Optional[int] = None,
90
+ width: Optional[int] = None,
91
+ num_inference_steps: int = 28,
92
+ timesteps: List[int] = None,
93
+ guidance_scale: float = 3.5,
94
+ num_images_per_prompt: Optional[int] = 1,
95
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
96
+ latents: Optional[torch.FloatTensor] = None,
97
+ prompt_embeds: Optional[torch.FloatTensor] = None,
98
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
99
+ output_type: Optional[str] = "pil",
100
+ return_dict: bool = True,
101
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
102
+ max_sequence_length: int = 512,
103
+ good_vae: Optional[Any] = None,
104
+ ):
105
+ height = height or self.default_sample_size * self.vae_scale_factor
106
+ width = width or self.default_sample_size * self.vae_scale_factor
107
+
108
+ self.check_inputs(
109
+ prompt,
110
+ prompt_2,
111
+ height,
112
+ width,
113
+ prompt_embeds=prompt_embeds,
114
+ pooled_prompt_embeds=pooled_prompt_embeds,
115
+ max_sequence_length=max_sequence_length,
116
+ )
117
+
118
+ self._guidance_scale = guidance_scale
119
+ self._joint_attention_kwargs = joint_attention_kwargs
120
+ self._interrupt = False
121
+
122
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
123
+ device = self._execution_device
124
+
125
+ lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
126
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
127
+ prompt=prompt,
128
+ prompt_2=prompt_2,
129
+ prompt_embeds=prompt_embeds,
130
+ pooled_prompt_embeds=pooled_prompt_embeds,
131
+ device=device,
132
+ num_images_per_prompt=num_images_per_prompt,
133
+ max_sequence_length=max_sequence_length,
134
+ lora_scale=lora_scale,
135
+ )
136
+
137
+ num_channels_latents = self.transformer.config.in_channels // 4
138
+ latents, latent_image_ids = self.prepare_latents(
139
+ batch_size * num_images_per_prompt,
140
+ num_channels_latents,
141
+ height,
142
+ width,
143
+ prompt_embeds.dtype,
144
+ device,
145
+ generator,
146
+ latents,
147
+ )
148
+
149
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
150
+ image_seq_len = latents.shape[1]
151
+ mu = calculate_shift(
152
+ image_seq_len,
153
+ self.scheduler.config.base_image_seq_len,
154
+ self.scheduler.config.max_image_seq_len,
155
+ self.scheduler.config.base_shift,
156
+ self.scheduler.config.max_shift,
157
+ )
158
+ timesteps, num_inference_steps = retrieve_timesteps(
159
+ self.scheduler,
160
+ num_inference_steps,
161
+ device,
162
+ timesteps,
163
+ sigmas,
164
+ mu=mu,
165
+ )
166
+ self._num_timesteps = len(timesteps)
167
+
168
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
169
+
170
+ for i, t in enumerate(timesteps):
171
+ if self.interrupt:
172
+ continue
173
+
174
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
175
+
176
+ noise_pred = self.transformer(
177
+ hidden_states=latents,
178
+ timestep=timestep / 1000,
179
+ guidance=guidance,
180
+ pooled_projections=pooled_prompt_embeds,
181
+ encoder_hidden_states=prompt_embeds,
182
+ txt_ids=text_ids,
183
+ img_ids=latent_image_ids,
184
+ joint_attention_kwargs=self.joint_attention_kwargs,
185
+ return_dict=False,
186
+ )[0]
187
+
188
+ latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
189
+ latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
190
+ image = self.vae.decode(latents_for_image, return_dict=False)[0]
191
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
192
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
193
+ torch.cuda.empty_cache()
194
+
195
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
196
+ latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
197
+ image = good_vae.decode(latents, return_dict=False)[0]
198
+ self.maybe_free_model_hooks()
199
+ torch.cuda.empty_cache()
200
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
201
+
202
+ #--------------------------------------------------Model Initialization-----------------------------------------------------------------------------------------#
203
+ dtype = torch.bfloat16
204
+ device = "cuda" if torch.cuda.is_available() else "cpu"
205
+ base_model = "black-forest-labs/FLUX.1-dev"
206
+
207
+ # TAEF1 is a very tiny autoencoder which uses the same "latent API" as FLUX.1's VAE.
208
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
209
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
210
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
211
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
212
+ base_model,
213
+ vae=good_vae,
214
+ transformer=pipe.transformer,
215
+ text_encoder=pipe.text_encoder,
216
+ tokenizer=pipe.tokenizer,
217
+ text_encoder_2=pipe.text_encoder_2,
218
+ tokenizer_2=pipe.tokenizer_2,
219
+ torch_dtype=dtype,
220
+ ).to(device)
221
+ MAX_SEED = 2**32-1
222
+
223
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
224
+
225
+ class calculateDuration:
226
+ def __init__(self, activity_name=""):
227
+ self.activity_name = activity_name
228
+ def __enter__(self):
229
+ self.start_time = time.time()
230
+ return self
231
+ def __exit__(self, exc_type, exc_value, traceback):
232
+ self.end_time = time.time()
233
+ self.elapsed_time = self.end_time - self.start_time
234
+ if self.activity_name:
235
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
236
+ else:
237
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
238
+
239
+ def update_selection(evt: gr.SelectData, width, height):
240
+ selected_lora = loras[evt.index]
241
+ new_placeholder = f"Type a prompt for {selected_lora['title']}"
242
+ lora_repo = selected_lora["repo"]
243
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✅"
244
+ if "aspect" in selected_lora:
245
+ if selected_lora["aspect"] == "portrait":
246
+ width = 768
247
+ height = 1024
248
+ elif selected_lora["aspect"] == "landscape":
249
+ width = 1024
250
+ height = 768
251
+ else:
252
+ width = 1024
253
+ height = 1024
254
+ return (
255
+ gr.update(placeholder=new_placeholder),
256
+ updated_text,
257
+ evt.index,
258
+ width,
259
+ height,
260
+ )
261
+
262
+ @spaces.GPU(duration=100)
263
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
264
+ pipe.to("cuda")
265
+ generator = torch.Generator(device="cuda").manual_seed(seed)
266
+ with calculateDuration("Generating image"):
267
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
268
+ prompt=prompt_mash,
269
+ num_inference_steps=steps,
270
+ guidance_scale=cfg_scale,
271
+ width=width,
272
+ height=height,
273
+ generator=generator,
274
+ joint_attention_kwargs={"scale": lora_scale},
275
+ output_type="pil",
276
+ good_vae=good_vae,
277
+ ):
278
+ yield img
279
+
280
+ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
281
+ generator = torch.Generator(device="cuda").manual_seed(seed)
282
+ pipe_i2i.to("cuda")
283
+ image_input = load_image(image_input_path)
284
+ final_image = pipe_i2i(
285
+ prompt=prompt_mash,
286
+ image=image_input,
287
+ strength=image_strength,
288
+ num_inference_steps=steps,
289
+ guidance_scale=cfg_scale,
290
+ width=width,
291
+ height=height,
292
+ generator=generator,
293
+ joint_attention_kwargs={"scale": lora_scale},
294
+ output_type="pil",
295
+ ).images[0]
296
+ return final_image
297
+
298
+ @spaces.GPU(duration=100)
299
+ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, use_enhancer, progress=gr.Progress(track_tqdm=True)):
300
+ # Check if a LoRA is selected.
301
+ if selected_index is None:
302
+ return None, seed, gr.update(value="Please select a LoRA from the 'LoRA DLC's gallery above before generating images.", visible=False), "" # Return None for image and update prompt box
303
+
304
+ selected_lora = loras[selected_index]
305
+ lora_path = selected_lora["repo"]
306
+ trigger_word = selected_lora["trigger_word"]
307
+
308
+ # Prepare prompt by appending/prepending trigger word if available.
309
+ if trigger_word:
310
+ if "trigger_position" in selected_lora and selected_lora["trigger_position"] == "prepend":
311
+ prompt_mash = f"{trigger_word} {prompt}"
312
+ else:
313
+ prompt_mash = f"{prompt} {trigger_word}"
314
+ else:
315
+ prompt_mash = prompt
316
+
317
+ # If prompt enhancer is enabled, stream the enhanced prompt.
318
+ enhanced_text = ""
319
+ if use_enhancer:
320
+ for enhanced_chunk in enhance_generate(prompt_mash):
321
+ enhanced_text = enhanced_chunk
322
+ # Yield intermediate output (no image yet, but update enhanced prompt textbox)
323
+ yield None, seed, gr.update(visible=False), enhanced_text
324
+ prompt_mash = enhanced_text # Use final enhanced prompt for generation
325
+ # Else, leave prompt_mash as is.
326
+
327
+ with calculateDuration("Unloading LoRA"):
328
+ pipe.unload_lora_weights()
329
+ pipe_i2i.unload_lora_weights()
330
+
331
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
332
+ pipe_to_use = pipe_i2i if image_input is not None else pipe
333
+ weight_name = selected_lora.get("weights", None)
334
+ pipe_to_use.load_lora_weights(
335
+ lora_path,
336
+ weight_name=weight_name,
337
+ low_cpu_mem_usage=True
338
+ )
339
+
340
+ with calculateDuration("Randomizing seed"):
341
+ if randomize_seed:
342
+ seed = random.randint(0, MAX_SEED)
343
+
344
+ if image_input is not None:
345
+ final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
346
+ yield final_image, seed, gr.update(visible=False), enhanced_text
347
+ else:
348
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
349
+ final_image = None
350
+ step_counter = 0
351
+ for image in image_generator:
352
+ step_counter += 1
353
+ final_image = image
354
+ progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
355
+ yield image, seed, gr.update(value=progress_bar, visible=True), enhanced_text
356
+ yield final_image, seed, gr.update(value=progress_bar, visible=False), enhanced_text
357
+
358
+
359
+
360
+ def get_huggingface_safetensors(link):
361
+ split_link = link.split("/")
362
+ if len(split_link) == 2:
363
+ model_card = ModelCard.load(link)
364
+ base_model = model_card.data.get("base_model")
365
+ print(base_model)
366
+ if (base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell"):
367
+ raise Exception("Flux LoRA Not Found!")
368
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
369
+ trigger_word = model_card.data.get("instance_prompt", "")
370
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
371
+ fs = HfFileSystem()
372
+ try:
373
+ list_of_files = fs.ls(link, detail=False)
374
+ for file in list_of_files:
375
+ if file.endswith(".safetensors"):
376
+ safetensors_name = file.split("/")[-1]
377
+ if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
378
+ image_elements = file.split("/")
379
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
380
+ except Exception as e:
381
+ print(e)
382
+ gr.Warning("You didn't include a link nor a valid Hugging Face repository with a *.safetensors LoRA")
383
+ raise Exception("Invalid LoRA repository")
384
+ return split_link[1], link, safetensors_name, trigger_word, image_url
385
+ else:
386
+ raise Exception("Invalid LoRA link format")
387
+
388
+ def check_custom_model(link):
389
+ if link.startswith("https://"):
390
+ if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
391
+ link_split = link.split("huggingface.co/")
392
+ return get_huggingface_safetensors(link_split[1])
393
+ else:
394
+ return get_huggingface_safetensors(link)
395
+
396
+ def add_custom_lora(custom_lora):
397
+ global loras
398
+ if custom_lora:
399
+ try:
400
+ title, repo, path, trigger_word, image = check_custom_model(custom_lora)
401
+ print(f"Loaded custom LoRA: {repo}")
402
+ card = f'''
403
+ <div class="custom_lora_card">
404
+ <span>Loaded custom LoRA:</span>
405
+ <div class="card_internal">
406
+ <img src="{image}" />
407
+ <div>
408
+ <h3>{title}</h3>
409
+ <small>{"Using: <code><b>" + trigger_word + "</b></code> as the trigger word" if trigger_word else "No trigger word found. Include it in your prompt"}<br></small>
410
+ </div>
411
+ </div>
412
+ </div>
413
+ '''
414
+ existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
415
+ if not existing_item_index:
416
+ new_item = {
417
+ "image": image,
418
+ "title": title,
419
+ "repo": repo,
420
+ "weights": path,
421
+ "trigger_word": trigger_word
422
+ }
423
+ print(new_item)
424
+ existing_item_index = len(loras)
425
+ loras.append(new_item)
426
+
427
+ return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
428
+ except Exception as e:
429
+ gr.Warning("Invalid LoRA: either you entered an invalid link or a non-FLUX LoRA")
430
+ return gr.update(visible=True, value="Invalid LoRA"), gr.update(visible=False), gr.update(), "", None, ""
431
+ else:
432
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
433
+
434
+ def remove_custom_lora():
435
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
436
+
437
+ run_lora.zerogpu = True
438
+
439
+ css = '''
440
+ #gen_btn { height: 100%; }
441
+ #gen_column { align-self: stretch; }
442
+ #title { text-align: center; }
443
+ #title h1 { font-size: 3em; display:inline-flex; align-items:center; }
444
+ #title img { width: 100px; margin-right: 0.5em; }
445
+ #gallery .grid-wrap { height: 10vh; }
446
+ #lora_list { background: var(--block-background-fill); padding: 0 1em .3em; font-size: 90%; }
447
+ .card_internal { display: flex; height: 100px; margin-top: .5em; }
448
+ .card_internal img { margin-right: 1em; }
449
+ .styler { --form-gap-width: 0px !important; }
450
+ #progress { height:30px; }
451
+ #progress .generating { display:none; }
452
+ .progress-container { width: 100%; height: 30px; background-color: #f0f0f0; border-radius: 15px; overflow: hidden; margin-bottom: 20px; }
453
+ .progress-bar { height: 100%; background-color: #4f46e5; width: calc(var(--current) / var(--total) * 100%); transition: width 0.5s ease-in-out; }
454
+ '''
455
+
456
+ with gr.Blocks(theme=gr.themes.Base(), css=css, delete_cache=(60, 60)) as app:
457
+ title = gr.HTML(
458
+ """<h1>Flux LoRA Generation</h1>""",
459
+ elem_id="title",
460
+ )
461
+ selected_index = gr.State(None)
462
+ with gr.Row():
463
+ with gr.Column(scale=3):
464
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder=":/ choose the LoRA and type the prompt ")
465
+ with gr.Column(scale=1, elem_id="gen_column"):
466
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
467
+ with gr.Row():
468
+ with gr.Column():
469
+ selected_info = gr.Markdown("")
470
+ gallery = gr.Gallery(
471
+ [(item["image"], item["title"]) for item in loras],
472
+ label="LoRA DLC's",
473
+ allow_preview=False,
474
+ columns=3,
475
+ elem_id="gallery",
476
+ show_share_button=False
477
+ )
478
+ with gr.Group():
479
+ custom_lora = gr.Textbox(label="Enter Custom LoRA", placeholder="prithivMLmods/Canopus-LoRA-Flux-Anime")
480
+ gr.Markdown("[Check the list of FLUX LoRA's](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
481
+ custom_lora_info = gr.HTML(visible=False)
482
+ custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
483
+ with gr.Column():
484
+ progress_bar = gr.Markdown(elem_id="progress", visible=False)
485
+ result = gr.Image(label="Generated Image")
486
+ with gr.Row():
487
+ with gr.Accordion("Advanced Settings", open=False):
488
+ with gr.Row():
489
+ input_image = gr.Image(label="Input image", type="filepath")
490
+ image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
491
+ with gr.Column():
492
+ with gr.Row():
493
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
494
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
495
+ with gr.Row():
496
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
497
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
498
+ with gr.Row():
499
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
500
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
501
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=0.95)
502
+ with gr.Row():
503
+ use_enhancer = gr.Checkbox(value=False, label="Use Prompt Enhancer")
504
+ show_enhanced_prompt = gr.Checkbox(value=False, label="Display Enhanced Prompt")
505
+ enhanced_prompt_box = gr.Textbox(label="Enhanced Prompt", visible=False)
506
+ # Add the change event so that the enhanced prompt box visibility toggles.
507
+ show_enhanced_prompt.change(fn=lambda show: gr.update(visible=show),
508
+ inputs=show_enhanced_prompt,
509
+ outputs=enhanced_prompt_box)
510
+ gallery.select(
511
+ update_selection,
512
+ inputs=[width, height],
513
+ outputs=[prompt, selected_info, selected_index, width, height]
514
+ )
515
+ custom_lora.input(
516
+ add_custom_lora,
517
+ inputs=[custom_lora],
518
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
519
+ )
520
+ custom_lora_button.click(
521
+ remove_custom_lora,
522
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
523
+ )
524
+ gr.on(
525
+ triggers=[generate_button.click, prompt.submit],
526
+ fn=run_lora,
527
+ inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, use_enhancer],
528
+ outputs=[result, seed, progress_bar, enhanced_prompt_box]
529
+ )
530
+ with gr.Row():
531
+ gr.HTML("<div style='text-align:center; font-size:0.9em; margin-top:20px;'>Credits: <a href='https://ruslanmv.com' target='_blank'>ruslanmv.com</a></div>")
532
+
533
+ app.queue()
534
+ app.launch(debug=True)
app.py CHANGED
@@ -299,7 +299,7 @@ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps
299
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, use_enhancer, progress=gr.Progress(track_tqdm=True)):
300
  # Check if a LoRA is selected.
301
  if selected_index is None:
302
- return None, seed, gr.update(value="Please select a LoRA from the 'LoRA DLC's gallery above before generating images.", visible=False), "" # Return None for image and update prompt box
303
 
304
  selected_lora = loras[selected_index]
305
  lora_path = selected_lora["repo"]
@@ -356,7 +356,6 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
356
  yield final_image, seed, gr.update(value=progress_bar, visible=False), enhanced_text
357
 
358
 
359
-
360
  def get_huggingface_safetensors(link):
361
  split_link = link.split("/")
362
  if len(split_link) == 2:
 
299
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, use_enhancer, progress=gr.Progress(track_tqdm=True)):
300
  # Check if a LoRA is selected.
301
  if selected_index is None:
302
+ return "You must select a LoRA before proceeding. 🧨", seed, gr.update(visible=False), "" # Return message for image output, update prompt box
303
 
304
  selected_lora = loras[selected_index]
305
  lora_path = selected_lora["repo"]
 
356
  yield final_image, seed, gr.update(value=progress_bar, visible=False), enhanced_text
357
 
358
 
 
359
  def get_huggingface_safetensors(link):
360
  split_link = link.split("/")
361
  if len(split_link) == 2: