Himanshu-AT commited on
Commit
33b3b46
·
1 Parent(s): 4f8239d
Files changed (4) hide show
  1. README.md +1 -1
  2. app.py +73 -29
  3. garment_pipeline.py +60 -0
  4. recaption.py +48 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Diffusion Self Distillation
3
  emoji: 🦀
4
  colorFrom: red
5
  colorTo: green
 
1
  ---
2
+ title: 0Shot
3
  emoji: 🦀
4
  colorFrom: red
5
  colorTo: green
app.py CHANGED
@@ -2,17 +2,15 @@ import spaces
2
  import gradio as gr
3
  import torch
4
  from PIL import Image
5
-
6
  from diffusers.utils import load_image
7
  from pipeline import FluxConditionalPipeline
8
  from transformer import FluxTransformer2DConditionalModel
9
-
 
10
  import os
11
 
12
  pipe = None
13
-
14
  CHECKPOINT = "primecai/dsd_model"
15
-
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
18
 
@@ -24,17 +22,20 @@ transformer = FluxTransformer2DConditionalModel.from_pretrained(
24
  ignore_mismatched_sizes=True,
25
  use_auth_token=os.getenv("HF_TOKEN"),
26
  )
 
27
  pipe = FluxConditionalPipeline.from_pretrained(
28
  "black-forest-labs/FLUX.1-dev",
29
  transformer=transformer,
30
  torch_dtype=dtype,
31
  use_auth_token=os.getenv("HF_TOKEN"),
32
  )
 
33
  pipe.load_lora_weights(
34
  CHECKPOINT,
35
  weight_name="pytorch_lora_weights.safetensors",
36
  use_auth_token=os.getenv("HF_TOKEN"),
37
  )
 
38
  pipe.to(device, dtype=dtype)
39
 
40
  @spaces.GPU
@@ -50,7 +51,7 @@ def generate_image(
50
  image = image.crop(
51
  ((w - min_size) // 2, (h - min_size) // 2, (w + min_size) // 2, (h + min_size) // 2)
52
  ).resize((512, 512))
53
-
54
  control_image = load_image(image)
55
  result_image = pipe(
56
  prompt=text.strip(),
@@ -64,9 +65,36 @@ def generate_image(
64
  guidance_scale_real_t=t_guidance,
65
  gemini_prompt=gemini_prompt,
66
  ).images[0]
67
-
68
  return result_image
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def get_samples():
72
  sample_list = [
@@ -111,7 +139,6 @@ def get_samples():
111
  for sample in sample_list
112
  ]
113
 
114
-
115
  demo = gr.Blocks()
116
 
117
  with demo:
@@ -128,28 +155,45 @@ with demo:
128
  </div>
129
  """
130
  )
131
-
132
- iface = gr.Interface(
133
- fn=generate_image,
134
- inputs=[
135
- gr.Image(type="pil", width=512),
136
- gr.Textbox(lines=2, label="text", info="Could be something as simple as 'this character playing soccer'."),
137
- gr.Checkbox(label="Gemini prompt", value=True, info="Use Gemini to enhance the prompt. This is recommended for most cases, unless you have a specific prompt similar to the examples in mind."),
138
- gr.Slider(minimum=1.0, maximum=6.0, step=0.5, value=3.5, label="guidance scale", info="Tip: start with 3.5, then gradually increase if the consistency is consistently off"),
139
- gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.5, label="real guidance scale for image", info="Tip: increase if the image is not consistent"),
140
- gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.0, label="real guidance scale for prompt", info="Tip: increase if the prompt is not consistent"),
141
- ],
142
- outputs=gr.Image(type="pil"),
143
- # examples=get_samples(),
144
- live=False,
145
- )
146
- gr.Examples(
147
- examples=get_samples(),
148
- inputs=iface.input_components,
149
- outputs=iface.output_components,
150
- run_on_click=False # Prevents auto-submission
151
- )
152
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  gr.HTML(
154
  """
155
  <div style="text-align: center;">
 
2
  import gradio as gr
3
  import torch
4
  from PIL import Image
 
5
  from diffusers.utils import load_image
6
  from pipeline import FluxConditionalPipeline
7
  from transformer import FluxTransformer2DConditionalModel
8
+ from garment_pipeline import generate_with_garment
9
+ from recaption import enhance_prompt, enhance_garment_prompt
10
  import os
11
 
12
  pipe = None
 
13
  CHECKPOINT = "primecai/dsd_model"
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
16
 
 
22
  ignore_mismatched_sizes=True,
23
  use_auth_token=os.getenv("HF_TOKEN"),
24
  )
25
+
26
  pipe = FluxConditionalPipeline.from_pretrained(
27
  "black-forest-labs/FLUX.1-dev",
28
  transformer=transformer,
29
  torch_dtype=dtype,
30
  use_auth_token=os.getenv("HF_TOKEN"),
31
  )
32
+
33
  pipe.load_lora_weights(
34
  CHECKPOINT,
35
  weight_name="pytorch_lora_weights.safetensors",
36
  use_auth_token=os.getenv("HF_TOKEN"),
37
  )
38
+
39
  pipe.to(device, dtype=dtype)
40
 
41
  @spaces.GPU
 
51
  image = image.crop(
52
  ((w - min_size) // 2, (h - min_size) // 2, (w + min_size) // 2, (h + min_size) // 2)
53
  ).resize((512, 512))
54
+
55
  control_image = load_image(image)
56
  result_image = pipe(
57
  prompt=text.strip(),
 
65
  guidance_scale_real_t=t_guidance,
66
  gemini_prompt=gemini_prompt,
67
  ).images[0]
68
+
69
  return result_image
70
 
71
+ @spaces.GPU
72
+ def generate_with_garment_interface(
73
+ garment_image: Image.Image,
74
+ text: str,
75
+ gemini_prompt: bool = True,
76
+ guidance: float = 3.5,
77
+ i_guidance: float = 1.5, # Default higher to maintain garment fidelity
78
+ t_guidance: float = 1.0
79
+ ):
80
+ """Interface function for generating images with a garment"""
81
+ # Use garment-specific prompt enhancement if enabled
82
+ if gemini_prompt:
83
+ text = enhance_garment_prompt(garment_image, text)
84
+
85
+ # Call the garment-specific generation function
86
+ result_image = generate_with_garment(
87
+ pipe=pipe,
88
+ garment_image=garment_image,
89
+ text=text,
90
+ gemini_prompt=False, # Already enhanced above if needed
91
+ guidance=guidance,
92
+ i_guidance=i_guidance,
93
+ t_guidance=t_guidance,
94
+ device=device
95
+ )
96
+
97
+ return result_image
98
 
99
  def get_samples():
100
  sample_list = [
 
139
  for sample in sample_list
140
  ]
141
 
 
142
  demo = gr.Blocks()
143
 
144
  with demo:
 
155
  </div>
156
  """
157
  )
158
+
159
+ with gr.Tabs():
160
+ with gr.TabItem("Standard Generation"):
161
+ iface = gr.Interface(
162
+ fn=generate_image,
163
+ inputs=[
164
+ gr.Image(type="pil", width=512),
165
+ gr.Textbox(lines=2, label="text", info="Could be something as simple as 'this character playing soccer'."),
166
+ gr.Checkbox(label="Gemini prompt", value=True, info="Use Gemini to enhance the prompt. This is recommended for most cases, unless you have a specific prompt similar to the examples in mind."),
167
+ gr.Slider(minimum=1.0, maximum=6.0, step=0.5, value=3.5, label="guidance scale", info="Tip: start with 3.5, then gradually increase if the consistency is consistently off"),
168
+ gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.5, label="real guidance scale for image", info="Tip: increase if the image is not consistent"),
169
+ gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.0, label="real guidance scale for prompt", info="Tip: increase if the prompt is not consistent"),
170
+ ],
171
+ outputs=gr.Image(type="pil"),
172
+ live=False,
173
+ )
174
+ gr.Examples(
175
+ examples=get_samples(),
176
+ inputs=iface.input_components,
177
+ outputs=iface.output_components,
178
+ run_on_click=False # Prevents auto-submission
179
+ )
180
+
181
+ with gr.TabItem("Garment Generation"):
182
+ garment_iface = gr.Interface(
183
+ fn=generate_with_garment_interface,
184
+ inputs=[
185
+ gr.Image(type="pil", width=512, label="Garment Image", info="Upload an image of the garment you want to keep in the generated output"),
186
+ gr.Textbox(lines=2, label="Model and Background Description", info="Describe the model and setting you want the garment to appear in, e.g., 'A tall model on a beach at sunset'"),
187
+ gr.Checkbox(label="Enhance Prompt", value=True, info="Use Gemini to enhance the prompt with detailed garment description"),
188
+ gr.Slider(minimum=1.0, maximum=6.0, step=0.5, value=3.5, label="guidance scale", info="Controls overall adherence to the prompt"),
189
+ gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.5, label="garment fidelity", info="Controls how closely the output matches the original garment - higher values preserve more details"),
190
+ gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.0, label="prompt adherence", info="Controls how closely the output matches the text prompt for model and background"),
191
+ ],
192
+ outputs=gr.Image(type="pil"),
193
+ live=False,
194
+ description="Generate an image of a model wearing the provided garment in a new setting",
195
+ )
196
+
197
  gr.HTML(
198
  """
199
  <div style="text-align: center;">
garment_pipeline.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from diffusers.utils import load_image
4
+ from pipeline import FluxConditionalPipeline
5
+
6
+ def generate_with_garment(
7
+ pipe,
8
+ garment_image: Image.Image,
9
+ text: str,
10
+ gemini_prompt: bool = True,
11
+ guidance: float = 3.5,
12
+ i_guidance: float = 1.0,
13
+ t_guidance: float = 1.0,
14
+ device="cuda" if torch.cuda.is_available() else "cpu"
15
+ ):
16
+ """
17
+ Generates an image of a model wearing the provided garment with a new background
18
+
19
+ Args:
20
+ pipe: The FluxConditionalPipeline instance
21
+ garment_image: Image of the garment to keep in the generated output
22
+ text: Text prompt describing the desired output (model, pose, background)
23
+ gemini_prompt: Whether to enhance the prompt using Gemini
24
+ guidance: General guidance scale
25
+ i_guidance: Image-specific guidance scale
26
+ t_guidance: Text-specific guidance scale
27
+ device: The device to use for generation
28
+
29
+ Returns:
30
+ The generated image
31
+ """
32
+ # Process the garment image
33
+ w, h, min_size = garment_image.size[0], garment_image.size[1], min(garment_image.size)
34
+ garment_image = garment_image.crop(
35
+ ((w - min_size) // 2, (h - min_size) // 2, (w + min_size) // 2, (h + min_size) // 2)
36
+ ).resize((512, 512))
37
+
38
+ # Prepare garment image as control image
39
+ control_image = load_image(garment_image)
40
+
41
+ # Enhance the prompt to focus on keeping the garment while changing the model and background
42
+ enhanced_text = text
43
+ if not "garment" in enhanced_text.lower() and not "clothing" in enhanced_text.lower():
44
+ enhanced_text = f"A model wearing this garment, {text}"
45
+
46
+ # Generate the image
47
+ result_image = pipe(
48
+ prompt=enhanced_text.strip(),
49
+ negative_prompt="distorted garment, wrong clothing, deformed clothes",
50
+ num_inference_steps=28,
51
+ height=512,
52
+ width=1024,
53
+ guidance_scale=guidance,
54
+ image=control_image,
55
+ guidance_scale_real_i=i_guidance, # Higher value to maintain garment fidelity
56
+ guidance_scale_real_t=t_guidance,
57
+ gemini_prompt=gemini_prompt,
58
+ ).images[0]
59
+
60
+ return result_image
recaption.py CHANGED
@@ -30,4 +30,52 @@ def enhance_prompt(image, prompt):
30
  print("input_image_prompt: ", input_image_prompt)
31
  print("prompt: ", prompt)
32
  print("enhanced_prompt: ", enhanced_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  return enhanced_prompt
 
30
  print("input_image_prompt: ", input_image_prompt)
31
  print("prompt: ", prompt)
32
  print("enhanced_prompt: ", enhanced_prompt)
33
+ return enhanced_prompt
34
+
35
+ def enhance_garment_prompt(image, prompt):
36
+ """
37
+ Enhances a prompt specifically for garment transformation tasks.
38
+
39
+ Args:
40
+ image: The garment image
41
+ prompt: User provided prompt for the desired output
42
+
43
+ Returns:
44
+ Enhanced prompt that preserves garment details while incorporating user requirements
45
+ """
46
+ input_caption_prompt = (
47
+ "Please provide a detailed description of this garment/clothing item I will show you. "
48
+ "Focus on describing the garment's color, pattern, style, fabric, cut, and unique details. "
49
+ "Be specific about the type of garment (e.g., t-shirt, dress, jacket, pants) and its defining characteristics. "
50
+ "The description should be detailed enough for an image generation model to recreate this exact garment. "
51
+ "The description should be short and precise, in one-line format."
52
+ )
53
+
54
+ caption_model = genai.Client(
55
+ vertexai=False, api_key=os.environ["GOOGLE_API_KEY"]
56
+ )
57
+
58
+ # Get detailed garment description
59
+ garment_description = caption_model.models.generate_content(
60
+ model='gemini-1.5-flash', contents=[input_caption_prompt, image]).text
61
+ garment_description = garment_description.replace('\r', '').replace('\n', '')
62
+
63
+ # Enhance user prompt to include garment details
64
+ enhance_instruction = (
65
+ f"I need to generate an image of a model wearing a specific garment. "
66
+ f"The garment is described as: '{garment_description}'. "
67
+ f"The user wants: '{prompt}'. "
68
+ f"Create a detailed prompt that combines these elements, ensuring the garment description is preserved exactly while "
69
+ f"incorporating the user's requirements for the model (person wearing it) and setting/background. "
70
+ f"Focus on describing a photorealistic scene with a model wearing this specific garment. "
71
+ f"The enhanced prompt should be short and precise, in one-line format, and should not exceed 77 tokens."
72
+ )
73
+
74
+ enhanced_prompt = caption_model.models.generate_content(
75
+ model='gemini-1.5-flash', contents=[enhance_instruction]).text.replace('\r', '').replace('\n', '')
76
+
77
+ print("garment_description: ", garment_description)
78
+ print("user prompt: ", prompt)
79
+ print("enhanced_prompt: ", enhanced_prompt)
80
+
81
  return enhanced_prompt