SD3-dev1 / _app.py
Alibrown's picture
Update _app.py
bb0d65c verified
raw
history blame
3.44 kB
import os
import random
import numpy as np
import torch
import gradio as gr
from diffusers import StableDiffusionPipeline
import paramiko
from huggingface_hub import login
# Hugging Face Token
HF_TOKEN = os.getenv('HF_TOKEN', '').strip()
if not HF_TOKEN:
raise ValueError("HUGGING_TOKEN is not set. Please set the token as an environment variable.")
# Hugging Face Login
login(token=HF_TOKEN)
# Konfiguration
STORAGE_DOMAIN = os.getenv('STORAGE_DOMAIN', '').strip() # SFTP Server Domain
STORAGE_USER = os.getenv('STORAGE_USER', '').strip() # SFTP User
STORAGE_PSWD = os.getenv('STORAGE_PSWD', '').strip() # SFTP Passwort
STORAGE_PORT = int(os.getenv('STORAGE_PORT', '22').strip()) # SFTP Port
STORAGE_SECRET = os.getenv('STORAGE_SECRET', '').strip() # Secret Token
# Modell laden
device = "cuda" if torch.cuda.is_available() else "cpu"
repo = "stabilityai/stable-diffusion-3-medium-diffusers"
try:
pipe = StableDiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16).to(device)
except Exception as e:
raise RuntimeError(f"Failed to load the model. Ensure the token has access to the repo. Error: {e}")
# Maximalwerte
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1344
# SFTP-Funktion
def upload_to_sftp(local_file, remote_path):
try:
transport = paramiko.Transport((STORAGE_DOMAIN, STORAGE_PORT))
transport.connect(username=STORAGE_USER, password=STORAGE_PSWD)
sftp = paramiko.SFTPClient.from_transport(transport)
sftp.put(local_file, remote_path)
sftp.close()
transport.close()
print(f"File {local_file} successfully uploaded to {remote_path}")
return True
except Exception as e:
print(f"Error during SFTP upload: {e}")
return False
# Inferenz-Funktion
def infer(prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.manual_seed(seed)
image = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator).images[0]
# Speichere Bild lokal
local_file = f"/tmp/generated_image_{seed}.png"
image.save(local_file)
# Hochladen zu SFTP
remote_path = f"/uploads/generated_image_{seed}.png"
if upload_to_sftp(local_file, remote_path):
os.remove(local_file)
return f"Image uploaded to {remote_path}", seed
else:
return "Failed to upload image", seed
# Gradio-App
with gr.Blocks() as demo:
gr.Markdown("### Stable Diffusion 3 - Test App")
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Width")
height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Height")
guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=7.5, label="Guidance Scale")
num_inference_steps = gr.Slider(1, 50, step=1, value=25, label="Inference Steps")
seed = gr.Number(value=42, label="Seed")
randomize_seed = gr.Checkbox(value=False, label="Randomize Seed")
generate_button = gr.Button("Generate Image")
output = gr.Text(label="Output")
generate_button.click(
infer,
inputs=[prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed],
outputs=[output, seed]
)
demo.launch()