Spaces:
Running
on
Zero
Running
on
Zero
Himanshu-AT
commited on
Commit
·
33b3b46
1
Parent(s):
4f8239d
modify
Browse files- README.md +1 -1
- app.py +73 -29
- garment_pipeline.py +60 -0
- recaption.py +48 -0
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🦀
|
4 |
colorFrom: red
|
5 |
colorTo: green
|
|
|
1 |
---
|
2 |
+
title: 0Shot
|
3 |
emoji: 🦀
|
4 |
colorFrom: red
|
5 |
colorTo: green
|
app.py
CHANGED
@@ -2,17 +2,15 @@ import spaces
|
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
from PIL import Image
|
5 |
-
|
6 |
from diffusers.utils import load_image
|
7 |
from pipeline import FluxConditionalPipeline
|
8 |
from transformer import FluxTransformer2DConditionalModel
|
9 |
-
|
|
|
10 |
import os
|
11 |
|
12 |
pipe = None
|
13 |
-
|
14 |
CHECKPOINT = "primecai/dsd_model"
|
15 |
-
|
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
18 |
|
@@ -24,17 +22,20 @@ transformer = FluxTransformer2DConditionalModel.from_pretrained(
|
|
24 |
ignore_mismatched_sizes=True,
|
25 |
use_auth_token=os.getenv("HF_TOKEN"),
|
26 |
)
|
|
|
27 |
pipe = FluxConditionalPipeline.from_pretrained(
|
28 |
"black-forest-labs/FLUX.1-dev",
|
29 |
transformer=transformer,
|
30 |
torch_dtype=dtype,
|
31 |
use_auth_token=os.getenv("HF_TOKEN"),
|
32 |
)
|
|
|
33 |
pipe.load_lora_weights(
|
34 |
CHECKPOINT,
|
35 |
weight_name="pytorch_lora_weights.safetensors",
|
36 |
use_auth_token=os.getenv("HF_TOKEN"),
|
37 |
)
|
|
|
38 |
pipe.to(device, dtype=dtype)
|
39 |
|
40 |
@spaces.GPU
|
@@ -50,7 +51,7 @@ def generate_image(
|
|
50 |
image = image.crop(
|
51 |
((w - min_size) // 2, (h - min_size) // 2, (w + min_size) // 2, (h + min_size) // 2)
|
52 |
).resize((512, 512))
|
53 |
-
|
54 |
control_image = load_image(image)
|
55 |
result_image = pipe(
|
56 |
prompt=text.strip(),
|
@@ -64,9 +65,36 @@ def generate_image(
|
|
64 |
guidance_scale_real_t=t_guidance,
|
65 |
gemini_prompt=gemini_prompt,
|
66 |
).images[0]
|
67 |
-
|
68 |
return result_image
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
def get_samples():
|
72 |
sample_list = [
|
@@ -111,7 +139,6 @@ def get_samples():
|
|
111 |
for sample in sample_list
|
112 |
]
|
113 |
|
114 |
-
|
115 |
demo = gr.Blocks()
|
116 |
|
117 |
with demo:
|
@@ -128,28 +155,45 @@ with demo:
|
|
128 |
</div>
|
129 |
"""
|
130 |
)
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
gr.HTML(
|
154 |
"""
|
155 |
<div style="text-align: center;">
|
|
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
from PIL import Image
|
|
|
5 |
from diffusers.utils import load_image
|
6 |
from pipeline import FluxConditionalPipeline
|
7 |
from transformer import FluxTransformer2DConditionalModel
|
8 |
+
from garment_pipeline import generate_with_garment
|
9 |
+
from recaption import enhance_prompt, enhance_garment_prompt
|
10 |
import os
|
11 |
|
12 |
pipe = None
|
|
|
13 |
CHECKPOINT = "primecai/dsd_model"
|
|
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
16 |
|
|
|
22 |
ignore_mismatched_sizes=True,
|
23 |
use_auth_token=os.getenv("HF_TOKEN"),
|
24 |
)
|
25 |
+
|
26 |
pipe = FluxConditionalPipeline.from_pretrained(
|
27 |
"black-forest-labs/FLUX.1-dev",
|
28 |
transformer=transformer,
|
29 |
torch_dtype=dtype,
|
30 |
use_auth_token=os.getenv("HF_TOKEN"),
|
31 |
)
|
32 |
+
|
33 |
pipe.load_lora_weights(
|
34 |
CHECKPOINT,
|
35 |
weight_name="pytorch_lora_weights.safetensors",
|
36 |
use_auth_token=os.getenv("HF_TOKEN"),
|
37 |
)
|
38 |
+
|
39 |
pipe.to(device, dtype=dtype)
|
40 |
|
41 |
@spaces.GPU
|
|
|
51 |
image = image.crop(
|
52 |
((w - min_size) // 2, (h - min_size) // 2, (w + min_size) // 2, (h + min_size) // 2)
|
53 |
).resize((512, 512))
|
54 |
+
|
55 |
control_image = load_image(image)
|
56 |
result_image = pipe(
|
57 |
prompt=text.strip(),
|
|
|
65 |
guidance_scale_real_t=t_guidance,
|
66 |
gemini_prompt=gemini_prompt,
|
67 |
).images[0]
|
68 |
+
|
69 |
return result_image
|
70 |
|
71 |
+
@spaces.GPU
|
72 |
+
def generate_with_garment_interface(
|
73 |
+
garment_image: Image.Image,
|
74 |
+
text: str,
|
75 |
+
gemini_prompt: bool = True,
|
76 |
+
guidance: float = 3.5,
|
77 |
+
i_guidance: float = 1.5, # Default higher to maintain garment fidelity
|
78 |
+
t_guidance: float = 1.0
|
79 |
+
):
|
80 |
+
"""Interface function for generating images with a garment"""
|
81 |
+
# Use garment-specific prompt enhancement if enabled
|
82 |
+
if gemini_prompt:
|
83 |
+
text = enhance_garment_prompt(garment_image, text)
|
84 |
+
|
85 |
+
# Call the garment-specific generation function
|
86 |
+
result_image = generate_with_garment(
|
87 |
+
pipe=pipe,
|
88 |
+
garment_image=garment_image,
|
89 |
+
text=text,
|
90 |
+
gemini_prompt=False, # Already enhanced above if needed
|
91 |
+
guidance=guidance,
|
92 |
+
i_guidance=i_guidance,
|
93 |
+
t_guidance=t_guidance,
|
94 |
+
device=device
|
95 |
+
)
|
96 |
+
|
97 |
+
return result_image
|
98 |
|
99 |
def get_samples():
|
100 |
sample_list = [
|
|
|
139 |
for sample in sample_list
|
140 |
]
|
141 |
|
|
|
142 |
demo = gr.Blocks()
|
143 |
|
144 |
with demo:
|
|
|
155 |
</div>
|
156 |
"""
|
157 |
)
|
158 |
+
|
159 |
+
with gr.Tabs():
|
160 |
+
with gr.TabItem("Standard Generation"):
|
161 |
+
iface = gr.Interface(
|
162 |
+
fn=generate_image,
|
163 |
+
inputs=[
|
164 |
+
gr.Image(type="pil", width=512),
|
165 |
+
gr.Textbox(lines=2, label="text", info="Could be something as simple as 'this character playing soccer'."),
|
166 |
+
gr.Checkbox(label="Gemini prompt", value=True, info="Use Gemini to enhance the prompt. This is recommended for most cases, unless you have a specific prompt similar to the examples in mind."),
|
167 |
+
gr.Slider(minimum=1.0, maximum=6.0, step=0.5, value=3.5, label="guidance scale", info="Tip: start with 3.5, then gradually increase if the consistency is consistently off"),
|
168 |
+
gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.5, label="real guidance scale for image", info="Tip: increase if the image is not consistent"),
|
169 |
+
gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.0, label="real guidance scale for prompt", info="Tip: increase if the prompt is not consistent"),
|
170 |
+
],
|
171 |
+
outputs=gr.Image(type="pil"),
|
172 |
+
live=False,
|
173 |
+
)
|
174 |
+
gr.Examples(
|
175 |
+
examples=get_samples(),
|
176 |
+
inputs=iface.input_components,
|
177 |
+
outputs=iface.output_components,
|
178 |
+
run_on_click=False # Prevents auto-submission
|
179 |
+
)
|
180 |
+
|
181 |
+
with gr.TabItem("Garment Generation"):
|
182 |
+
garment_iface = gr.Interface(
|
183 |
+
fn=generate_with_garment_interface,
|
184 |
+
inputs=[
|
185 |
+
gr.Image(type="pil", width=512, label="Garment Image", info="Upload an image of the garment you want to keep in the generated output"),
|
186 |
+
gr.Textbox(lines=2, label="Model and Background Description", info="Describe the model and setting you want the garment to appear in, e.g., 'A tall model on a beach at sunset'"),
|
187 |
+
gr.Checkbox(label="Enhance Prompt", value=True, info="Use Gemini to enhance the prompt with detailed garment description"),
|
188 |
+
gr.Slider(minimum=1.0, maximum=6.0, step=0.5, value=3.5, label="guidance scale", info="Controls overall adherence to the prompt"),
|
189 |
+
gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.5, label="garment fidelity", info="Controls how closely the output matches the original garment - higher values preserve more details"),
|
190 |
+
gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.0, label="prompt adherence", info="Controls how closely the output matches the text prompt for model and background"),
|
191 |
+
],
|
192 |
+
outputs=gr.Image(type="pil"),
|
193 |
+
live=False,
|
194 |
+
description="Generate an image of a model wearing the provided garment in a new setting",
|
195 |
+
)
|
196 |
+
|
197 |
gr.HTML(
|
198 |
"""
|
199 |
<div style="text-align: center;">
|
garment_pipeline.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from diffusers.utils import load_image
|
4 |
+
from pipeline import FluxConditionalPipeline
|
5 |
+
|
6 |
+
def generate_with_garment(
|
7 |
+
pipe,
|
8 |
+
garment_image: Image.Image,
|
9 |
+
text: str,
|
10 |
+
gemini_prompt: bool = True,
|
11 |
+
guidance: float = 3.5,
|
12 |
+
i_guidance: float = 1.0,
|
13 |
+
t_guidance: float = 1.0,
|
14 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
15 |
+
):
|
16 |
+
"""
|
17 |
+
Generates an image of a model wearing the provided garment with a new background
|
18 |
+
|
19 |
+
Args:
|
20 |
+
pipe: The FluxConditionalPipeline instance
|
21 |
+
garment_image: Image of the garment to keep in the generated output
|
22 |
+
text: Text prompt describing the desired output (model, pose, background)
|
23 |
+
gemini_prompt: Whether to enhance the prompt using Gemini
|
24 |
+
guidance: General guidance scale
|
25 |
+
i_guidance: Image-specific guidance scale
|
26 |
+
t_guidance: Text-specific guidance scale
|
27 |
+
device: The device to use for generation
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
The generated image
|
31 |
+
"""
|
32 |
+
# Process the garment image
|
33 |
+
w, h, min_size = garment_image.size[0], garment_image.size[1], min(garment_image.size)
|
34 |
+
garment_image = garment_image.crop(
|
35 |
+
((w - min_size) // 2, (h - min_size) // 2, (w + min_size) // 2, (h + min_size) // 2)
|
36 |
+
).resize((512, 512))
|
37 |
+
|
38 |
+
# Prepare garment image as control image
|
39 |
+
control_image = load_image(garment_image)
|
40 |
+
|
41 |
+
# Enhance the prompt to focus on keeping the garment while changing the model and background
|
42 |
+
enhanced_text = text
|
43 |
+
if not "garment" in enhanced_text.lower() and not "clothing" in enhanced_text.lower():
|
44 |
+
enhanced_text = f"A model wearing this garment, {text}"
|
45 |
+
|
46 |
+
# Generate the image
|
47 |
+
result_image = pipe(
|
48 |
+
prompt=enhanced_text.strip(),
|
49 |
+
negative_prompt="distorted garment, wrong clothing, deformed clothes",
|
50 |
+
num_inference_steps=28,
|
51 |
+
height=512,
|
52 |
+
width=1024,
|
53 |
+
guidance_scale=guidance,
|
54 |
+
image=control_image,
|
55 |
+
guidance_scale_real_i=i_guidance, # Higher value to maintain garment fidelity
|
56 |
+
guidance_scale_real_t=t_guidance,
|
57 |
+
gemini_prompt=gemini_prompt,
|
58 |
+
).images[0]
|
59 |
+
|
60 |
+
return result_image
|
recaption.py
CHANGED
@@ -30,4 +30,52 @@ def enhance_prompt(image, prompt):
|
|
30 |
print("input_image_prompt: ", input_image_prompt)
|
31 |
print("prompt: ", prompt)
|
32 |
print("enhanced_prompt: ", enhanced_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
return enhanced_prompt
|
|
|
30 |
print("input_image_prompt: ", input_image_prompt)
|
31 |
print("prompt: ", prompt)
|
32 |
print("enhanced_prompt: ", enhanced_prompt)
|
33 |
+
return enhanced_prompt
|
34 |
+
|
35 |
+
def enhance_garment_prompt(image, prompt):
|
36 |
+
"""
|
37 |
+
Enhances a prompt specifically for garment transformation tasks.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
image: The garment image
|
41 |
+
prompt: User provided prompt for the desired output
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Enhanced prompt that preserves garment details while incorporating user requirements
|
45 |
+
"""
|
46 |
+
input_caption_prompt = (
|
47 |
+
"Please provide a detailed description of this garment/clothing item I will show you. "
|
48 |
+
"Focus on describing the garment's color, pattern, style, fabric, cut, and unique details. "
|
49 |
+
"Be specific about the type of garment (e.g., t-shirt, dress, jacket, pants) and its defining characteristics. "
|
50 |
+
"The description should be detailed enough for an image generation model to recreate this exact garment. "
|
51 |
+
"The description should be short and precise, in one-line format."
|
52 |
+
)
|
53 |
+
|
54 |
+
caption_model = genai.Client(
|
55 |
+
vertexai=False, api_key=os.environ["GOOGLE_API_KEY"]
|
56 |
+
)
|
57 |
+
|
58 |
+
# Get detailed garment description
|
59 |
+
garment_description = caption_model.models.generate_content(
|
60 |
+
model='gemini-1.5-flash', contents=[input_caption_prompt, image]).text
|
61 |
+
garment_description = garment_description.replace('\r', '').replace('\n', '')
|
62 |
+
|
63 |
+
# Enhance user prompt to include garment details
|
64 |
+
enhance_instruction = (
|
65 |
+
f"I need to generate an image of a model wearing a specific garment. "
|
66 |
+
f"The garment is described as: '{garment_description}'. "
|
67 |
+
f"The user wants: '{prompt}'. "
|
68 |
+
f"Create a detailed prompt that combines these elements, ensuring the garment description is preserved exactly while "
|
69 |
+
f"incorporating the user's requirements for the model (person wearing it) and setting/background. "
|
70 |
+
f"Focus on describing a photorealistic scene with a model wearing this specific garment. "
|
71 |
+
f"The enhanced prompt should be short and precise, in one-line format, and should not exceed 77 tokens."
|
72 |
+
)
|
73 |
+
|
74 |
+
enhanced_prompt = caption_model.models.generate_content(
|
75 |
+
model='gemini-1.5-flash', contents=[enhance_instruction]).text.replace('\r', '').replace('\n', '')
|
76 |
+
|
77 |
+
print("garment_description: ", garment_description)
|
78 |
+
print("user prompt: ", prompt)
|
79 |
+
print("enhanced_prompt: ", enhanced_prompt)
|
80 |
+
|
81 |
return enhanced_prompt
|