Files changed (1) hide show
  1. app.py +227 -1
app.py CHANGED
@@ -1,2 +1,228 @@
 
 
1
  import os
2
- exec(os.environ.get('APP'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import argparse
3
  import os
4
+ import time
5
+ from os import path
6
+ import shutil
7
+ from datetime import datetime
8
+ from safetensors.torch import load_file
9
+ from huggingface_hub import hf_hub_download
10
+ import gradio as gr
11
+ import torch
12
+ from diffusers import FluxPipeline
13
+ from PIL import Image
14
+ from transformers import pipeline
15
+
16
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
17
+
18
+ # Hugging Face 토큰 μ„€μ •
19
+ HF_TOKEN = os.getenv("HF_TOKEN")
20
+ if HF_TOKEN is None:
21
+ raise ValueError("HF_TOKEN environment variable is not set")
22
+
23
+ # Setup and initialization code
24
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
25
+ PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
26
+ gallery_path = path.join(PERSISTENT_DIR, "gallery")
27
+
28
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
29
+ os.environ["HF_HUB_CACHE"] = cache_path
30
+ os.environ["HF_HOME"] = cache_path
31
+
32
+ torch.backends.cuda.matmul.allow_tf32 = True
33
+
34
+ # Create gallery directory if it doesn't exist
35
+ if not path.exists(gallery_path):
36
+ os.makedirs(gallery_path, exist_ok=True)
37
+
38
+ class timer:
39
+ def __init__(self, method_name="timed process"):
40
+ self.method = method_name
41
+ def __enter__(self):
42
+ self.start = time.time()
43
+ print(f"{self.method} starts")
44
+ def __exit__(self, exc_type, exc_val, exc_tb):
45
+ end = time.time()
46
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
47
+
48
+ # Model initialization
49
+ if not path.exists(cache_path):
50
+ os.makedirs(cache_path, exist_ok=True)
51
+
52
+ # 인증된 λͺ¨λΈ λ‘œλ“œ
53
+ pipe = FluxPipeline.from_pretrained(
54
+ "black-forest-labs/FLUX.1-dev",
55
+ torch_dtype=torch.bfloat16,
56
+ use_auth_token=HF_TOKEN
57
+ )
58
+
59
+ # Hyper-SD LoRA λ‘œλ“œ (인증 포함)
60
+ pipe.load_lora_weights(
61
+ hf_hub_download(
62
+ "ByteDance/Hyper-SD",
63
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors",
64
+ use_auth_token=HF_TOKEN
65
+ )
66
+ )
67
+ pipe.fuse_lora(lora_scale=0.125)
68
+ pipe.to(device="cuda", dtype=torch.bfloat16)
69
+
70
+ def save_image(image):
71
+ """Save the generated image and return the path"""
72
+ try:
73
+ if not os.path.exists(gallery_path):
74
+ try:
75
+ os.makedirs(gallery_path, exist_ok=True)
76
+ except Exception as e:
77
+ print(f"Failed to create gallery directory: {str(e)}")
78
+ return None
79
+
80
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
81
+ random_suffix = os.urandom(4).hex()
82
+ filename = f"generated_{timestamp}_{random_suffix}.png"
83
+ filepath = os.path.join(gallery_path, filename)
84
+
85
+ try:
86
+ if isinstance(image, Image.Image):
87
+ image.save(filepath, "PNG", quality=100)
88
+ else:
89
+ image = Image.fromarray(image)
90
+ image.save(filepath, "PNG", quality=100)
91
+
92
+ if not os.path.exists(filepath):
93
+ print(f"Warning: Failed to verify saved image at {filepath}")
94
+ return None
95
+
96
+ return filepath
97
+ except Exception as e:
98
+ print(f"Failed to save image: {str(e)}")
99
+ return None
100
+
101
+ except Exception as e:
102
+ print(f"Error in save_image: {str(e)}")
103
+ return None
104
+
105
+ # Create Gradio interface
106
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
107
+ with gr.Row():
108
+ with gr.Column(scale=3):
109
+ prompt = gr.Textbox(
110
+ label="Image Description",
111
+ placeholder="Describe the image you want to create...",
112
+ lines=3
113
+ )
114
+
115
+ with gr.Accordion("Advanced Settings", open=False):
116
+ with gr.Row():
117
+ height = gr.Slider(
118
+ label="Height",
119
+ minimum=256,
120
+ maximum=1152,
121
+ step=64,
122
+ value=1024
123
+ )
124
+ width = gr.Slider(
125
+ label="Width",
126
+ minimum=256,
127
+ maximum=1152,
128
+ step=64,
129
+ value=1024
130
+ )
131
+
132
+ with gr.Row():
133
+ steps = gr.Slider(
134
+ label="Inference Steps",
135
+ minimum=6,
136
+ maximum=25,
137
+ step=1,
138
+ value=8
139
+ )
140
+ scales = gr.Slider(
141
+ label="Guidance Scale",
142
+ minimum=0.0,
143
+ maximum=5.0,
144
+ step=0.1,
145
+ value=3.5
146
+ )
147
+
148
+ def get_random_seed():
149
+ return torch.randint(0, 1000000, (1,)).item()
150
+
151
+ seed = gr.Number(
152
+ label="Seed (random by default, set for reproducibility)",
153
+ value=get_random_seed(),
154
+ precision=0
155
+ )
156
+
157
+ randomize_seed = gr.Button("🎲 Randomize Seed", elem_classes=["generate-btn"])
158
+
159
+ generate_btn = gr.Button(
160
+ "✨ Generate Image",
161
+ elem_classes=["generate-btn"]
162
+ )
163
+
164
+ with gr.Column(scale=4, elem_classes=["fixed-width"]):
165
+ output = gr.Image(
166
+ label="Generated Image",
167
+ elem_id="output-image",
168
+ elem_classes=["output-image", "fixed-width"]
169
+ )
170
+
171
+ @spaces.GPU
172
+ def process_and_save_image(height, width, steps, scales, prompt, seed):
173
+ global pipe
174
+
175
+ # ν•œκΈ€ 감지 및 λ²ˆμ—­
176
+ def contains_korean(text):
177
+ return any(ord('κ°€') <= ord(c) <= ord('힣') for c in text)
178
+
179
+ # ν”„λ‘¬ν”„νŠΈ μ „μ²˜λ¦¬
180
+ if contains_korean(prompt):
181
+ # ν•œκΈ€μ„ μ˜μ–΄λ‘œ λ²ˆμ—­
182
+ translated = translator(prompt)[0]['translation_text']
183
+ prompt = translated
184
+
185
+ # ν”„λ‘¬ν”„νŠΈ ν˜•μ‹ κ°•μ œ
186
+ formatted_prompt = f"wbgmsst, 3D, {prompt} ,white background"
187
+
188
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
189
+ try:
190
+ generated_image = pipe(
191
+ prompt=[formatted_prompt],
192
+ generator=torch.Generator().manual_seed(int(seed)),
193
+ num_inference_steps=int(steps),
194
+ guidance_scale=float(scales),
195
+ height=int(height),
196
+ width=int(width),
197
+ max_sequence_length=256
198
+ ).images[0]
199
+
200
+ saved_path = save_image(generated_image)
201
+ if saved_path is None:
202
+ print("Warning: Failed to save generated image")
203
+
204
+ return generated_image
205
+ except Exception as e:
206
+ print(f"Error in image generation: {str(e)}")
207
+ return None
208
+
209
+ def update_seed():
210
+ return get_random_seed()
211
+
212
+ # Click event handlers inside gr.Blocks context
213
+ generate_btn.click(
214
+ process_and_save_image,
215
+ inputs=[height, width, steps, scales, prompt, seed],
216
+ outputs=output
217
+ ).then(
218
+ update_seed,
219
+ outputs=[seed]
220
+ )
221
+
222
+ randomize_seed.click(
223
+ update_seed,
224
+ outputs=[seed]
225
+ )
226
+
227
+ if __name__ == "__main__":
228
+ demo.launch(allowed_paths=[PERSISTENT_DIR])