erwold commited on
Commit
76678b6
·
1 Parent(s): bb47725

Initial Commit

Browse files
Files changed (1) hide show
  1. app.py +120 -82
app.py CHANGED
@@ -13,7 +13,6 @@ import sys
13
 
14
  from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
15
  from huggingface_hub import snapshot_download
16
- import spaces
17
 
18
  # 设置日志
19
  logging.basicConfig(
@@ -78,42 +77,53 @@ class FluxInterface:
78
  return
79
 
80
  logger.info("Starting model loading...")
81
- # 3. 显式设置 PyTorch 缓存分配器的行为
82
- torch.cuda.set_per_process_memory_fraction(0.95) # 允许使用95%的显存
83
- torch.cuda.max_memory_allocated = lambda *args, **kwargs: 0 # 忽略已分配内存的限制
84
 
85
- # Load FLUX components
86
  tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer"))
87
- text_encoder = CLIPTextModel.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/text_encoder")).to(self.dtype).to(self.device)
88
- text_encoder_two = T5EncoderModel.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/text_encoder_2")).to(self.dtype).to(self.device)
89
- tokenizer_two = T5TokenizerFast.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2"))
90
 
91
- # Load VAE and transformer
92
- vae = AutoencoderKL.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/vae")).to(self.dtype).to(self.device)
93
- transformer = FluxTransformer2DModel.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/transformer")).to(self.dtype).to(self.device)
94
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/scheduler"), shift=1)
95
 
96
- # Load Qwen2VL components
97
- qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(os.path.join(MODEL_CACHE_DIR, "qwen2-vl")).to(self.dtype).to(self.device)
98
 
99
- # 加载 connector
100
- connector = Qwen2Connector().to(self.dtype).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
102
  connector_state = torch.load(connector_path, map_location='cpu')
103
- connector_state = {k: v.to(self.dtype) for k, v in connector_state.items()}
104
  connector.load_state_dict(connector_state)
105
- connector = connector.to(self.device)
106
-
107
- # 加载 T5 embedder
108
- self.t5_context_embedder = nn.Linear(4096, 3072).to(self.dtype).to(self.device)
109
  t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
110
  t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
111
- t5_embedder_state = {k: v.to(self.dtype) for k, v in t5_embedder_state.items()}
112
  self.t5_context_embedder.load_state_dict(t5_embedder_state)
113
- self.t5_context_embedder = self.t5_context_embedder.to(self.device)
114
 
115
- # Set models to eval mode
116
- for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, self.t5_context_embedder]:
 
117
  model.requires_grad_(False)
118
  model.eval()
119
 
@@ -133,9 +143,9 @@ class FluxInterface:
133
 
134
  # Initialize processor and pipeline
135
  self.qwen2vl_processor = AutoProcessor.from_pretrained(
136
- self.MODEL_ID,
137
  subfolder="qwen2-vl",
138
- min_pixels=256*28*28,
139
  max_pixels=256*28*28
140
  )
141
 
@@ -145,7 +155,61 @@ class FluxInterface:
145
  vae=vae,
146
  text_encoder=text_encoder,
147
  tokenizer=tokenizer,
148
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  def resize_image(self, img, max_pixels=1050000):
151
  if not isinstance(img, Image.Image):
@@ -163,28 +227,7 @@ class FluxInterface:
163
  img = img.resize((new_width, new_height), Image.LANCZOS)
164
 
165
  return img
166
-
167
- # [Previous methods remain unchanged...]
168
- def process_image(self, image):
169
- message = [
170
- {
171
- "role": "user",
172
- "content": [
173
- {"type": "image", "image": image},
174
- {"type": "text", "text": "Describe this image."},
175
- ]
176
- }
177
- ]
178
- text = self.qwen2vl_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
179
-
180
- with torch.no_grad():
181
- inputs = self.qwen2vl_processor(text=[text], images=[image], padding=True, return_tensors="pt").to(self.device)
182
- output_hidden_state, image_token_mask, image_grid_thw = self.models['qwen2vl'](**inputs)
183
- image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
184
- image_hidden_state = self.models['connector'](image_hidden_state)
185
-
186
- return image_hidden_state, image_grid_thw
187
-
188
  def compute_t5_text_embeddings(self, prompt):
189
  """Compute T5 embeddings for text prompt"""
190
  if prompt == "":
@@ -222,50 +265,39 @@ class FluxInterface:
222
 
223
  return pooled_prompt_embeds
224
 
225
- @spaces.GPU(duration=120) # 300秒的 GPU 使用时间
226
- def generate(self, input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"):
227
  try:
228
- logger.info(f"Starting generation with prompt: {prompt}, guidance_scale: {guidance_scale}, steps: {num_inference_steps}")
229
 
230
  if input_image is None:
231
  raise ValueError("No input image provided")
232
 
233
  if seed is not None:
234
  torch.manual_seed(seed)
235
- logger.info(f"Set random seed to: {seed}")
236
-
237
- self.load_models()
238
- logger.info("Models loaded successfully")
239
 
240
- # Get dimensions from aspect ratio
241
- if aspect_ratio not in ASPECT_RATIOS:
242
- raise ValueError(f"Invalid aspect ratio. Choose from {list(ASPECT_RATIOS.keys())}")
243
- width, height = ASPECT_RATIOS[aspect_ratio]
244
- logger.info(f"Using dimensions: {width}x{height}")
245
 
246
- # Process input image
247
- try:
248
- input_image = self.resize_image(input_image)
249
- logger.info(f"Input image resized to: {input_image.size}")
250
- qwen2_hidden_state, image_grid_thw = self.process_image(input_image)
251
- logger.info("Input image processed successfully")
252
- except Exception as e:
253
- raise RuntimeError(f"Error processing input image: {str(e)}")
254
 
255
- try:
256
- pooled_prompt_embeds = self.compute_text_embeddings("")
257
- logger.info("Base text embeddings computed")
258
-
259
- # Get T5 embeddings if prompt is provided
260
- t5_prompt_embeds = self.compute_t5_text_embeddings(prompt)
261
- logger.info("T5 prompt embeddings computed")
262
- except Exception as e:
263
- raise RuntimeError(f"Error computing embeddings: {str(e)}")
264
 
265
- # Generate images
 
 
 
266
  try:
267
  output_images = self.pipeline(
268
- prompt_embeds=qwen2_hidden_state.repeat(num_images, 1, 1),
269
  pooled_prompt_embeds=pooled_prompt_embeds,
270
  t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
271
  num_inference_steps=num_inference_steps,
@@ -274,10 +306,16 @@ class FluxInterface:
274
  width=width,
275
  ).images
276
 
277
- logger.info("Images generated successfully")
 
 
 
 
278
  return output_images
 
279
  except Exception as e:
280
  raise RuntimeError(f"Error generating images: {str(e)}")
 
281
  except Exception as e:
282
  logger.error(f"Error during generation: {str(e)}")
283
  raise gr.Error(f"Generation failed: {str(e)}")
 
13
 
14
  from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
15
  from huggingface_hub import snapshot_download
 
16
 
17
  # 设置日志
18
  logging.basicConfig(
 
77
  return
78
 
79
  logger.info("Starting model loading...")
 
 
 
80
 
81
+ # 1. 首先加载较小的模型到GPU
82
  tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer"))
83
+ text_encoder = CLIPTextModel.from_pretrained(
84
+ os.path.join(MODEL_CACHE_DIR, "flux/text_encoder")
85
+ ).to(self.dtype).to(self.device)
86
 
87
+ text_encoder_two = T5EncoderModel.from_pretrained(
88
+ os.path.join(MODEL_CACHE_DIR, "flux/text_encoder_2")
89
+ ).to(self.dtype).to(self.device)
 
90
 
91
+ tokenizer_two = T5TokenizerFast.from_pretrained(
92
+ os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2"))
93
 
94
+ # 2. 将大模型初始加载到CPU
95
+ vae = AutoencoderKL.from_pretrained(
96
+ os.path.join(MODEL_CACHE_DIR, "flux/vae")
97
+ ).to(torch.float32).cpu()
98
+
99
+ transformer = FluxTransformer2DModel.from_pretrained(
100
+ os.path.join(MODEL_CACHE_DIR, "flux/transformer")
101
+ ).to(torch.float32).cpu()
102
+
103
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
104
+ os.path.join(MODEL_CACHE_DIR, "flux/scheduler"),
105
+ shift=1
106
+ )
107
+
108
+ # 3. Qwen2VL初始加载到CPU
109
+ qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
110
+ os.path.join(MODEL_CACHE_DIR, "qwen2-vl")
111
+ ).to(torch.float32).cpu()
112
+
113
+ # 4. 加载connector和embedder到CPU
114
+ connector = Qwen2Connector().to(torch.float32).cpu()
115
  connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
116
  connector_state = torch.load(connector_path, map_location='cpu')
 
117
  connector.load_state_dict(connector_state)
118
+
119
+ self.t5_context_embedder = nn.Linear(4096, 3072).to(torch.float32).cpu()
 
 
120
  t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
121
  t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
 
122
  self.t5_context_embedder.load_state_dict(t5_embedder_state)
 
123
 
124
+ # 5. 设置所有模型为eval模式
125
+ for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl,
126
+ connector, self.t5_context_embedder]:
127
  model.requires_grad_(False)
128
  model.eval()
129
 
 
143
 
144
  # Initialize processor and pipeline
145
  self.qwen2vl_processor = AutoProcessor.from_pretrained(
146
+ self.MODEL_ID,
147
  subfolder="qwen2-vl",
148
+ min_pixels=256*28*28,
149
  max_pixels=256*28*28
150
  )
151
 
 
155
  vae=vae,
156
  text_encoder=text_encoder,
157
  tokenizer=tokenizer,
158
+ )
159
+
160
+ def move_to_device(self, model, device):
161
+ """Helper function to move model to specified device"""
162
+ if hasattr(model, 'to'):
163
+ return model.to(device)
164
+ return model
165
+
166
+ def process_image(self, image):
167
+ """Process image with Qwen2VL model"""
168
+ try:
169
+ # 1. 将Qwen2VL相关模型移到GPU
170
+ self.models['qwen2vl'] = self.move_to_device(self.models['qwen2vl'], self.device)
171
+ self.models['connector'] = self.move_to_device(self.models['connector'], self.device)
172
+
173
+ message = [
174
+ {
175
+ "role": "user",
176
+ "content": [
177
+ {"type": "image", "image": image},
178
+ {"type": "text", "text": "Describe this image."},
179
+ ]
180
+ }
181
+ ]
182
+ text = self.qwen2vl_processor.apply_chat_template(
183
+ message,
184
+ tokenize=False,
185
+ add_generation_prompt=True
186
+ )
187
+
188
+ with torch.no_grad():
189
+ inputs = self.qwen2vl_processor(
190
+ text=[text],
191
+ images=[image],
192
+ padding=True,
193
+ return_tensors="pt"
194
+ ).to(self.device)
195
+
196
+ output_hidden_state, image_token_mask, image_grid_thw = self.models['qwen2vl'](**inputs)
197
+ image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
198
+ image_hidden_state = self.models['connector'](image_hidden_state)
199
+
200
+ # 保存结果到CPU
201
+ result = (image_hidden_state.cpu(), image_grid_thw)
202
+
203
+ # 2. 将Qwen2VL相关模型移回CPU以释放显存
204
+ self.models['qwen2vl'] = self.move_to_device(self.models['qwen2vl'], 'cpu')
205
+ self.models['connector'] = self.move_to_device(self.models['connector'], 'cpu')
206
+ torch.cuda.empty_cache()
207
+
208
+ return result
209
+
210
+ except Exception as e:
211
+ logger.error(f"Error in process_image: {str(e)}")
212
+ raise
213
 
214
  def resize_image(self, img, max_pixels=1050000):
215
  if not isinstance(img, Image.Image):
 
227
  img = img.resize((new_width, new_height), Image.LANCZOS)
228
 
229
  return img
230
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  def compute_t5_text_embeddings(self, prompt):
232
  """Compute T5 embeddings for text prompt"""
233
  if prompt == "":
 
265
 
266
  return pooled_prompt_embeds
267
 
268
+ def generate(self, input_image, prompt="", guidance_scale=3.5,
269
+ num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"):
270
  try:
271
+ logger.info(f"Starting generation with prompt: {prompt}")
272
 
273
  if input_image is None:
274
  raise ValueError("No input image provided")
275
 
276
  if seed is not None:
277
  torch.manual_seed(seed)
278
+
279
+ # 1. 使用Qwen2VL处理图像
280
+ qwen2_hidden_state, image_grid_thw = self.process_image(input_image)
 
281
 
282
+ # 2. 计算文本嵌入
283
+ pooled_prompt_embeds = self.compute_text_embeddings("")
284
+ t5_prompt_embeds = self.compute_t5_text_embeddings(prompt)
 
 
285
 
286
+ # 3. 将Transformer和VAE移到GPU
287
+ self.models['transformer'] = self.move_to_device(self.models['transformer'], self.device)
288
+ self.models['vae'] = self.move_to_device(self.models['vae'], self.device)
 
 
 
 
 
289
 
290
+ # 更新pipeline中的模型
291
+ self.pipeline.transformer = self.models['transformer']
292
+ self.pipeline.vae = self.models['vae']
 
 
 
 
 
 
293
 
294
+ # 获取维度
295
+ width, height = ASPECT_RATIOS[aspect_ratio]
296
+
297
+ # 4. 生成图像
298
  try:
299
  output_images = self.pipeline(
300
+ prompt_embeds=qwen2_hidden_state.to(self.device).repeat(num_images, 1, 1),
301
  pooled_prompt_embeds=pooled_prompt_embeds,
302
  t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
303
  num_inference_steps=num_inference_steps,
 
306
  width=width,
307
  ).images
308
 
309
+ # 5. 将Transformer和VAE移回CPU
310
+ self.models['transformer'] = self.move_to_device(self.models['transformer'], 'cpu')
311
+ self.models['vae'] = self.move_to_device(self.models['vae'], 'cpu')
312
+ torch.cuda.empty_cache()
313
+
314
  return output_images
315
+
316
  except Exception as e:
317
  raise RuntimeError(f"Error generating images: {str(e)}")
318
+
319
  except Exception as e:
320
  logger.error(f"Error during generation: {str(e)}")
321
  raise gr.Error(f"Generation failed: {str(e)}")