gensym commited on
Commit
8e0b0a8
·
1 Parent(s): 7be00f2
Files changed (1) hide show
  1. app.py +37 -37
app.py CHANGED
@@ -55,8 +55,8 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
55
  },
56
  'trial_id': trial_id,
57
  }
58
-
59
-
60
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
61
  gs = Gaussian(
62
  aabb=state['gaussian']['aabb'],
@@ -71,12 +71,12 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
71
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
72
  gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
73
  gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
74
-
75
  mesh = edict(
76
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
77
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
78
  )
79
-
80
  return gs, mesh, state['trial_id']
81
 
82
 
@@ -159,36 +159,36 @@ with gr.Blocks() as demo:
159
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
160
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
161
  """)
162
-
163
- with gr.Row():
164
- with gr.Column():
165
- image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
166
-
167
- with gr.Accordion(label="Generation Settings", open=False):
168
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
169
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
170
- gr.Markdown("Stage 1: Sparse Structure Generation")
171
- with gr.Row():
172
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
173
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
174
- gr.Markdown("Stage 2: Structured Latent Generation")
175
- with gr.Row():
176
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
177
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
178
-
179
- generate_btn = gr.Button("Generate")
180
-
181
- with gr.Accordion(label="GLB Extraction Settings", open=False):
182
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
183
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
184
-
185
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
186
-
187
- with gr.Column():
188
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
189
- model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
190
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
191
-
192
  trial_id = gr.Textbox(visible=False)
193
  output_buf = gr.State()
194
 
@@ -244,7 +244,7 @@ with gr.Blocks() as demo:
244
  deactivate_button,
245
  outputs=[download_glb],
246
  )
247
-
248
 
249
  # Cleans up the temporary directory every 10 minutes
250
  import threading
@@ -258,10 +258,10 @@ def cleanup_tmp_dir():
258
  if time.time() - os.path.getmtime(os.path.join(TMP_DIR, file)) > 600:
259
  os.remove(os.path.join(TMP_DIR, file))
260
  time.sleep(600)
261
-
262
  cleanup_thread = threading.Thread(target=cleanup_tmp_dir)
263
  cleanup_thread.start()
264
-
265
 
266
  # Launch the Gradio app
267
  if __name__ == "__main__":
 
55
  },
56
  'trial_id': trial_id,
57
  }
58
+
59
+
60
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
61
  gs = Gaussian(
62
  aabb=state['gaussian']['aabb'],
 
71
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
72
  gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
73
  gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
74
+
75
  mesh = edict(
76
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
77
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
78
  )
79
+
80
  return gs, mesh, state['trial_id']
81
 
82
 
 
159
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
160
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
161
  """)
162
+
163
+ # with gr.Row():
164
+ # with gr.Column():
165
+ # image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
166
+
167
+ # with gr.Accordion(label="Generation Settings", open=False):
168
+ # seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
169
+ # randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
170
+ # gr.Markdown("Stage 1: Sparse Structure Generation")
171
+ # with gr.Row():
172
+ # ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
173
+ # ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
174
+ # gr.Markdown("Stage 2: Structured Latent Generation")
175
+ # with gr.Row():
176
+ # slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
177
+ # slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
178
+
179
+ # generate_btn = gr.Button("Generate")
180
+
181
+ # with gr.Accordion(label="GLB Extraction Settings", open=False):
182
+ # mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
183
+ # texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
184
+
185
+ # extract_glb_btn = gr.Button("Extract GLB", interactive=False)
186
+
187
+ # with gr.Column():
188
+ # video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
189
+ # model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
190
+ # download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
191
+
192
  trial_id = gr.Textbox(visible=False)
193
  output_buf = gr.State()
194
 
 
244
  deactivate_button,
245
  outputs=[download_glb],
246
  )
247
+
248
 
249
  # Cleans up the temporary directory every 10 minutes
250
  import threading
 
258
  if time.time() - os.path.getmtime(os.path.join(TMP_DIR, file)) > 600:
259
  os.remove(os.path.join(TMP_DIR, file))
260
  time.sleep(600)
261
+
262
  cleanup_thread = threading.Thread(target=cleanup_tmp_dir)
263
  cleanup_thread.start()
264
+
265
 
266
  # Launch the Gradio app
267
  if __name__ == "__main__":