Alibrown commited on
Commit
7c29612
Β·
verified Β·
1 Parent(s): 0e9e4bc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ import gradio as gr
6
+ from diffusers import StableDiffusionPipeline
7
+ import paramiko
8
+
9
+ # Konfiguration
10
+ STORAGE_DOMAIN = os.getenv('STORAGE_DOMAIN', '').strip() # SFTP Server Domain
11
+ STORAGE_USER = os.getenv('STORAGE_USER', '').strip() # SFTP User
12
+ STORAGE_PSWD = os.getenv('STORAGE_PSWD', '').strip() # SFTP Passwort
13
+ STORAGE_PORT = int(os.getenv('STORAGE_PORT', '22').strip()) # SFTP Port
14
+ STORAGE_SECRET = os.getenv('STORAGE_SECRET', '').strip() # Secret Token
15
+
16
+ # Modell laden
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ repo = "stabilityai/stable-diffusion-3-medium-diffusers"
19
+ pipe = StableDiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16).to(device)
20
+
21
+ # Maximalwerte
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+ MAX_IMAGE_SIZE = 1344
24
+
25
+ # SFTP-Funktion
26
+ def upload_to_sftp(local_file, remote_path):
27
+ try:
28
+ transport = paramiko.Transport((STORAGE_DOMAIN, STORAGE_PORT))
29
+ transport.connect(username=STORAGE_USER, password=STORAGE_PSWD)
30
+ sftp = paramiko.SFTPClient.from_transport(transport)
31
+ sftp.put(local_file, remote_path)
32
+ sftp.close()
33
+ transport.close()
34
+ print(f"File {local_file} successfully uploaded to {remote_path}")
35
+ return True
36
+ except Exception as e:
37
+ print(f"Error during SFTP upload: {e}")
38
+ return False
39
+
40
+ # Inferenz-Funktion
41
+ def infer(prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed):
42
+ if randomize_seed:
43
+ seed = random.randint(0, MAX_SEED)
44
+
45
+ generator = torch.manual_seed(seed)
46
+ image = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator).images[0]
47
+
48
+ # Speichere Bild lokal
49
+ local_file = f"/tmp/generated_image_{seed}.png"
50
+ image.save(local_file)
51
+
52
+ # Hochladen zu SFTP
53
+ remote_path = f"/uploads/generated_image_{seed}.png"
54
+ if upload_to_sftp(local_file, remote_path):
55
+ os.remove(local_file)
56
+ return f"Image uploaded to {remote_path}", seed
57
+ else:
58
+ return "Failed to upload image", seed
59
+
60
+ # Gradio-App
61
+ with gr.Blocks() as demo:
62
+ gr.Markdown("### Stable Diffusion 3 - Test App")
63
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
64
+ width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Width")
65
+ height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=512, label="Height")
66
+ guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=7.5, label="Guidance Scale")
67
+ num_inference_steps = gr.Slider(1, 50, step=1, value=25, label="Inference Steps")
68
+ seed = gr.Number(value=42, label="Seed")
69
+ randomize_seed = gr.Checkbox(value=False, label="Randomize Seed")
70
+ generate_button = gr.Button("Generate Image")
71
+ output = gr.Text(label="Output")
72
+
73
+ generate_button.click(
74
+ infer,
75
+ inputs=[prompt, width, height, guidance_scale, num_inference_steps, seed, randomize_seed],
76
+ outputs=[output, seed]
77
+ )
78
+
79
+ demo.launch()