erwold commited on
Commit
0ded2d6
·
1 Parent(s): 2645f74

Initial Commit

Browse files
Files changed (1) hide show
  1. app.py +65 -5
app.py CHANGED
@@ -7,9 +7,20 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
7
  from flux.transformer_flux import FluxTransformer2DModel
8
  from flux.pipeline_flux_chameleon import FluxPipeline
9
  import torch.nn as nn
 
10
 
11
  MODEL_ID = "Djrango/Qwen2vl-Flux"
12
 
 
 
 
 
 
 
 
 
 
 
13
  class Qwen2Connector(nn.Module):
14
  def __init__(self, input_dim=3584, output_dim=4096):
15
  super().__init__()
@@ -88,6 +99,23 @@ class FluxInterface:
88
  text_encoder=text_encoder,
89
  tokenizer=tokenizer,
90
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  # [Previous methods remain unchanged...]
93
  def process_image(self, image):
@@ -109,8 +137,8 @@ class FluxInterface:
109
  image_hidden_state = self.models['connector'](image_hidden_state)
110
 
111
  return image_hidden_state, image_grid_thw
112
-
113
- def compute_text_embeddings(self, prompt):
114
  """Compute T5 embeddings for text prompt"""
115
  if prompt == "":
116
  return None
@@ -129,13 +157,36 @@ class FluxInterface:
129
 
130
  return prompt_embeds
131
 
132
- def generate(self, input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  try:
134
  if seed is not None:
135
  torch.manual_seed(seed)
136
 
137
  self.load_models()
138
 
 
 
 
 
 
139
  # Process input image
140
  input_image = self.resize_image(input_image)
141
  qwen2_hidden_state, image_grid_thw = self.process_image(input_image)
@@ -151,6 +202,8 @@ class FluxInterface:
151
  t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
152
  num_inference_steps=num_inference_steps,
153
  guidance_scale=guidance_scale,
 
 
154
  ).images
155
 
156
  return output_images
@@ -212,7 +265,7 @@ with gr.Blocks(
212
  with gr.Group():
213
  prompt = gr.Textbox(
214
  label="Text Prompt (Optional)",
215
- placeholder="Describe how you want to modify the image...",
216
  lines=3
217
  )
218
 
@@ -249,6 +302,12 @@ with gr.Blocks(
249
  precision=0,
250
  info="Set for reproducible results"
251
  )
 
 
 
 
 
 
252
 
253
  submit_btn = gr.Button(
254
  "🎨 Generate Variations",
@@ -288,7 +347,8 @@ with gr.Blocks(
288
  guidance,
289
  steps,
290
  num_images,
291
- seed
 
292
  ],
293
  outputs=output_gallery,
294
  show_progress="minimal"
 
7
  from flux.transformer_flux import FluxTransformer2DModel
8
  from flux.pipeline_flux_chameleon import FluxPipeline
9
  import torch.nn as nn
10
+ import math
11
 
12
  MODEL_ID = "Djrango/Qwen2vl-Flux"
13
 
14
+ # Add aspect ratio options
15
+ ASPECT_RATIOS = {
16
+ "1:1": (1024, 1024),
17
+ "16:9": (1344, 768),
18
+ "9:16": (768, 1344),
19
+ "2.4:1": (1536, 640),
20
+ "3:4": (896, 1152),
21
+ "4:3": (1152, 896),
22
+ }
23
+
24
  class Qwen2Connector(nn.Module):
25
  def __init__(self, input_dim=3584, output_dim=4096):
26
  super().__init__()
 
99
  text_encoder=text_encoder,
100
  tokenizer=tokenizer,
101
  )
102
+
103
+ def resize_image(self, img, max_pixels=1050000):
104
+ if not isinstance(img, Image.Image):
105
+ img = Image.fromarray(img)
106
+
107
+ width, height = img.size
108
+ num_pixels = width * height
109
+
110
+ if num_pixels > max_pixels:
111
+ scale = math.sqrt(max_pixels / num_pixels)
112
+ new_width = int(width * scale)
113
+ new_height = int(height * scale)
114
+ new_width = new_width - (new_width % 8)
115
+ new_height = new_height - (new_height % 8)
116
+ img = img.resize((new_width, new_height), Image.LANCZOS)
117
+
118
+ return img
119
 
120
  # [Previous methods remain unchanged...]
121
  def process_image(self, image):
 
137
  image_hidden_state = self.models['connector'](image_hidden_state)
138
 
139
  return image_hidden_state, image_grid_thw
140
+
141
+ def compute_t5_text_embeddings(self, prompt):
142
  """Compute T5 embeddings for text prompt"""
143
  if prompt == "":
144
  return None
 
157
 
158
  return prompt_embeds
159
 
160
+ def compute_text_embeddings(self, prompt=""):
161
+ with torch.no_grad():
162
+ text_inputs = self.models['tokenizer'](
163
+ prompt,
164
+ padding="max_length",
165
+ max_length=77,
166
+ truncation=True,
167
+ return_tensors="pt"
168
+ ).to(self.device)
169
+
170
+ prompt_embeds = self.models['text_encoder'](
171
+ text_inputs.input_ids,
172
+ output_hidden_states=False
173
+ )
174
+ pooled_prompt_embeds = prompt_embeds.pooler_output.to(self.dtype)
175
+
176
+ return pooled_prompt_embeds
177
+
178
+ def generate(self, input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"):
179
  try:
180
  if seed is not None:
181
  torch.manual_seed(seed)
182
 
183
  self.load_models()
184
 
185
+ # Get dimensions from aspect ratio
186
+ if aspect_ratio not in ASPECT_RATIOS:
187
+ raise ValueError(f"Invalid aspect ratio. Choose from {list(ASPECT_RATIOS.keys())}")
188
+ width, height = ASPECT_RATIOS[aspect_ratio]
189
+
190
  # Process input image
191
  input_image = self.resize_image(input_image)
192
  qwen2_hidden_state, image_grid_thw = self.process_image(input_image)
 
202
  t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
203
  num_inference_steps=num_inference_steps,
204
  guidance_scale=guidance_scale,
205
+ height=height,
206
+ width=width,
207
  ).images
208
 
209
  return output_images
 
265
  with gr.Group():
266
  prompt = gr.Textbox(
267
  label="Text Prompt (Optional)",
268
+ placeholder="As Long As Possible...",
269
  lines=3
270
  )
271
 
 
302
  precision=0,
303
  info="Set for reproducible results"
304
  )
305
+ aspect_ratio = gr.Radio(
306
+ label="Aspect Ratio",
307
+ choices=["1:1", "16:9", "9:16", "2.4:1", "3:4", "4:3"],
308
+ value="1:1",
309
+ info="Choose aspect ratio for generated images"
310
+ )
311
 
312
  submit_btn = gr.Button(
313
  "🎨 Generate Variations",
 
347
  guidance,
348
  steps,
349
  num_images,
350
+ seed,
351
+ aspect_ratio
352
  ],
353
  outputs=output_gallery,
354
  show_progress="minimal"