erwold commited on
Commit
2c1a6cc
·
1 Parent(s): e152f3a
Files changed (1) hide show
  1. app.py +192 -357
app.py CHANGED
@@ -10,44 +10,104 @@ import torch.nn as nn
10
  import math
11
  import logging
12
  import sys
13
-
14
- from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
15
  from huggingface_hub import snapshot_download
 
 
16
 
17
  # 设置日志
18
  logging.basicConfig(
19
  level=logging.INFO,
20
  format='%(asctime)s - %(levelname)s - %(message)s',
21
- handlers=[
22
- logging.StreamHandler(sys.stdout)
23
- ]
24
  )
25
  logger = logging.getLogger(__name__)
26
 
27
  MODEL_ID = "Djrango/Qwen2vl-Flux"
28
  MODEL_CACHE_DIR = "model_cache"
 
 
29
 
30
- # 预下载所有模型
31
- def download_models():
32
  logger.info("Starting model download...")
33
  try:
34
- # 下载完整模型仓库
35
  snapshot_download(
36
  repo_id=MODEL_ID,
37
  local_dir=MODEL_CACHE_DIR,
38
  local_dir_use_symlinks=False
39
  )
40
-
41
  logger.info("Model download completed successfully")
42
  except Exception as e:
43
  logger.error(f"Error downloading models: {str(e)}")
44
  raise
45
 
46
- # 在脚本开始时下载模型
47
- if not os.path.exists(MODEL_CACHE_DIR):
48
- download_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Add aspect ratio options
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  ASPECT_RATIOS = {
52
  "1:1": (1024, 1024),
53
  "16:9": (1344, 768),
@@ -57,304 +117,121 @@ ASPECT_RATIOS = {
57
  "4:3": (1152, 896),
58
  }
59
 
60
- class Qwen2Connector(nn.Module):
61
- def __init__(self, input_dim=3584, output_dim=4096):
62
- super().__init__()
63
- self.linear = nn.Linear(input_dim, output_dim)
64
-
65
- def forward(self, x):
66
- return self.linear(x)
67
-
68
- class FluxInterface:
69
- def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
70
- self.device = device
71
- self.dtype = torch.bfloat16
72
- self.models = None
73
- self.MODEL_ID = "Djrango/Qwen2vl-Flux"
74
-
75
- def load_models(self):
76
- if self.models is not None:
77
- return
78
-
79
- logger.info("Starting model loading...")
80
-
81
- # 1. 首先加载较小的模型到GPU
82
- tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer"))
83
- text_encoder = CLIPTextModel.from_pretrained(
84
- os.path.join(MODEL_CACHE_DIR, "flux/text_encoder")
85
- ).to(self.dtype).to(self.device)
86
-
87
- text_encoder_two = T5EncoderModel.from_pretrained(
88
- os.path.join(MODEL_CACHE_DIR, "flux/text_encoder_2")
89
- ).to(self.dtype).to(self.device)
90
-
91
- tokenizer_two = T5TokenizerFast.from_pretrained(
92
- os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2"))
93
-
94
- # 2. 将大模型加载到CPU,但保持bfloat16精度
95
- vae = AutoencoderKL.from_pretrained(
96
- os.path.join(MODEL_CACHE_DIR, "flux/vae")
97
- ).to(self.dtype).cpu()
98
-
99
- transformer = FluxTransformer2DModel.from_pretrained(
100
- os.path.join(MODEL_CACHE_DIR, "flux/transformer")
101
- ).to(self.dtype).cpu()
102
-
103
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
104
- os.path.join(MODEL_CACHE_DIR, "flux/scheduler"),
105
- shift=1
106
- )
107
-
108
- # 3. Qwen2VL加载到CPU,保持bfloat16
109
- qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
110
- os.path.join(MODEL_CACHE_DIR, "qwen2-vl")
111
- ).to(self.dtype).cpu()
112
-
113
- # 4. 加载connector和embedder,保持bfloat16
114
- connector = Qwen2Connector().to(self.dtype).cpu()
115
- connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
116
- connector_state = torch.load(connector_path, map_location='cpu')
117
- connector_state = {k.replace('module.', ''): v.to(self.dtype) for k, v in connector_state.items()}
118
- connector.load_state_dict(connector_state)
119
-
120
- self.t5_context_embedder = nn.Linear(4096, 3072).to(self.dtype).cpu()
121
- t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
122
- t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
123
- t5_embedder_state = {k: v.to(self.dtype) for k, v in t5_embedder_state.items()}
124
- self.t5_context_embedder.load_state_dict(t5_embedder_state)
125
-
126
- # 5. 设置所有模型为eval模式
127
- for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl,
128
- connector, self.t5_context_embedder]:
129
- model.requires_grad_(False)
130
- model.eval()
131
-
132
- logger.info("All models loaded successfully")
133
-
134
- self.models = {
135
- 'tokenizer': tokenizer,
136
- 'text_encoder': text_encoder,
137
- 'text_encoder_two': text_encoder_two,
138
- 'tokenizer_two': tokenizer_two,
139
- 'vae': vae,
140
- 'transformer': transformer,
141
- 'scheduler': scheduler,
142
- 'qwen2vl': qwen2vl,
143
- 'connector': connector
144
- }
145
-
146
- self.qwen2vl_processor = AutoProcessor.from_pretrained(
147
- self.MODEL_ID,
148
- subfolder="qwen2-vl",
149
- min_pixels=256*28*28,
150
- max_pixels=256*28*28
151
- )
152
-
153
- self.pipeline = FluxPipeline(
154
- transformer=transformer,
155
- scheduler=scheduler,
156
- vae=vae,
157
- text_encoder=text_encoder,
158
- tokenizer=tokenizer,
159
  )
160
 
161
- def move_to_device(self, model, device):
162
- """Helper function to move model to specified device"""
163
- if hasattr(model, 'to'):
164
- return model.to(self.dtype).to(device)
165
- return model
166
-
167
- def process_image(self, image):
168
- """Process image with Qwen2VL model"""
169
- try:
170
- # 1. 将Qwen2VL相关模型移到GPU
171
- logger.info("Moving Qwen2VL models to GPU...")
172
- self.models['qwen2vl'] = self.models['qwen2vl'].to(self.device)
173
- self.models['connector'] = self.models['connector'].to(self.device)
174
- logger.info("Qwen2VL models moved to GPU")
175
-
176
- message = [
177
- {
178
- "role": "user",
179
- "content": [
180
- {"type": "image", "image": image},
181
- {"type": "text", "text": "Describe this image."},
182
- ]
183
- }
184
- ]
185
- text = self.qwen2vl_processor.apply_chat_template(
186
- message,
187
- tokenize=False,
188
- add_generation_prompt=True
189
- )
190
-
191
- with torch.no_grad():
192
- inputs = self.qwen2vl_processor(
193
- text=[text],
194
- images=[image],
195
- padding=True,
196
- return_tensors="pt"
197
- ).to(self.device)
198
-
199
- output_hidden_state, image_token_mask, image_grid_thw = self.models['qwen2vl'](**inputs)
200
- image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
201
- image_hidden_state = self.models['connector'](image_hidden_state)
202
-
203
- # 保存结果到CPU
204
- result = (image_hidden_state.cpu(), image_grid_thw)
205
 
206
- # 2. 将Qwen2VL相关模型移回CPU
207
- logger.info("Moving Qwen2VL models back to CPU...")
208
- self.models['qwen2vl'] = self.models['qwen2vl'].cpu()
209
- self.models['connector'] = self.models['connector'].cpu()
210
- torch.cuda.empty_cache()
211
- logger.info("Qwen2VL models moved to CPU and GPU cache cleared")
212
 
213
- return result
214
 
215
- except Exception as e:
216
- logger.error(f"Error in process_image: {str(e)}")
217
- raise
218
-
219
- def resize_image(self, img, max_pixels=1050000):
220
- if not isinstance(img, Image.Image):
221
- img = Image.fromarray(img)
222
-
223
- width, height = img.size
224
- num_pixels = width * height
225
-
226
- if num_pixels > max_pixels:
227
- scale = math.sqrt(max_pixels / num_pixels)
228
- new_width = int(width * scale)
229
- new_height = int(height * scale)
230
- new_width = new_width - (new_width % 8)
231
- new_height = new_height - (new_height % 8)
232
- img = img.resize((new_width, new_height), Image.LANCZOS)
233
 
234
- return img
235
-
236
- def compute_t5_text_embeddings(self, prompt):
237
- """Compute T5 embeddings for text prompt"""
238
- if prompt == "":
239
- return None
240
-
241
- text_inputs = self.models['tokenizer_two'](
 
 
 
 
 
 
 
 
 
242
  prompt,
243
  padding="max_length",
244
- max_length=256,
245
  truncation=True,
246
  return_tensors="pt"
247
- ).to(self.device)
248
-
249
- prompt_embeds = self.models['text_encoder_two'](text_inputs.input_ids)[0]
250
- prompt_embeds = self.t5_context_embedder.to(self.device)(prompt_embeds)
251
- self.t5_context_embedder = self.t5_context_embedder.cpu()
252
-
253
- return prompt_embeds
254
-
255
- def compute_text_embeddings(self, prompt=""):
256
- with torch.no_grad():
257
- text_inputs = self.models['tokenizer'](
258
- prompt,
259
- padding="max_length",
260
- max_length=77,
261
- truncation=True,
262
- return_tensors="pt"
263
- ).to(self.device)
264
-
265
- prompt_embeds = self.models['text_encoder'](
266
- text_inputs.input_ids,
267
- output_hidden_states=False
268
- )
269
- pooled_prompt_embeds = prompt_embeds.pooler_output
270
- return pooled_prompt_embeds
271
 
 
 
 
 
 
272
 
273
- def generate(self, input_image, prompt="", guidance_scale=3.5,
274
- num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"):
275
- try:
276
- logger.info(f"Starting generation with prompt: {prompt}")
277
-
278
- if input_image is None:
279
- raise ValueError("No input image provided")
280
-
281
- if seed is not None:
282
- torch.manual_seed(seed)
283
- logger.info(f"Set random seed to: {seed}")
284
-
285
- # 1. 使用Qwen2VL处理图像
286
- logger.info("Processing input image with Qwen2VL...")
287
- qwen2_hidden_state, image_grid_thw = self.process_image(input_image)
288
- logger.info("Image processing completed")
289
-
290
- # 2. 计算文本嵌入
291
- logger.info("Computing text embeddings...")
292
- pooled_prompt_embeds = self.compute_text_embeddings(prompt)
293
- t5_prompt_embeds = self.compute_t5_text_embeddings(prompt)
294
- logger.info("Text embeddings computed")
295
-
296
- # 3. 将Transformer和VAE移到GPU
297
- logger.info("Moving Transformer and VAE to GPU...")
298
- self.models['transformer'] = self.models['transformer'].to(self.device)
299
- self.models['vae'] = self.models['vae'].to(self.device)
300
 
301
- # 更新pipeline中的模型引用
302
- self.pipeline.transformer = self.models['transformer']
303
- self.pipeline.vae = self.models['vae']
304
- logger.info("Models moved to GPU")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
- # 获取维度
307
- width, height = ASPECT_RATIOS[aspect_ratio]
308
- logger.info(f"Using dimensions: {width}x{height}")
309
 
310
- # 4. 生成图像
311
- try:
312
- logger.info("Starting image generation...")
313
- output_images = self.pipeline(
314
- prompt_embeds=qwen2_hidden_state.to(self.device).repeat(num_images, 1, 1),
315
- pooled_prompt_embeds=pooled_prompt_embeds,
316
- t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
317
- num_inference_steps=num_inference_steps,
318
- guidance_scale=guidance_scale,
319
- height=height,
320
- width=width,
321
- ).images
322
- logger.info("Image generation completed")
323
-
324
- # 5. 将Transformer和VAE移回CPU
325
- logger.info("Moving models back to CPU...")
326
- self.models['transformer'] = self.models['transformer'].cpu()
327
- self.models['vae'] = self.models['vae'].cpu()
328
- torch.cuda.empty_cache()
329
- logger.info("Models moved to CPU and GPU cache cleared")
330
-
331
- return output_images
332
-
333
- except Exception as e:
334
- raise RuntimeError(f"Error generating images: {str(e)}")
335
-
336
  except Exception as e:
337
- logger.error(f"Error during generation: {str(e)}")
338
- raise gr.Error(f"Generation failed: {str(e)}")
339
-
340
- # Initialize the interface
341
- interface = FluxInterface()
342
-
343
- def process_request(input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"):
344
- """主处理函数,直接处理用户请求"""
345
- try:
346
- if interface.models is None:
347
- interface.load_models()
348
 
349
- return interface.generate(
350
- input_image=input_image,
351
- prompt=prompt,
352
- guidance_scale=guidance_scale,
353
- num_inference_steps=num_inference_steps,
354
- num_images=num_images,
355
- seed=seed,
356
- aspect_ratio=aspect_ratio
357
- )
358
  except Exception as e:
359
  logger.error(f"Error during generation: {str(e)}")
360
  raise gr.Error(f"Generation failed: {str(e)}")
@@ -363,41 +240,22 @@ def process_request(input_image, prompt="", guidance_scale=3.5, num_inference_st
363
  with gr.Blocks(
364
  theme=gr.themes.Soft(),
365
  css="""
366
- .container {
367
- max-width: 1200px;
368
- margin: auto;
369
- padding: 0 20px;
370
- }
371
- .header {
372
- text-align: center;
373
- margin: 20px 0 40px 0;
374
- padding: 20px;
375
- background: #f7f7f7;
376
- border-radius: 12px;
377
- }
378
- .param-row {
379
- padding: 10px 0;
380
- }
381
- footer {
382
- margin-top: 40px;
383
- padding: 20px;
384
- border-top: 1px solid #eee;
385
- }
386
  """
387
  ) as demo:
388
  with gr.Column(elem_classes="container"):
389
- gr.Markdown(
390
- """
391
  <div class="header">
392
  # 🎨 Qwen2vl-Flux Image Variation Demo
393
  Generate creative variations of your images with optional text guidance
394
  </div>
395
- """
396
- )
397
 
398
  with gr.Row(equal_height=True):
399
  with gr.Column(scale=1):
400
- # Input Section
401
  input_image = gr.Image(
402
  label="Upload Your Image",
403
  type="pil",
@@ -419,48 +277,38 @@ with gr.Blocks(
419
  maximum=10,
420
  value=3.5,
421
  step=0.5,
422
- label="Guidance Scale",
423
- info="Higher values follow prompt more closely"
424
  )
425
  steps = gr.Slider(
426
  minimum=1,
427
- maximum=50,
428
  value=28,
429
  step=1,
430
- label="Sampling Steps",
431
- info="More steps = better quality but slower"
432
  )
433
 
434
  with gr.Row(elem_classes="param-row"):
435
  num_images = gr.Slider(
436
  minimum=1,
437
- maximum=4,
438
- value=2,
439
  step=1,
440
- label="Number of Images",
441
- info="Generate multiple variations at once"
442
  )
443
  seed = gr.Number(
444
  label="Random Seed",
445
  value=None,
446
- precision=0,
447
- info="Set for reproducible results"
448
  )
449
  aspect_ratio = gr.Radio(
450
  label="Aspect Ratio",
451
  choices=["1:1", "16:9", "9:16", "2.4:1", "3:4", "4:3"],
452
- value="1:1",
453
- info="Choose aspect ratio for generated images"
454
  )
455
 
456
- submit_btn = gr.Button(
457
- "🎨 Generate Variations",
458
- variant="primary",
459
- size="lg"
460
- )
461
 
462
  with gr.Column(scale=1):
463
- # Output Section
464
  output_gallery = gr.Gallery(
465
  label="Generated Variations",
466
  columns=2,
@@ -468,23 +316,11 @@ with gr.Blocks(
468
  height=700,
469
  object_fit="contain",
470
  show_label=True,
471
- allow_preview=True,
472
- preview=True
473
  )
474
- error_message = gr.Textbox(visible=False)
475
 
476
- with gr.Row(elem_classes="footer"):
477
- gr.Markdown("""
478
- ### Tips:
479
- - 📸 Upload any image to get started
480
- - 💡 Add an optional text prompt to guide the generation
481
- - 🎯 Adjust guidance scale to control prompt influence
482
- - ⚙️ Increase steps for higher quality
483
- - 🎲 Use seeds for reproducible results
484
- """)
485
-
486
  submit_btn.click(
487
- fn=process_request,
488
  inputs=[
489
  input_image,
490
  prompt,
@@ -493,15 +329,14 @@ with gr.Blocks(
493
  num_images,
494
  seed,
495
  aspect_ratio
496
- ],
497
  outputs=[output_gallery],
498
  show_progress=True
499
  )
500
 
501
- # Launch the app
502
  if __name__ == "__main__":
503
  demo.launch(
504
- server_name="0.0.0.0", # Listen on all network interfaces
505
- server_port=7860, # Use a specific port
506
- share=False, # Disable public URL sharing
507
  )
 
10
  import math
11
  import logging
12
  import sys
 
 
13
  from huggingface_hub import snapshot_download
14
+ from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
15
+ import huggingface_hub.spaces as spaces
16
 
17
  # 设置日志
18
  logging.basicConfig(
19
  level=logging.INFO,
20
  format='%(asctime)s - %(levelname)s - %(message)s',
21
+ handlers=[logging.StreamHandler(sys.stdout)]
 
 
22
  )
23
  logger = logging.getLogger(__name__)
24
 
25
  MODEL_ID = "Djrango/Qwen2vl-Flux"
26
  MODEL_CACHE_DIR = "model_cache"
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ dtype = torch.bfloat16
29
 
30
+ # 预下载模型
31
+ if not os.path.exists(MODEL_CACHE_DIR):
32
  logger.info("Starting model download...")
33
  try:
 
34
  snapshot_download(
35
  repo_id=MODEL_ID,
36
  local_dir=MODEL_CACHE_DIR,
37
  local_dir_use_symlinks=False
38
  )
 
39
  logger.info("Model download completed successfully")
40
  except Exception as e:
41
  logger.error(f"Error downloading models: {str(e)}")
42
  raise
43
 
44
+ # 加载所有模型到全局变量
45
+ logger.info("Loading models...")
46
+ tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer"))
47
+ text_encoder = CLIPTextModel.from_pretrained(
48
+ os.path.join(MODEL_CACHE_DIR, "flux/text_encoder")
49
+ ).to(dtype)
50
+
51
+ text_encoder_two = T5EncoderModel.from_pretrained(
52
+ os.path.join(MODEL_CACHE_DIR, "flux/text_encoder_2")
53
+ ).to(dtype)
54
+
55
+ tokenizer_two = T5TokenizerFast.from_pretrained(
56
+ os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2"))
57
+
58
+ vae = AutoencoderKL.from_pretrained(
59
+ os.path.join(MODEL_CACHE_DIR, "flux/vae")
60
+ ).to(dtype)
61
+
62
+ transformer = FluxTransformer2DModel.from_pretrained(
63
+ os.path.join(MODEL_CACHE_DIR, "flux/transformer")
64
+ ).to(dtype)
65
+
66
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
67
+ os.path.join(MODEL_CACHE_DIR, "flux/scheduler"),
68
+ shift=1
69
+ )
70
 
71
+ qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
72
+ os.path.join(MODEL_CACHE_DIR, "qwen2-vl")
73
+ ).to(dtype)
74
+
75
+ qwen2vl_processor = AutoProcessor.from_pretrained(
76
+ MODEL_ID,
77
+ subfolder="qwen2-vl",
78
+ min_pixels=256*28*28,
79
+ max_pixels=256*28*28
80
+ )
81
+
82
+ # 加载connector和embedder
83
+ connector = nn.Linear(3584, 4096).to(dtype)
84
+ connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
85
+ connector_state = torch.load(connector_path, map_location='cpu')
86
+ connector_state = {k.replace('module.', ''): v.to(dtype) for k, v in connector_state.items()}
87
+ connector.load_state_dict(connector_state)
88
+
89
+ t5_context_embedder = nn.Linear(4096, 3072).to(dtype)
90
+ t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
91
+ t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
92
+ t5_embedder_state = {k: v.to(dtype) for k, v in t5_embedder_state.items()}
93
+ t5_context_embedder.load_state_dict(t5_embedder_state)
94
+
95
+ # 创建pipeline
96
+ pipeline = FluxPipeline(
97
+ transformer=transformer,
98
+ scheduler=scheduler,
99
+ vae=vae,
100
+ text_encoder=text_encoder,
101
+ tokenizer=tokenizer,
102
+ )
103
+
104
+ # 设置所有模型为eval模式
105
+ for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl,
106
+ connector, t5_context_embedder]:
107
+ model.requires_grad_(False)
108
+ model.eval()
109
+
110
+ # Aspect ratio options
111
  ASPECT_RATIOS = {
112
  "1:1": (1024, 1024),
113
  "16:9": (1344, 768),
 
117
  "4:3": (1152, 896),
118
  }
119
 
120
+ def process_image(image):
121
+ """Process image with Qwen2VL model"""
122
+ try:
123
+ message = [
124
+ {
125
+ "role": "user",
126
+ "content": [
127
+ {"type": "image", "image": image},
128
+ {"type": "text", "text": "Describe this image."},
129
+ ]
130
+ }
131
+ ]
132
+ text = qwen2vl_processor.apply_chat_template(
133
+ message,
134
+ tokenize=False,
135
+ add_generation_prompt=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  )
137
 
138
+ with torch.no_grad():
139
+ inputs = qwen2vl_processor(
140
+ text=[text],
141
+ images=[image],
142
+ padding=True,
143
+ return_tensors="pt"
144
+ ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ output_hidden_state, image_token_mask, image_grid_thw = qwen2vl(**inputs)
147
+ image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
148
+ image_hidden_state = connector(image_hidden_state)
 
 
 
149
 
150
+ return (image_hidden_state, image_grid_thw)
151
 
152
+ except Exception as e:
153
+ logger.error(f"Error in process_image: {str(e)}")
154
+ raise
155
+
156
+ def compute_t5_text_embeddings(prompt):
157
+ """Compute T5 embeddings for text prompt"""
158
+ if prompt == "":
159
+ return None
 
 
 
 
 
 
 
 
 
 
160
 
161
+ text_inputs = tokenizer_two(
162
+ prompt,
163
+ padding="max_length",
164
+ max_length=256,
165
+ truncation=True,
166
+ return_tensors="pt"
167
+ ).to(device)
168
+
169
+ prompt_embeds = text_encoder_two(text_inputs.input_ids)[0]
170
+ prompt_embeds = t5_context_embedder(prompt_embeds)
171
+
172
+ return prompt_embeds
173
+
174
+ def compute_text_embeddings(prompt=""):
175
+ """Compute text embeddings for the prompt"""
176
+ with torch.no_grad():
177
+ text_inputs = tokenizer(
178
  prompt,
179
  padding="max_length",
180
+ max_length=77,
181
  truncation=True,
182
  return_tensors="pt"
183
+ ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ prompt_embeds = text_encoder(
186
+ text_inputs.input_ids,
187
+ output_hidden_states=False
188
+ )
189
+ return prompt_embeds.pooler_output
190
 
191
+ @spaces.GPU(duration=120) # 使用ZeroGPU装饰器
192
+ def generate_images(input_image, prompt="", guidance_scale=3.5,
193
+ num_inference_steps=28, num_images=1, seed=None, aspect_ratio="1:1"):
194
+ """Generate images using the pipeline"""
195
+ try:
196
+ logger.info(f"Starting generation with prompt: {prompt}")
197
+
198
+ if input_image is None:
199
+ raise ValueError("No input image provided")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ if seed is not None:
202
+ torch.manual_seed(seed)
203
+ logger.info(f"Set random seed to: {seed}")
204
+
205
+ # Process image with Qwen2VL
206
+ qwen2_hidden_state, image_grid_thw = process_image(input_image)
207
+
208
+ # Compute text embeddings
209
+ pooled_prompt_embeds = compute_text_embeddings(prompt)
210
+ t5_prompt_embeds = compute_t5_text_embeddings(prompt)
211
+
212
+ # Get dimensions
213
+ width, height = ASPECT_RATIOS[aspect_ratio]
214
+ logger.info(f"Using dimensions: {width}x{height}")
215
+
216
+ # Generate images
217
+ try:
218
+ logger.info("Starting image generation...")
219
+ output_images = pipeline(
220
+ prompt_embeds=qwen2_hidden_state.repeat(num_images, 1, 1),
221
+ pooled_prompt_embeds=pooled_prompt_embeds,
222
+ t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
223
+ num_inference_steps=num_inference_steps,
224
+ guidance_scale=guidance_scale,
225
+ height=height,
226
+ width=width,
227
+ ).images
228
+ logger.info("Image generation completed")
229
 
230
+ return output_images
 
 
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  except Exception as e:
233
+ raise RuntimeError(f"Error generating images: {str(e)}")
 
 
 
 
 
 
 
 
 
 
234
 
 
 
 
 
 
 
 
 
 
235
  except Exception as e:
236
  logger.error(f"Error during generation: {str(e)}")
237
  raise gr.Error(f"Generation failed: {str(e)}")
 
240
  with gr.Blocks(
241
  theme=gr.themes.Soft(),
242
  css="""
243
+ .container { max-width: 1200px; margin: auto; padding: 0 20px; }
244
+ .header { text-align: center; margin: 20px 0 40px 0; padding: 20px; background: #f7f7f7; border-radius: 12px; }
245
+ .param-row { padding: 10px 0; }
246
+ footer { margin-top: 40px; padding: 20px; border-top: 1px solid #eee; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  """
248
  ) as demo:
249
  with gr.Column(elem_classes="container"):
250
+ gr.Markdown("""
 
251
  <div class="header">
252
  # 🎨 Qwen2vl-Flux Image Variation Demo
253
  Generate creative variations of your images with optional text guidance
254
  </div>
255
+ """)
 
256
 
257
  with gr.Row(equal_height=True):
258
  with gr.Column(scale=1):
 
259
  input_image = gr.Image(
260
  label="Upload Your Image",
261
  type="pil",
 
277
  maximum=10,
278
  value=3.5,
279
  step=0.5,
280
+ label="Guidance Scale"
 
281
  )
282
  steps = gr.Slider(
283
  minimum=1,
284
+ maximum=30,
285
  value=28,
286
  step=1,
287
+ label="Sampling Steps"
 
288
  )
289
 
290
  with gr.Row(elem_classes="param-row"):
291
  num_images = gr.Slider(
292
  minimum=1,
293
+ maximum=2,
294
+ value=1, # 默认改为1
295
  step=1,
296
+ label="Number of Images"
 
297
  )
298
  seed = gr.Number(
299
  label="Random Seed",
300
  value=None,
301
+ precision=0
 
302
  )
303
  aspect_ratio = gr.Radio(
304
  label="Aspect Ratio",
305
  choices=["1:1", "16:9", "9:16", "2.4:1", "3:4", "4:3"],
306
+ value="1:1"
 
307
  )
308
 
309
+ submit_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
 
 
 
 
310
 
311
  with gr.Column(scale=1):
 
312
  output_gallery = gr.Gallery(
313
  label="Generated Variations",
314
  columns=2,
 
316
  height=700,
317
  object_fit="contain",
318
  show_label=True,
319
+ allow_preview=True
 
320
  )
 
321
 
 
 
 
 
 
 
 
 
 
 
322
  submit_btn.click(
323
+ fn=generate_images,
324
  inputs=[
325
  input_image,
326
  prompt,
 
329
  num_images,
330
  seed,
331
  aspect_ratio
332
+ ],
333
  outputs=[output_gallery],
334
  show_progress=True
335
  )
336
 
 
337
  if __name__ == "__main__":
338
  demo.launch(
339
+ server_name="0.0.0.0",
340
+ server_port=7860,
341
+ share=False
342
  )