File size: 10,988 Bytes
98ade97 db6a3b7 98ade97 3057b36 98ade97 db6a3b7 690b53e 98ade97 db6a3b7 98ade97 9880f3d 98ade97 7d475c1 98ade97 db6a3b7 98ade97 db6a3b7 98ade97 9880f3d 98ade97 db6a3b7 eee56a3 98ade97 db6a3b7 98ade97 9880f3d 98ade97 db6a3b7 bd46f72 a898014 bd46f72 d7b1815 ae973c3 a898014 db6a3b7 a898014 db6a3b7 a898014 ae973c3 db894f7 a898014 db6a3b7 a898014 9880f3d a898014 9880f3d 8e0b0a8 9880f3d 8e0b0a8 9880f3d 8e0b0a8 a898014 9880f3d 3057b36 a898014 db6a3b7 a898014 bd46f72 db6a3b7 bd46f72 ae973c3 db894f7 a898014 db894f7 bd46f72 7d475c1 15fe7bc a898014 db6a3b7 7d475c1 a898014 9880f3d db6a3b7 3057b36 9880f3d db6a3b7 9880f3d db6a3b7 a898014 690b53e a898014 db6a3b7 6e47b5e db6a3b7 7d475c1 8e0b0a8 c60c59d 6e47b5e c60c59d 6e47b5e a085c25 4e2347a 8e0b0a8 6a6e1c3 8e0b0a8 6a6e1c3 8e0b0a8 6a6e1c3 8e0b0a8 d35de72 6bef8b1 b77fe63 6bef8b1 6a3f855 b77fe63 d35de72 8e0b0a8 efbad73 db6a3b7 ae973c3 6e47b5e 65146bb 6e47b5e 9067242 d66f772 6a3f855 8e0b0a8 db6a3b7 edf8619 8e0b0a8 edf8619 8e0b0a8 ae973c3 ee3dbc1 ae973c3 ba1ed50 a8e79b5 915ce08 639772e a8e79b5 ae973c3 db6a3b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 |
print("importing gradio")
import gradio as gr
print("importing spaces")
import spaces
print("importing os")
import os
os.environ['SPCONV_ALGO'] = 'native'
print("importing typing")
from typing import *
print("importing torch")
import torch
print("importing numpy")
import numpy as np
print("importing imageio")
import imageio
print("importing uuid")
import uuid
print("importing easydict")
from easydict import EasyDict as edict
print("importing PIL")
from PIL import Image
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
print("importing trellis image to 3d pipeline")
from trellis.pipelines import TrellisImageTo3DPipeline
print("importing trellis representations")
from trellis.representations import Gaussian, MeshExtractResult
print("importing trellis utils")
from trellis.utils import render_utils, postprocessing_utils
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = "/tmp/Trellis-demo"
os.makedirs(TMP_DIR, exist_ok=True)
@spaces.GPU
def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
"""
Preprocess the input image.
Args:
image (Image.Image): The input image.
Returns:
str: uuid of the trial.
Image.Image: The preprocessed image.
"""
trial_id = str(uuid.uuid4())
preload()
processed_image = pipeline.preprocess_image(image)
processed_image.save(f"{TMP_DIR}/{trial_id}.png")
return trial_id, processed_image
def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
return {
'gaussian': {
**gs.init_params,
'_xyz': gs._xyz.cpu().numpy(),
'_features_dc': gs._features_dc.cpu().numpy(),
'_scaling': gs._scaling.cpu().numpy(),
'_rotation': gs._rotation.cpu().numpy(),
'_opacity': gs._opacity.cpu().numpy(),
},
'mesh': {
'vertices': mesh.vertices.cpu().numpy(),
'faces': mesh.faces.cpu().numpy(),
},
'trial_id': trial_id,
}
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
gs = Gaussian(
aabb=state['gaussian']['aabb'],
sh_degree=state['gaussian']['sh_degree'],
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
scaling_bias=state['gaussian']['scaling_bias'],
opacity_bias=state['gaussian']['opacity_bias'],
scaling_activation=state['gaussian']['scaling_activation'],
)
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
mesh = edict(
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
)
return gs, mesh, state['trial_id']
@spaces.GPU
def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]:
"""
Convert an image to a 3D model.
Args:
trial_id (str): The uuid of the trial.
seed (int): The random seed.
randomize_seed (bool): Whether to randomize the seed.
ss_guidance_strength (float): The guidance strength for sparse structure generation.
ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
slat_guidance_strength (float): The guidance strength for structured latent generation.
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
Returns:
dict: The information of the generated 3D model.
str: The path to the video of the 3D model.
"""
if randomize_seed:
seed = np.random.randint(0, MAX_SEED)
preload()
outputs = pipeline.run(
Image.open(f"{TMP_DIR}/{trial_id}.png"),
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
trial_id = uuid.uuid4()
video_path = f"{TMP_DIR}/{trial_id}.mp4"
os.makedirs(os.path.dirname(video_path), exist_ok=True)
imageio.mimsave(video_path, video, fps=15)
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
return state, video_path
@spaces.GPU
def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
"""
Extract a GLB file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
mesh_simplify (float): The mesh simplification factor.
texture_size (int): The texture resolution.
Returns:
str: The path to the extracted GLB file.
"""
gs, mesh, trial_id = unpack_state(state)
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
glb_path = f"{TMP_DIR}/{trial_id}.glb"
glb.export(glb_path)
return glb_path, glb_path
def activate_button() -> gr.Button:
return gr.Button(interactive=True)
def deactivate_button() -> gr.Button:
return gr.Button(interactive=False)
def update(name):
return f"Welcome to Gradio, {name}!"
with gr.Blocks() as demo:
gr.Markdown("""
## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
* Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
* If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
""")
with gr.Row():
with gr.Column():
image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
with gr.Accordion(label="Generation Settings", open=False):
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
gr.Markdown("Stage 2: Structured Latent Generation")
with gr.Row():
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
generate_btn = gr.Button("Generate")
with gr.Accordion(label="GLB Extraction Settings", open=False):
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
with gr.Column():
# Todo - getting errors when using video output
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
model_output = gr.Model3D(label="Extracted GLB")
# Todo - getting errors when using model output
# model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
trial_id = gr.Textbox(visible=False)
output_buf = gr.State()
# # Example images at the bottom of the page
# with gr.Row():
# examples = gr.Examples(
# examples=[
# f'assets/example_image/{image}'
# for image in os.listdir("assets/example_image")
# ],
# inputs=[image_prompt],
# fn=preprocess_image,
# outputs=[trial_id, image_prompt],
# run_on_click=True,
# examples_per_page=64,
# )
# Handlers
image_prompt.upload(
preprocess_image,
inputs=[image_prompt],
outputs=[trial_id, image_prompt],
)
image_prompt.clear(
lambda: '',
outputs=[trial_id],
)
generate_btn.click(
image_to_3d,
inputs=[trial_id, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
outputs=[output_buf, video_output],
).then(
activate_button,
outputs=[extract_glb_btn],
)
video_output.clear(
deactivate_button,
outputs=[extract_glb_btn],
)
extract_glb_btn.click(
extract_glb,
inputs=[output_buf, mesh_simplify, texture_size],
outputs=[model_output, download_glb],
).then(
activate_button,
outputs=[download_glb],
)
model_output.clear(
deactivate_button,
outputs=[download_glb],
)
# Cleans up the temporary directory every 10 minutes
import threading
import time
def cleanup_tmp_dir():
while True:
if os.path.exists(TMP_DIR):
for file in os.listdir(TMP_DIR):
# remove files older than 10 minutes
if time.time() - os.path.getmtime(os.path.join(TMP_DIR, file)) > 600:
os.remove(os.path.join(TMP_DIR, file))
time.sleep(600)
cleanup_thread = threading.Thread(target=cleanup_tmp_dir)
cleanup_thread.start()
@spaces.GPU
def preload():
global preloaded
if preloaded:
return
preloaded = True
pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
pipeline.cuda()
try:
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
except:
pass
# Launch the Gradio app
if __name__ == "__main__":
global pipeline
# pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
# pipeline.cuda()
# try:
# pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
# except:
# pass
demo.launch()
|