Surn commited on
Commit
98af578
·
1 Parent(s): bc2ae02

Fix depth estimation

Browse files
app.py CHANGED
@@ -100,7 +100,7 @@ from utils.version_info import (
100
  #release_torch_resources,
101
  #get_torch_info
102
  )
103
-
104
 
105
  input_image_palette = []
106
  current_prerendered_image = gr.State("./images/images/Beeuty-1.png")
@@ -722,7 +722,8 @@ def generate_3d_asset(depth_image_source, randomize_seed, seed, input_image, out
722
  # Validate the mesh
723
  mesh = outputs['mesh'][0]
724
  # Depending on the mesh format (it might be a dict or an object)
725
- if isinstance(mesh, dict):
 
726
  vertices = mesh['vertices']
727
  faces = mesh['faces']
728
  else:
@@ -738,20 +739,29 @@ def generate_3d_asset(depth_image_source, randomize_seed, seed, input_image, out
738
  if not vertices.is_cuda or not faces.is_cuda:
739
  raise ValueError("Mesh data must be on GPU")
740
  if vertices.dtype != torch.float32 or faces.dtype != torch.int32:
741
- raise ValueError("Mesh vertices must be float32, faces must be int32")
 
 
 
 
 
 
742
 
743
  # Save the video to a temporary file
744
  user_dir = os.path.join(constants.TMPDIR, str(req.session_hash))
745
  os.makedirs(user_dir, exist_ok=True)
746
 
747
  video = render_utils.render_video(outputs['gaussian'][0], resolution=576, num_frames=60, r=1)['color']
748
- snapshot_results = render_utils.render_snapshot(outputs['gaussian'][0], resolution=576)
749
- depth_snapshot = snapshot_results['depth'][0]
750
  video_geo = render_utils.render_video(outputs['mesh'][0], resolution=576, num_frames=60, r=1)['normal']
751
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
752
  video_path = os.path.join(user_dir, f'{output_name}.mp4')
753
- imageio.mimsave(video_path, video, fps=15)
 
 
 
 
754
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], output_name)
 
755
  torch.cuda.empty_cache()
756
  return [state, video_path, depth_snapshot]
757
 
@@ -1106,7 +1116,7 @@ with gr.Blocks(css_paths="style_20250128.css", title=title, theme='Surn/beeuty',
1106
  with gr.Row():
1107
  with gr.Column():
1108
  # Use standard seed settings only
1109
- seed_3d = gr.Slider(0, constants.MAX_SEED, label="Seed (3D Generation)", value=0, step=1)
1110
  randomize_seed_3d = gr.Checkbox(label="Randomize Seed (3D Generation)", value=True)
1111
  with gr.Column():
1112
  depth_image_source = gr.Radio(
 
100
  #release_torch_resources,
101
  #get_torch_info
102
  )
103
+ from utils.depth_estimation import (get_depth_map_from_state)
104
 
105
  input_image_palette = []
106
  current_prerendered_image = gr.State("./images/images/Beeuty-1.png")
 
722
  # Validate the mesh
723
  mesh = outputs['mesh'][0]
724
  # Depending on the mesh format (it might be a dict or an object)
725
+ meshisdict = isinstance(mesh, dict)
726
+ if meshisdict:
727
  vertices = mesh['vertices']
728
  faces = mesh['faces']
729
  else:
 
739
  if not vertices.is_cuda or not faces.is_cuda:
740
  raise ValueError("Mesh data must be on GPU")
741
  if vertices.dtype != torch.float32 or faces.dtype != torch.int32:
742
+ if meshisdict:
743
+ mesh['faces'] = faces.to(torch.int32)
744
+ mesh['vertices'] = vertices.to(torch.float32)
745
+ else:
746
+ mesh.faces = faces.to(torch.int32)
747
+ mesh.vertices = vertices.to(torch.float32)
748
+ #raise ValueError("Mesh vertices must be float32, faces must be int32")
749
 
750
  # Save the video to a temporary file
751
  user_dir = os.path.join(constants.TMPDIR, str(req.session_hash))
752
  os.makedirs(user_dir, exist_ok=True)
753
 
754
  video = render_utils.render_video(outputs['gaussian'][0], resolution=576, num_frames=60, r=1)['color']
 
 
755
  video_geo = render_utils.render_video(outputs['mesh'][0], resolution=576, num_frames=60, r=1)['normal']
756
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
757
  video_path = os.path.join(user_dir, f'{output_name}.mp4')
758
+ imageio.mimsave(video_path, video, fps=10)
759
+
760
+ snapshot_results = render_utils.render_snapshot(outputs['mesh'][0], resolution=768)['depth']
761
+ depth_snapshot = Image.fromarray(snapshot_results[0])
762
+
763
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], output_name)
764
+ #depth_snapshot = get_depth_map_from_state(state, image_raw.size[0], image_raw.size[1])
765
  torch.cuda.empty_cache()
766
  return [state, video_path, depth_snapshot]
767
 
 
1116
  with gr.Row():
1117
  with gr.Column():
1118
  # Use standard seed settings only
1119
+ seed_3d = gr.Slider(0, constants.MAX_SEED, label="Seed (3D Generation)", value=0, step=1, randomize=True)
1120
  randomize_seed_3d = gr.Checkbox(label="Randomize Seed (3D Generation)", value=True)
1121
  with gr.Column():
1122
  depth_image_source = gr.Radio(
trellis/renderers/gaussian_render.py CHANGED
@@ -11,7 +11,6 @@
11
 
12
  import torch
13
  import math
14
- from easydict import EasyDict as edict
15
  import numpy as np
16
  from ..representations.gaussian import Gaussian
17
  from .sh_utils import eval_sh
 
11
 
12
  import torch
13
  import math
 
14
  import numpy as np
15
  from ..representations.gaussian import Gaussian
16
  from .sh_utils import eval_sh
trellis/utils/render_utils.py CHANGED
@@ -67,6 +67,53 @@ def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=N
67
  else:
68
  raise ValueError(f'Unsupported sample type: {type(sample)}')
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  rets = {}
71
  for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
72
  if not isinstance(sample, MeshExtractResult):
@@ -82,12 +129,14 @@ def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=N
82
  rets['depth'].append(None)
83
  else:
84
  res = renderer.render(sample, extr, intr)
85
- if 'normal' not in rets: rets['normal'] = []
86
- rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
 
 
 
87
  return rets
88
 
89
-
90
- def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs):
91
  yaws = torch.linspace(0, 2 * 3.1415, num_frames)
92
  pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
93
  yaws = yaws.tolist()
@@ -114,3 +163,11 @@ def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 1
114
  pitch = [offset[1] for _ in range(4)]
115
  extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
116
  return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
 
 
 
 
 
 
 
 
 
67
  else:
68
  raise ValueError(f'Unsupported sample type: {type(sample)}')
69
 
70
+ rets = {}
71
+ for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
72
+ if not isinstance(sample, MeshExtractResult):
73
+ res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite)
74
+ if 'color' not in rets: rets['color'] = []
75
+ # if 'depth' not in rets: rets['depth'] = []
76
+ rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
77
+ # if 'percent_depth' in res:
78
+ # rets['depth'].append(res['percent_depth'].detach().cpu().numpy())
79
+ # elif 'depth' in res:
80
+ # rets['depth'].append(res['depth'].detach().cpu().numpy())
81
+ # else:
82
+ # rets['depth'].append(None)
83
+ else:
84
+ res = renderer.render(sample, extr, intr)
85
+ if 'normal' not in rets: rets['normal'] = []
86
+ rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
87
+
88
+ return rets
89
+
90
+ def render_frames_depth(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs):
91
+ if isinstance(sample, Octree):
92
+ renderer = OctreeRenderer()
93
+ renderer.rendering_options.resolution = options.get('resolution', 512)
94
+ renderer.rendering_options.near = options.get('near', 0.8)
95
+ renderer.rendering_options.far = options.get('far', 1.6)
96
+ renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0))
97
+ renderer.rendering_options.ssaa = options.get('ssaa', 4)
98
+ renderer.pipe.primitive = sample.primitive
99
+ elif isinstance(sample, Gaussian):
100
+ renderer = GaussianRenderer()
101
+ renderer.rendering_options.resolution = options.get('resolution', 512)
102
+ renderer.rendering_options.near = options.get('near', 0.8)
103
+ renderer.rendering_options.far = options.get('far', 1.6)
104
+ renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0))
105
+ renderer.rendering_options.ssaa = options.get('ssaa', 1)
106
+ renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1)
107
+ renderer.pipe.use_mip_gaussian = True
108
+ elif isinstance(sample, MeshExtractResult):
109
+ renderer = MeshRenderer()
110
+ renderer.rendering_options.resolution = options.get('resolution', 512)
111
+ renderer.rendering_options.near = options.get('near', 1)
112
+ renderer.rendering_options.far = options.get('far', 100)
113
+ renderer.rendering_options.ssaa = options.get('ssaa', 4)
114
+ else:
115
+ raise ValueError(f'Unsupported sample type: {type(sample)}')
116
+
117
  rets = {}
118
  for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
119
  if not isinstance(sample, MeshExtractResult):
 
129
  rets['depth'].append(None)
130
  else:
131
  res = renderer.render(sample, extr, intr)
132
+ if 'depth' not in rets: rets['depth'] = []
133
+ if 'depth' in res:
134
+ rets['depth'].append(np.clip(res['depth'].detach().cpu().numpy(), 0, 255).astype(np.uint8))
135
+ else:
136
+ rets['depth'].append(None)
137
  return rets
138
 
139
+ def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=60, **kwargs):
 
140
  yaws = torch.linspace(0, 2 * 3.1415, num_frames)
141
  pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
142
  yaws = yaws.tolist()
 
163
  pitch = [offset[1] for _ in range(4)]
164
  extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
165
  return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
166
+
167
+ def render_snapshot_depth(samples, resolution=512, bg_color=(0, 0, 0), offset=(0, np.pi/2), r=1,fov=80, **kwargs):
168
+ yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
169
+ yaw_offset = offset[0]
170
+ yaw = [y + yaw_offset for y in yaw]
171
+ pitch = [offset[1] for _ in range(4)]
172
+ extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
173
+ return render_frames_depth(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
utils/depth_estimation.py CHANGED
@@ -12,6 +12,8 @@ from utils.image_utils import (
12
  resize_image_with_aspect_ratio
13
  )
14
  from utils.constants import TMPDIR
 
 
15
 
16
  # Load models once during module import
17
  image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
@@ -258,10 +260,10 @@ def depth_process_image(image_path, resized_width=800, z_scale=208):
258
  torch.cuda.ipc_collect()
259
  return [img, gltf_path, gltf_path]
260
 
261
- def get_depth_map_from_state(state):
262
  from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings
263
 
264
- settings = GaussianRasterizationSettings(image_height=1024, image_width=1024, bg_color=(1.0, 1.0, 1.0), max_num_points_per_tile=100, tile_size=32, filter_mode="linear", precompute_cov3D=True, precompute_cov2D=True, precompute_colors=True)
265
  rasterizer = GaussianRasterizer(settings)
266
  # Assume state has necessary data like means3D, scales, etc.
267
  rendered_image, rendered_depth, _, _, _, _ = rasterizer(means3D=state["means3D"], means2D=state["means2D"], shs=state["shs"], colors_precomp=state["colors_precomp"], opacities=state["opacities"], scales=state["scales"], rotations=state["rotations"], cov3D_precomp=state["cov3D_precomp"])
 
12
  resize_image_with_aspect_ratio
13
  )
14
  from utils.constants import TMPDIR
15
+ from easydict import EasyDict as edict
16
+
17
 
18
  # Load models once during module import
19
  image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
 
260
  torch.cuda.ipc_collect()
261
  return [img, gltf_path, gltf_path]
262
 
263
+ def get_depth_map_from_state(state, image_height=1024, image_width=1024):
264
  from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings
265
 
266
+ settings = GaussianRasterizationSettings(image_height=image_height, image_width=image_width, kernel_size=0.01,bg=(0.0, 0.0, 0.0))
267
  rasterizer = GaussianRasterizer(settings)
268
  # Assume state has necessary data like means3D, scales, etc.
269
  rendered_image, rendered_depth, _, _, _, _ = rasterizer(means3D=state["means3D"], means2D=state["means2D"], shs=state["shs"], colors_precomp=state["colors_precomp"], opacities=state["opacities"], scales=state["scales"], rotations=state["rotations"], cov3D_precomp=state["cov3D_precomp"])