erwold commited on
Commit
9590121
·
1 Parent(s): f53a34a
Files changed (1) hide show
  1. app.py +52 -13
app.py CHANGED
@@ -41,27 +41,29 @@ if not os.path.exists(MODEL_CACHE_DIR):
41
  logger.error(f"Error downloading models: {str(e)}")
42
  raise
43
 
44
- # 加载所有模型到全局变量
45
- logger.info("Loading models...")
46
  tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer"))
47
  text_encoder = CLIPTextModel.from_pretrained(
48
  os.path.join(MODEL_CACHE_DIR, "flux/text_encoder")
49
- ).to(dtype)
50
 
51
  text_encoder_two = T5EncoderModel.from_pretrained(
52
  os.path.join(MODEL_CACHE_DIR, "flux/text_encoder_2")
53
- ).to(dtype)
54
 
55
  tokenizer_two = T5TokenizerFast.from_pretrained(
56
  os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2"))
57
 
 
 
58
  vae = AutoencoderKL.from_pretrained(
59
  os.path.join(MODEL_CACHE_DIR, "flux/vae")
60
- ).to(dtype)
61
 
62
  transformer = FluxTransformer2DModel.from_pretrained(
63
  os.path.join(MODEL_CACHE_DIR, "flux/transformer")
64
- ).to(dtype)
65
 
66
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
67
  os.path.join(MODEL_CACHE_DIR, "flux/scheduler"),
@@ -70,7 +72,7 @@ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
70
 
71
  qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
72
  os.path.join(MODEL_CACHE_DIR, "qwen2-vl")
73
- ).to(dtype)
74
 
75
  qwen2vl_processor = AutoProcessor.from_pretrained(
76
  MODEL_ID,
@@ -79,20 +81,20 @@ qwen2vl_processor = AutoProcessor.from_pretrained(
79
  max_pixels=256*28*28
80
  )
81
 
82
- # 加载connector和embedder
83
- connector = nn.Linear(3584, 4096).to(dtype)
84
  connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
85
  connector_state = torch.load(connector_path, map_location='cpu')
86
  connector_state = {k.replace('module.', ''): v.to(dtype) for k, v in connector_state.items()}
87
  connector.load_state_dict(connector_state)
88
 
89
- t5_context_embedder = nn.Linear(4096, 3072).to(dtype)
90
  t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
91
  t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
92
  t5_embedder_state = {k: v.to(dtype) for k, v in t5_embedder_state.items()}
93
  t5_context_embedder.load_state_dict(t5_embedder_state)
94
 
95
- # 创建pipeline
96
  pipeline = FluxPipeline(
97
  transformer=transformer,
98
  scheduler=scheduler,
@@ -120,6 +122,11 @@ ASPECT_RATIOS = {
120
  def process_image(image):
121
  """Process image with Qwen2VL model"""
122
  try:
 
 
 
 
 
123
  message = [
124
  {
125
  "role": "user",
@@ -147,7 +154,16 @@ def process_image(image):
147
  image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
148
  image_hidden_state = connector(image_hidden_state)
149
 
150
- return (image_hidden_state, image_grid_thw)
 
 
 
 
 
 
 
 
 
151
 
152
  except Exception as e:
153
  logger.error(f"Error in process_image: {str(e)}")
@@ -167,8 +183,14 @@ def compute_t5_text_embeddings(prompt):
167
  ).to(device)
168
 
169
  prompt_embeds = text_encoder_two(text_inputs.input_ids)[0]
 
 
 
170
  prompt_embeds = t5_context_embedder(prompt_embeds)
171
 
 
 
 
172
  return prompt_embeds
173
 
174
  def compute_text_embeddings(prompt=""):
@@ -216,8 +238,18 @@ def generate_images(input_image, prompt="", guidance_scale=3.5,
216
  # Generate images
217
  try:
218
  logger.info("Starting image generation...")
 
 
 
 
 
 
 
 
 
 
219
  output_images = pipeline(
220
- prompt_embeds=qwen2_hidden_state.repeat(num_images, 1, 1),
221
  pooled_prompt_embeds=pooled_prompt_embeds,
222
  t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
223
  num_inference_steps=num_inference_steps,
@@ -225,8 +257,15 @@ def generate_images(input_image, prompt="", guidance_scale=3.5,
225
  height=height,
226
  width=width,
227
  ).images
 
228
  logger.info("Image generation completed")
229
 
 
 
 
 
 
 
230
  return output_images
231
 
232
  except Exception as e:
 
41
  logger.error(f"Error downloading models: {str(e)}")
42
  raise
43
 
44
+ # 加载小模型到 GPU
45
+ logger.info("Loading small models to GPU...")
46
  tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer"))
47
  text_encoder = CLIPTextModel.from_pretrained(
48
  os.path.join(MODEL_CACHE_DIR, "flux/text_encoder")
49
+ ).to(dtype).to(device)
50
 
51
  text_encoder_two = T5EncoderModel.from_pretrained(
52
  os.path.join(MODEL_CACHE_DIR, "flux/text_encoder_2")
53
+ ).to(dtype).to(device)
54
 
55
  tokenizer_two = T5TokenizerFast.from_pretrained(
56
  os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2"))
57
 
58
+ # 大模型初始加载到 CPU
59
+ logger.info("Loading large models to CPU...")
60
  vae = AutoencoderKL.from_pretrained(
61
  os.path.join(MODEL_CACHE_DIR, "flux/vae")
62
+ ).to(dtype).cpu()
63
 
64
  transformer = FluxTransformer2DModel.from_pretrained(
65
  os.path.join(MODEL_CACHE_DIR, "flux/transformer")
66
+ ).to(dtype).cpu()
67
 
68
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
69
  os.path.join(MODEL_CACHE_DIR, "flux/scheduler"),
 
72
 
73
  qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
74
  os.path.join(MODEL_CACHE_DIR, "qwen2-vl")
75
+ ).to(dtype).cpu()
76
 
77
  qwen2vl_processor = AutoProcessor.from_pretrained(
78
  MODEL_ID,
 
81
  max_pixels=256*28*28
82
  )
83
 
84
+ # 加载 connector embedder 到 CPU
85
+ connector = nn.Linear(3584, 4096).to(dtype).cpu()
86
  connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
87
  connector_state = torch.load(connector_path, map_location='cpu')
88
  connector_state = {k.replace('module.', ''): v.to(dtype) for k, v in connector_state.items()}
89
  connector.load_state_dict(connector_state)
90
 
91
+ t5_context_embedder = nn.Linear(4096, 3072).to(dtype).cpu()
92
  t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
93
  t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
94
  t5_embedder_state = {k: v.to(dtype) for k, v in t5_embedder_state.items()}
95
  t5_context_embedder.load_state_dict(t5_embedder_state)
96
 
97
+ # 创建pipeline (先用CPU上的模型)
98
  pipeline = FluxPipeline(
99
  transformer=transformer,
100
  scheduler=scheduler,
 
122
  def process_image(image):
123
  """Process image with Qwen2VL model"""
124
  try:
125
+ # 将 Qwen2VL 相关模型移到 GPU
126
+ logger.info("Moving Qwen2VL models to GPU...")
127
+ qwen2vl.to(device)
128
+ connector.to(device)
129
+
130
  message = [
131
  {
132
  "role": "user",
 
154
  image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
155
  image_hidden_state = connector(image_hidden_state)
156
 
157
+ # 保存结果到 CPU
158
+ result = (image_hidden_state.cpu(), image_grid_thw)
159
+
160
+ # 将模型移回 CPU 并清理显存
161
+ logger.info("Moving Qwen2VL models back to CPU...")
162
+ qwen2vl.cpu()
163
+ connector.cpu()
164
+ torch.cuda.empty_cache()
165
+
166
+ return result
167
 
168
  except Exception as e:
169
  logger.error(f"Error in process_image: {str(e)}")
 
183
  ).to(device)
184
 
185
  prompt_embeds = text_encoder_two(text_inputs.input_ids)[0]
186
+
187
+ # 将 t5_context_embedder 移到 GPU
188
+ t5_context_embedder.to(device)
189
  prompt_embeds = t5_context_embedder(prompt_embeds)
190
 
191
+ # 将 t5_context_embedder 移回 CPU
192
+ t5_context_embedder.cpu()
193
+
194
  return prompt_embeds
195
 
196
  def compute_text_embeddings(prompt=""):
 
238
  # Generate images
239
  try:
240
  logger.info("Starting image generation...")
241
+
242
+ # 将 Transformer 和 VAE 移到 GPU
243
+ logger.info("Moving Transformer and VAE to GPU...")
244
+ transformer.to(device)
245
+ vae.to(device)
246
+
247
+ # 更新 pipeline 中的模型引用
248
+ pipeline.transformer = transformer
249
+ pipeline.vae = vae
250
+
251
  output_images = pipeline(
252
+ prompt_embeds=qwen2_hidden_state.to(device).repeat(num_images, 1, 1),
253
  pooled_prompt_embeds=pooled_prompt_embeds,
254
  t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
255
  num_inference_steps=num_inference_steps,
 
257
  height=height,
258
  width=width,
259
  ).images
260
+
261
  logger.info("Image generation completed")
262
 
263
+ # 将 Transformer 和 VAE 移回 CPU
264
+ logger.info("Moving models back to CPU...")
265
+ transformer.cpu()
266
+ vae.cpu()
267
+ torch.cuda.empty_cache()
268
+
269
  return output_images
270
 
271
  except Exception as e: