File size: 10,988 Bytes
98ade97
db6a3b7
98ade97
3057b36
98ade97
db6a3b7
690b53e
98ade97
db6a3b7
98ade97
9880f3d
98ade97
7d475c1
98ade97
db6a3b7
98ade97
db6a3b7
98ade97
9880f3d
98ade97
db6a3b7
eee56a3
 
 
 
 
98ade97
db6a3b7
98ade97
9880f3d
98ade97
db6a3b7
 
 
bd46f72
a898014
bd46f72
d7b1815
 
ae973c3
a898014
db6a3b7
 
 
 
 
 
 
a898014
db6a3b7
 
a898014
ae973c3
db894f7
a898014
 
db6a3b7
 
a898014
9880f3d
 
 
 
 
 
 
 
 
 
 
 
 
a898014
9880f3d
8e0b0a8
 
9880f3d
 
 
 
 
 
 
 
 
 
 
 
 
 
8e0b0a8
9880f3d
 
 
 
8e0b0a8
a898014
9880f3d
 
3057b36
a898014
db6a3b7
 
 
 
a898014
bd46f72
 
 
 
 
 
db6a3b7
 
 
 
 
bd46f72
 
ae973c3
 
 
db894f7
a898014
db894f7
bd46f72
 
 
 
 
 
 
 
 
 
 
7d475c1
15fe7bc
 
a898014
 
db6a3b7
7d475c1
a898014
9880f3d
db6a3b7
 
3057b36
9880f3d
db6a3b7
 
 
 
9880f3d
db6a3b7
 
 
 
 
 
a898014
690b53e
a898014
db6a3b7
 
 
 
 
 
 
 
 
 
 
 
6e47b5e
 
 
db6a3b7
7d475c1
 
 
 
 
8e0b0a8
c60c59d
6e47b5e
c60c59d
6e47b5e
a085c25
 
4e2347a
 
 
 
 
 
 
 
 
 
 
8e0b0a8
6a6e1c3
8e0b0a8
6a6e1c3
 
 
8e0b0a8
6a6e1c3
8e0b0a8
d35de72
6bef8b1
 
b77fe63
6bef8b1
6a3f855
b77fe63
 
d35de72
8e0b0a8
efbad73
 
db6a3b7
ae973c3
 
 
 
 
 
 
 
 
 
 
 
 
6e47b5e
65146bb
 
 
 
 
 
 
 
 
 
6e47b5e
9067242
d66f772
 
 
 
 
 
 
 
 
 
 
 
 
 
6a3f855
 
 
 
 
 
 
 
 
 
 
 
 
8e0b0a8
db6a3b7
edf8619
 
 
 
 
 
 
 
 
 
 
 
8e0b0a8
edf8619
 
8e0b0a8
ae973c3
ee3dbc1
 
ae973c3
 
 
 
ba1ed50
a8e79b5
915ce08
 
 
 
639772e
 
 
a8e79b5
ae973c3
 
 
 
 
 
db6a3b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
print("importing gradio")
import gradio as gr
print("importing spaces")
import spaces
print("importing os")
import os
os.environ['SPCONV_ALGO'] = 'native'
print("importing typing")
from typing import *
print("importing torch")
import torch
print("importing numpy")
import numpy as np
print("importing imageio")
import imageio
print("importing uuid")
import uuid
print("importing easydict")
from easydict import EasyDict as edict
print("importing PIL")
from PIL import Image

print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")


print("importing trellis image to 3d pipeline")
from trellis.pipelines import TrellisImageTo3DPipeline
print("importing trellis representations")
from trellis.representations import Gaussian, MeshExtractResult
print("importing trellis utils")
from trellis.utils import render_utils, postprocessing_utils


MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = "/tmp/Trellis-demo"

os.makedirs(TMP_DIR, exist_ok=True)

@spaces.GPU
def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
    """
    Preprocess the input image.

    Args:
        image (Image.Image): The input image.

    Returns:
        str: uuid of the trial.
        Image.Image: The preprocessed image.
    """
    trial_id = str(uuid.uuid4())
    preload()
    processed_image = pipeline.preprocess_image(image)
    processed_image.save(f"{TMP_DIR}/{trial_id}.png")
    return trial_id, processed_image


def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
    return {
        'gaussian': {
            **gs.init_params,
            '_xyz': gs._xyz.cpu().numpy(),
            '_features_dc': gs._features_dc.cpu().numpy(),
            '_scaling': gs._scaling.cpu().numpy(),
            '_rotation': gs._rotation.cpu().numpy(),
            '_opacity': gs._opacity.cpu().numpy(),
        },
        'mesh': {
            'vertices': mesh.vertices.cpu().numpy(),
            'faces': mesh.faces.cpu().numpy(),
        },
        'trial_id': trial_id,
    }


def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
    gs = Gaussian(
        aabb=state['gaussian']['aabb'],
        sh_degree=state['gaussian']['sh_degree'],
        mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
        scaling_bias=state['gaussian']['scaling_bias'],
        opacity_bias=state['gaussian']['opacity_bias'],
        scaling_activation=state['gaussian']['scaling_activation'],
    )
    gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
    gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
    gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
    gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
    gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')

    mesh = edict(
        vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
        faces=torch.tensor(state['mesh']['faces'], device='cuda'),
    )

    return gs, mesh, state['trial_id']


@spaces.GPU
def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]:
    """
    Convert an image to a 3D model.

    Args:
        trial_id (str): The uuid of the trial.
        seed (int): The random seed.
        randomize_seed (bool): Whether to randomize the seed.
        ss_guidance_strength (float): The guidance strength for sparse structure generation.
        ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
        slat_guidance_strength (float): The guidance strength for structured latent generation.
        slat_sampling_steps (int): The number of sampling steps for structured latent generation.

    Returns:
        dict: The information of the generated 3D model.
        str: The path to the video of the 3D model.
    """
    if randomize_seed:
        seed = np.random.randint(0, MAX_SEED)


    preload()
    outputs = pipeline.run(
        Image.open(f"{TMP_DIR}/{trial_id}.png"),
        seed=seed,
        formats=["gaussian", "mesh"],
        preprocess_image=False,
        sparse_structure_sampler_params={
            "steps": ss_sampling_steps,
            "cfg_strength": ss_guidance_strength,
        },
        slat_sampler_params={
            "steps": slat_sampling_steps,
            "cfg_strength": slat_guidance_strength,
        },
    )
    video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
    video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
    video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
    trial_id = uuid.uuid4()
    video_path = f"{TMP_DIR}/{trial_id}.mp4"
    os.makedirs(os.path.dirname(video_path), exist_ok=True)
    imageio.mimsave(video_path, video, fps=15)
    state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
    return state, video_path


@spaces.GPU
def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
    """
    Extract a GLB file from the 3D model.

    Args:
        state (dict): The state of the generated 3D model.
        mesh_simplify (float): The mesh simplification factor.
        texture_size (int): The texture resolution.

    Returns:
        str: The path to the extracted GLB file.
    """
    gs, mesh, trial_id = unpack_state(state)
    glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
    glb_path = f"{TMP_DIR}/{trial_id}.glb"
    glb.export(glb_path)
    return glb_path, glb_path


def activate_button() -> gr.Button:
    return gr.Button(interactive=True)


def deactivate_button() -> gr.Button:
    return gr.Button(interactive=False)


def update(name):
    return f"Welcome to Gradio, {name}!"

with gr.Blocks() as demo:
    gr.Markdown("""
    ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
    * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
    * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
    """)


    with gr.Row():
        with gr.Column():

            image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)

            with gr.Accordion(label="Generation Settings", open=False):
                seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
                randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
                gr.Markdown("Stage 1: Sparse Structure Generation")
                with gr.Row():
                    ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
                    ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
                gr.Markdown("Stage 2: Structured Latent Generation")
                with gr.Row():
                    slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
                    slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)

            generate_btn = gr.Button("Generate")

            with gr.Accordion(label="GLB Extraction Settings", open=False):
                mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
                texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)

            extract_glb_btn = gr.Button("Extract GLB", interactive=False)

        with gr.Column():

            # Todo - getting errors when using video output
            video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)

            model_output = gr.Model3D(label="Extracted GLB")
            # Todo - getting errors when using model output
            # model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
            download_glb = gr.DownloadButton(label="Download GLB", interactive=False)

    trial_id = gr.Textbox(visible=False)
    output_buf = gr.State()

    # # Example images at the bottom of the page
    # with gr.Row():
    #     examples = gr.Examples(
    #         examples=[
    #             f'assets/example_image/{image}'
    #             for image in os.listdir("assets/example_image")
    #         ],
    #         inputs=[image_prompt],
    #         fn=preprocess_image,
    #         outputs=[trial_id, image_prompt],
    #         run_on_click=True,
    #         examples_per_page=64,
    #     )

    # Handlers
    image_prompt.upload(
        preprocess_image,
        inputs=[image_prompt],
        outputs=[trial_id, image_prompt],
    )
    image_prompt.clear(
        lambda: '',
        outputs=[trial_id],
    )


    generate_btn.click(
        image_to_3d,
        inputs=[trial_id, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
        outputs=[output_buf, video_output],
    ).then(
        activate_button,
        outputs=[extract_glb_btn],
    )

    video_output.clear(
        deactivate_button,
        outputs=[extract_glb_btn],
    )

    extract_glb_btn.click(
        extract_glb,
        inputs=[output_buf, mesh_simplify, texture_size],
        outputs=[model_output, download_glb],
    ).then(
        activate_button,
        outputs=[download_glb],
    )

    model_output.clear(
        deactivate_button,
        outputs=[download_glb],
    )


# Cleans up the temporary directory every 10 minutes
import threading
import time

def cleanup_tmp_dir():
    while True:
        if os.path.exists(TMP_DIR):
            for file in os.listdir(TMP_DIR):
                # remove files older than 10 minutes
                if time.time() - os.path.getmtime(os.path.join(TMP_DIR, file)) > 600:
                    os.remove(os.path.join(TMP_DIR, file))
        time.sleep(600)

cleanup_thread = threading.Thread(target=cleanup_tmp_dir)
cleanup_thread.start()


@spaces.GPU
def preload():
    global preloaded
    if preloaded:
        return
    preloaded = True
    pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
    pipeline.cuda()
    try:
        pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))    # Preload rembg
    except:
        pass

# Launch the Gradio app
if __name__ == "__main__":
    global pipeline
    # pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
    # pipeline.cuda()
    # try:
    #     pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))    # Preload rembg
    # except:
    #     pass
    demo.launch()