Flux-LoRA-Generation-Advanced / app-v2-working.py
ruslanmv's picture
updates
83dbf4b
import os
import json
import copy
import time
import random
import logging
import numpy as np
from typing import Any, Dict, List, Optional, Union
import torch
from PIL import Image
import gradio as gr
from diffusers import (
DiffusionPipeline,
AutoencoderTiny,
AutoencoderKL,
AutoPipelineForImage2Image,
FluxPipeline,
FlowMatchEulerDiscreteScheduler
)
from huggingface_hub import (
hf_hub_download,
HfFileSystem,
ModelCard,
snapshot_download
)
from diffusers.utils import load_image
import spaces
# Import the prompt enhancer generator from enhance.py
from enhance import generate as enhance_generate
# Attempt to import loras from lora.py; otherwise use a default placeholder.
try:
from lora import loras
except ImportError:
loras = [
{"image": "placeholder.jpg", "title": "Placeholder LoRA", "repo": "placeholder/repo", "weights": None, "trigger_word": ""}
]
#---if workspace = local or colab---
# (Optional: add Hugging Face login code here)
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# FLUX pipeline
@torch.inference_mode()
def flux_pipe_call_that_returns_an_iterable_of_images(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
max_sequence_length: int = 512,
good_vae: Optional[Any] = None,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
self.check_inputs(
prompt,
prompt_2,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
self._num_timesteps = len(timesteps)
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents_for_image, return_dict=False)[0]
yield self.image_processor.postprocess(image, output_type=output_type)[0]
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
torch.cuda.empty_cache()
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
image = good_vae.decode(latents, return_dict=False)[0]
self.maybe_free_model_hooks()
torch.cuda.empty_cache()
yield self.image_processor.postprocess(image, output_type=output_type)[0]
#--------------------------------------------------Model Initialization-----------------------------------------------------------------------------------------#
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
# TAEF1 is a very tiny autoencoder which uses the same "latent API" as FLUX.1's VAE.
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
base_model,
vae=good_vae,
transformer=pipe.transformer,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
text_encoder_2=pipe.text_encoder_2,
tokenizer_2=pipe.tokenizer_2,
torch_dtype=dtype,
).to(device)
MAX_SEED = 2**32-1
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
def update_selection(evt: gr.SelectData, width, height):
selected_lora = loras[evt.index]
new_placeholder = f"Type a prompt for {selected_lora['title']}"
lora_repo = selected_lora["repo"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co./{lora_repo}) ✅"
if "aspect" in selected_lora:
if selected_lora["aspect"] == "portrait":
width = 768
height = 1024
elif selected_lora["aspect"] == "landscape":
width = 1024
height = 768
else:
width = 1024
height = 1024
return (
gr.update(placeholder=new_placeholder),
updated_text,
evt.index,
width,
height,
)
@spaces.GPU(duration=100)
def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
pipe.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(seed)
with calculateDuration("Generating image"):
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=prompt_mash,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
output_type="pil",
good_vae=good_vae,
):
yield img
def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
generator = torch.Generator(device="cuda").manual_seed(seed)
pipe_i2i.to("cuda")
image_input = load_image(image_input_path)
final_image = pipe_i2i(
prompt=prompt_mash,
image=image_input,
strength=image_strength,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
output_type="pil",
).images[0]
return final_image
@spaces.GPU(duration=100)
def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, use_enhancer, progress=gr.Progress(track_tqdm=True)):
# Check if a LoRA is selected.
if selected_index is None:
raise gr.Error("You must select a LoRA before proceeding.🧨")
selected_lora = loras[selected_index]
lora_path = selected_lora["repo"]
trigger_word = selected_lora["trigger_word"]
# Prepare prompt by appending/prepending trigger word if available.
if trigger_word:
if "trigger_position" in selected_lora and selected_lora["trigger_position"] == "prepend":
prompt_mash = f"{trigger_word} {prompt}"
else:
prompt_mash = f"{prompt} {trigger_word}"
else:
prompt_mash = prompt
# If prompt enhancer is enabled, stream the enhanced prompt.
enhanced_text = ""
if use_enhancer:
for enhanced_chunk in enhance_generate(prompt_mash):
enhanced_text = enhanced_chunk
# Yield intermediate output (no image yet, but update enhanced prompt textbox)
yield None, seed, gr.update(visible=False), enhanced_text
prompt_mash = enhanced_text # Use final enhanced prompt for generation
# Else, leave prompt_mash as is.
with calculateDuration("Unloading LoRA"):
pipe.unload_lora_weights()
pipe_i2i.unload_lora_weights()
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
pipe_to_use = pipe_i2i if image_input is not None else pipe
weight_name = selected_lora.get("weights", None)
pipe_to_use.load_lora_weights(
lora_path,
weight_name=weight_name,
low_cpu_mem_usage=True
)
with calculateDuration("Randomizing seed"):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
if image_input is not None:
final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
yield final_image, seed, gr.update(visible=False), enhanced_text
else:
image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
final_image = None
step_counter = 0
for image in image_generator:
step_counter += 1
final_image = image
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
yield image, seed, gr.update(value=progress_bar, visible=True), enhanced_text
yield final_image, seed, gr.update(value=progress_bar, visible=False), enhanced_text
def get_huggingface_safetensors(link):
split_link = link.split("/")
if len(split_link) == 2:
model_card = ModelCard.load(link)
base_model = model_card.data.get("base_model")
print(base_model)
if (base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell"):
raise Exception("Flux LoRA Not Found!")
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
trigger_word = model_card.data.get("instance_prompt", "")
image_url = f"https://huggingface.co./{link}/resolve/main/{image_path}" if image_path else None
fs = HfFileSystem()
try:
list_of_files = fs.ls(link, detail=False)
for file in list_of_files:
if file.endswith(".safetensors"):
safetensors_name = file.split("/")[-1]
if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
image_elements = file.split("/")
image_url = f"https://huggingface.co./{link}/resolve/main/{image_elements[-1]}"
except Exception as e:
print(e)
gr.Warning("You didn't include a link nor a valid Hugging Face repository with a *.safetensors LoRA")
raise Exception("Invalid LoRA repository")
return split_link[1], link, safetensors_name, trigger_word, image_url
else:
raise Exception("Invalid LoRA link format")
def check_custom_model(link):
if link.startswith("https://"):
if link.startswith("https://huggingface.co.") or link.startswith("https://www.huggingface.co"):
link_split = link.split("huggingface.co/")
return get_huggingface_safetensors(link_split[1])
else:
return get_huggingface_safetensors(link)
def add_custom_lora(custom_lora):
global loras
if custom_lora:
try:
title, repo, path, trigger_word, image = check_custom_model(custom_lora)
print(f"Loaded custom LoRA: {repo}")
card = f'''
<div class="custom_lora_card">
<span>Loaded custom LoRA:</span>
<div class="card_internal">
<img src="{image}" />
<div>
<h3>{title}</h3>
<small>{"Using: <code><b>" + trigger_word + "</b></code> as the trigger word" if trigger_word else "No trigger word found. Include it in your prompt"}<br></small>
</div>
</div>
</div>
'''
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
if not existing_item_index:
new_item = {
"image": image,
"title": title,
"repo": repo,
"weights": path,
"trigger_word": trigger_word
}
print(new_item)
existing_item_index = len(loras)
loras.append(new_item)
return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
except Exception as e:
gr.Warning("Invalid LoRA: either you entered an invalid link or a non-FLUX LoRA")
return gr.update(visible=True, value="Invalid LoRA"), gr.update(visible=False), gr.update(), "", None, ""
else:
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
def remove_custom_lora():
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
run_lora.zerogpu = True
css = '''
#gen_btn { height: 100%; }
#gen_column { align-self: stretch; }
#title { text-align: center; }
#title h1 { font-size: 3em; display:inline-flex; align-items:center; }
#title img { width: 100px; margin-right: 0.5em; }
#gallery .grid-wrap { height: 10vh; }
#lora_list { background: var(--block-background-fill); padding: 0 1em .3em; font-size: 90%; }
.card_internal { display: flex; height: 100px; margin-top: .5em; }
.card_internal img { margin-right: 1em; }
.styler { --form-gap-width: 0px !important; }
#progress { height:30px; }
#progress .generating { display:none; }
.progress-container { width: 100%; height: 30px; background-color: #f0f0f0; border-radius: 15px; overflow: hidden; margin-bottom: 20px; }
.progress-bar { height: 100%; background-color: #4f46e5; width: calc(var(--current) / var(--total) * 100%); transition: width 0.5s ease-in-out; }
'''
with gr.Blocks(theme=gr.themes.Base(), css=css, delete_cache=(60, 60)) as app:
title = gr.HTML(
"""<h1>Flux LoRA Generation</h1>""",
elem_id="title",
)
selected_index = gr.State(None)
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=1, placeholder=":/ choose the LoRA and type the prompt ")
with gr.Column(scale=1, elem_id="gen_column"):
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
with gr.Row():
with gr.Column():
selected_info = gr.Markdown("")
gallery = gr.Gallery(
[(item["image"], item["title"]) for item in loras],
label="LoRA DLC's",
allow_preview=False,
columns=3,
elem_id="gallery",
show_share_button=False
)
with gr.Group():
custom_lora = gr.Textbox(label="Enter Custom LoRA", placeholder="prithivMLmods/Canopus-LoRA-Flux-Anime")
gr.Markdown("[Check the list of FLUX LoRA's](https://huggingface.co./models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
custom_lora_info = gr.HTML(visible=False)
custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
with gr.Column():
progress_bar = gr.Markdown(elem_id="progress", visible=False)
result = gr.Image(label="Generated Image")
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
input_image = gr.Image(label="Input image", type="filepath")
image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
with gr.Column():
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
with gr.Row():
randomize_seed = gr.Checkbox(True, label="Randomize seed")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=0.95)
with gr.Row():
use_enhancer = gr.Checkbox(value=False, label="Use Prompt Enhancer")
show_enhanced_prompt = gr.Checkbox(value=False, label="Display Enhanced Prompt")
enhanced_prompt_box = gr.Textbox(label="Enhanced Prompt", visible=False)
# Add the change event so that the enhanced prompt box visibility toggles.
show_enhanced_prompt.change(fn=lambda show: gr.update(visible=show),
inputs=show_enhanced_prompt,
outputs=enhanced_prompt_box)
gallery.select(
update_selection,
inputs=[width, height],
outputs=[prompt, selected_info, selected_index, width, height]
)
custom_lora.input(
add_custom_lora,
inputs=[custom_lora],
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
)
custom_lora_button.click(
remove_custom_lora,
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
)
gr.on(
triggers=[generate_button.click, prompt.submit],
fn=run_lora,
inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, use_enhancer],
outputs=[result, seed, progress_bar, enhanced_prompt_box]
)
with gr.Row():
gr.HTML("<div style='text-align:center; font-size:0.9em; margin-top:20px;'>Credits: <a href='https://ruslanmv.com' target='_blank'>ruslanmv.com</a></div>")
app.queue()
app.launch(debug=True)