davidvgilmore commited on
Commit
b58f246
·
verified ·
1 Parent(s): b56a409

Upload hy3dgen/shapegen/models/vae.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. hy3dgen/shapegen/models/vae.py +636 -0
hy3dgen/shapegen/models/vae.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0
2
+ # and Other Licenses of the Third-Party Components therein:
3
+ # The below Model in this distribution may have been modified by THL A29 Limited
4
+ # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
+
6
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
+ # The below software and/or models in this distribution may have been
8
+ # modified by THL A29 Limited ("Tencent Modifications").
9
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
+
11
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
+ # except for the third-party components listed below.
13
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
+ # in the repsective licenses of these third-party components.
15
+ # Users must comply with all terms and conditions of original licenses of these third-party
16
+ # components and must ensure that the usage of the third party components adheres to
17
+ # all relevant laws and regulations.
18
+
19
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
20
+ # their software and algorithms, including trained model weights, parameters (including
21
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
23
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
+
25
+ from typing import Tuple, List, Union, Optional
26
+
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ from einops import rearrange, repeat
32
+ from skimage import measure
33
+ from tqdm import tqdm
34
+
35
+
36
+ class FourierEmbedder(nn.Module):
37
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
38
+ each feature dimension of `x[..., i]` into:
39
+ [
40
+ sin(x[..., i]),
41
+ sin(f_1*x[..., i]),
42
+ sin(f_2*x[..., i]),
43
+ ...
44
+ sin(f_N * x[..., i]),
45
+ cos(x[..., i]),
46
+ cos(f_1*x[..., i]),
47
+ cos(f_2*x[..., i]),
48
+ ...
49
+ cos(f_N * x[..., i]),
50
+ x[..., i] # only present if include_input is True.
51
+ ], here f_i is the frequency.
52
+
53
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
54
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
55
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
56
+
57
+ Args:
58
+ num_freqs (int): the number of frequencies, default is 6;
59
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
60
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
61
+ input_dim (int): the input dimension, default is 3;
62
+ include_input (bool): include the input tensor or not, default is True.
63
+
64
+ Attributes:
65
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
66
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
67
+
68
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
69
+ otherwise, it is input_dim * num_freqs * 2.
70
+
71
+ """
72
+
73
+ def __init__(self,
74
+ num_freqs: int = 6,
75
+ logspace: bool = True,
76
+ input_dim: int = 3,
77
+ include_input: bool = True,
78
+ include_pi: bool = True) -> None:
79
+
80
+ """The initialization"""
81
+
82
+ super().__init__()
83
+
84
+ if logspace:
85
+ frequencies = 2.0 ** torch.arange(
86
+ num_freqs,
87
+ dtype=torch.float32
88
+ )
89
+ else:
90
+ frequencies = torch.linspace(
91
+ 1.0,
92
+ 2.0 ** (num_freqs - 1),
93
+ num_freqs,
94
+ dtype=torch.float32
95
+ )
96
+
97
+ if include_pi:
98
+ frequencies *= torch.pi
99
+
100
+ self.register_buffer("frequencies", frequencies, persistent=False)
101
+ self.include_input = include_input
102
+ self.num_freqs = num_freqs
103
+
104
+ self.out_dim = self.get_dims(input_dim)
105
+
106
+ def get_dims(self, input_dim):
107
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
108
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
109
+
110
+ return out_dim
111
+
112
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
113
+ """ Forward process.
114
+
115
+ Args:
116
+ x: tensor of shape [..., dim]
117
+
118
+ Returns:
119
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
120
+ where temp is 1 if include_input is True and 0 otherwise.
121
+ """
122
+
123
+ if self.num_freqs > 0:
124
+ embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
125
+ if self.include_input:
126
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
127
+ else:
128
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
129
+ else:
130
+ return x
131
+
132
+
133
+ class DropPath(nn.Module):
134
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
135
+ """
136
+
137
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
138
+ super(DropPath, self).__init__()
139
+ self.drop_prob = drop_prob
140
+ self.scale_by_keep = scale_by_keep
141
+
142
+ def forward(self, x):
143
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
144
+
145
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
146
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
147
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
148
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
149
+ 'survival rate' as the argument.
150
+
151
+ """
152
+ if self.drop_prob == 0. or not self.training:
153
+ return x
154
+ keep_prob = 1 - self.drop_prob
155
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
156
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
157
+ if keep_prob > 0.0 and self.scale_by_keep:
158
+ random_tensor.div_(keep_prob)
159
+ return x * random_tensor
160
+
161
+ def extra_repr(self):
162
+ return f'drop_prob={round(self.drop_prob, 3):0.3f}'
163
+
164
+
165
+ class MLP(nn.Module):
166
+ def __init__(
167
+ self, *,
168
+ width: int,
169
+ output_width: int = None,
170
+ drop_path_rate: float = 0.0
171
+ ):
172
+ super().__init__()
173
+ self.width = width
174
+ self.c_fc = nn.Linear(width, width * 4)
175
+ self.c_proj = nn.Linear(width * 4, output_width if output_width is not None else width)
176
+ self.gelu = nn.GELU()
177
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
178
+
179
+ def forward(self, x):
180
+ return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
181
+
182
+
183
+ class QKVMultiheadCrossAttention(nn.Module):
184
+ def __init__(
185
+ self,
186
+ *,
187
+ heads: int,
188
+ n_data: Optional[int] = None,
189
+ width=None,
190
+ qk_norm=False,
191
+ norm_layer=nn.LayerNorm
192
+ ):
193
+ super().__init__()
194
+ self.heads = heads
195
+ self.n_data = n_data
196
+ self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
197
+ self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
198
+
199
+ def forward(self, q, kv):
200
+ _, n_ctx, _ = q.shape
201
+ bs, n_data, width = kv.shape
202
+ attn_ch = width // self.heads // 2
203
+ q = q.view(bs, n_ctx, self.heads, -1)
204
+ kv = kv.view(bs, n_data, self.heads, -1)
205
+ k, v = torch.split(kv, attn_ch, dim=-1)
206
+
207
+ q = self.q_norm(q)
208
+ k = self.k_norm(k)
209
+
210
+ q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
211
+ out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
212
+
213
+ return out
214
+
215
+
216
+ class MultiheadCrossAttention(nn.Module):
217
+ def __init__(
218
+ self,
219
+ *,
220
+ width: int,
221
+ heads: int,
222
+ qkv_bias: bool = True,
223
+ n_data: Optional[int] = None,
224
+ data_width: Optional[int] = None,
225
+ norm_layer=nn.LayerNorm,
226
+ qk_norm: bool = False
227
+ ):
228
+ super().__init__()
229
+ self.n_data = n_data
230
+ self.width = width
231
+ self.heads = heads
232
+ self.data_width = width if data_width is None else data_width
233
+ self.c_q = nn.Linear(width, width, bias=qkv_bias)
234
+ self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
235
+ self.c_proj = nn.Linear(width, width)
236
+ self.attention = QKVMultiheadCrossAttention(
237
+ heads=heads,
238
+ n_data=n_data,
239
+ width=width,
240
+ norm_layer=norm_layer,
241
+ qk_norm=qk_norm
242
+ )
243
+
244
+ def forward(self, x, data):
245
+ x = self.c_q(x)
246
+ data = self.c_kv(data)
247
+ x = self.attention(x, data)
248
+ x = self.c_proj(x)
249
+ return x
250
+
251
+
252
+ class ResidualCrossAttentionBlock(nn.Module):
253
+ def __init__(
254
+ self,
255
+ *,
256
+ n_data: Optional[int] = None,
257
+ width: int,
258
+ heads: int,
259
+ data_width: Optional[int] = None,
260
+ qkv_bias: bool = True,
261
+ norm_layer=nn.LayerNorm,
262
+ qk_norm: bool = False
263
+ ):
264
+ super().__init__()
265
+
266
+ if data_width is None:
267
+ data_width = width
268
+
269
+ self.attn = MultiheadCrossAttention(
270
+ n_data=n_data,
271
+ width=width,
272
+ heads=heads,
273
+ data_width=data_width,
274
+ qkv_bias=qkv_bias,
275
+ norm_layer=norm_layer,
276
+ qk_norm=qk_norm
277
+ )
278
+ self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
279
+ self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
280
+ self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
281
+ self.mlp = MLP(width=width)
282
+
283
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
284
+ x = x + self.attn(self.ln_1(x), self.ln_2(data))
285
+ x = x + self.mlp(self.ln_3(x))
286
+ return x
287
+
288
+
289
+ class QKVMultiheadAttention(nn.Module):
290
+ def __init__(
291
+ self,
292
+ *,
293
+ heads: int,
294
+ n_ctx: int,
295
+ width=None,
296
+ qk_norm=False,
297
+ norm_layer=nn.LayerNorm
298
+ ):
299
+ super().__init__()
300
+ self.heads = heads
301
+ self.n_ctx = n_ctx
302
+ self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
303
+ self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
304
+
305
+ def forward(self, qkv):
306
+ bs, n_ctx, width = qkv.shape
307
+ attn_ch = width // self.heads // 3
308
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
309
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
310
+
311
+ q = self.q_norm(q)
312
+ k = self.k_norm(k)
313
+
314
+ q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
315
+ out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
316
+ return out
317
+
318
+
319
+ class MultiheadAttention(nn.Module):
320
+ def __init__(
321
+ self,
322
+ *,
323
+ n_ctx: int,
324
+ width: int,
325
+ heads: int,
326
+ qkv_bias: bool,
327
+ norm_layer=nn.LayerNorm,
328
+ qk_norm: bool = False,
329
+ drop_path_rate: float = 0.0
330
+ ):
331
+ super().__init__()
332
+ self.n_ctx = n_ctx
333
+ self.width = width
334
+ self.heads = heads
335
+ self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
336
+ self.c_proj = nn.Linear(width, width)
337
+ self.attention = QKVMultiheadAttention(
338
+ heads=heads,
339
+ n_ctx=n_ctx,
340
+ width=width,
341
+ norm_layer=norm_layer,
342
+ qk_norm=qk_norm
343
+ )
344
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
345
+
346
+ def forward(self, x):
347
+ x = self.c_qkv(x)
348
+ x = self.attention(x)
349
+ x = self.drop_path(self.c_proj(x))
350
+ return x
351
+
352
+
353
+ class ResidualAttentionBlock(nn.Module):
354
+ def __init__(
355
+ self,
356
+ *,
357
+ n_ctx: int,
358
+ width: int,
359
+ heads: int,
360
+ qkv_bias: bool = True,
361
+ norm_layer=nn.LayerNorm,
362
+ qk_norm: bool = False,
363
+ drop_path_rate: float = 0.0,
364
+ ):
365
+ super().__init__()
366
+ self.attn = MultiheadAttention(
367
+ n_ctx=n_ctx,
368
+ width=width,
369
+ heads=heads,
370
+ qkv_bias=qkv_bias,
371
+ norm_layer=norm_layer,
372
+ qk_norm=qk_norm,
373
+ drop_path_rate=drop_path_rate
374
+ )
375
+ self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
376
+ self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
377
+ self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
378
+
379
+ def forward(self, x: torch.Tensor):
380
+ x = x + self.attn(self.ln_1(x))
381
+ x = x + self.mlp(self.ln_2(x))
382
+ return x
383
+
384
+
385
+ class Transformer(nn.Module):
386
+ def __init__(
387
+ self,
388
+ *,
389
+ n_ctx: int,
390
+ width: int,
391
+ layers: int,
392
+ heads: int,
393
+ qkv_bias: bool = True,
394
+ norm_layer=nn.LayerNorm,
395
+ qk_norm: bool = False,
396
+ drop_path_rate: float = 0.0
397
+ ):
398
+ super().__init__()
399
+ self.n_ctx = n_ctx
400
+ self.width = width
401
+ self.layers = layers
402
+ self.resblocks = nn.ModuleList(
403
+ [
404
+ ResidualAttentionBlock(
405
+ n_ctx=n_ctx,
406
+ width=width,
407
+ heads=heads,
408
+ qkv_bias=qkv_bias,
409
+ norm_layer=norm_layer,
410
+ qk_norm=qk_norm,
411
+ drop_path_rate=drop_path_rate
412
+ )
413
+ for _ in range(layers)
414
+ ]
415
+ )
416
+
417
+ def forward(self, x: torch.Tensor):
418
+ for block in self.resblocks:
419
+ x = block(x)
420
+ return x
421
+
422
+
423
+ class CrossAttentionDecoder(nn.Module):
424
+
425
+ def __init__(
426
+ self,
427
+ *,
428
+ num_latents: int,
429
+ out_channels: int,
430
+ fourier_embedder: FourierEmbedder,
431
+ width: int,
432
+ heads: int,
433
+ qkv_bias: bool = True,
434
+ qk_norm: bool = False,
435
+ label_type: str = "binary"
436
+ ):
437
+ super().__init__()
438
+
439
+ self.fourier_embedder = fourier_embedder
440
+
441
+ self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)
442
+
443
+ self.cross_attn_decoder = ResidualCrossAttentionBlock(
444
+ n_data=num_latents,
445
+ width=width,
446
+ heads=heads,
447
+ qkv_bias=qkv_bias,
448
+ qk_norm=qk_norm
449
+ )
450
+
451
+ self.ln_post = nn.LayerNorm(width)
452
+ self.output_proj = nn.Linear(width, out_channels)
453
+ self.label_type = label_type
454
+
455
+ def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
456
+ queries = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
457
+ x = self.cross_attn_decoder(queries, latents)
458
+ x = self.ln_post(x)
459
+ occ = self.output_proj(x)
460
+ return occ
461
+
462
+
463
+ def generate_dense_grid_points(bbox_min: np.ndarray,
464
+ bbox_max: np.ndarray,
465
+ octree_depth: int,
466
+ indexing: str = "ij",
467
+ octree_resolution: int = None,
468
+ ):
469
+ length = bbox_max - bbox_min
470
+ num_cells = np.exp2(octree_depth)
471
+ if octree_resolution is not None:
472
+ num_cells = octree_resolution
473
+
474
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
475
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
476
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
477
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
478
+ xyz = np.stack((xs, ys, zs), axis=-1)
479
+ xyz = xyz.reshape(-1, 3)
480
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
481
+
482
+ return xyz, grid_size, length
483
+
484
+
485
+ def center_vertices(vertices):
486
+ """Translate the vertices so that bounding box is centered at zero."""
487
+ vert_min = vertices.min(dim=0)[0]
488
+ vert_max = vertices.max(dim=0)[0]
489
+ vert_center = 0.5 * (vert_min + vert_max)
490
+ return vertices - vert_center
491
+
492
+
493
+ class Latent2MeshOutput:
494
+
495
+ def __init__(self, mesh_v=None, mesh_f=None):
496
+ self.mesh_v = mesh_v
497
+ self.mesh_f = mesh_f
498
+
499
+
500
+ class ShapeVAE(nn.Module):
501
+ def __init__(
502
+ self,
503
+ *,
504
+ num_latents: int,
505
+ embed_dim: int,
506
+ width: int,
507
+ heads: int,
508
+ num_decoder_layers: int,
509
+ num_freqs: int = 8,
510
+ include_pi: bool = True,
511
+ qkv_bias: bool = True,
512
+ qk_norm: bool = False,
513
+ label_type: str = "binary",
514
+ drop_path_rate: float = 0.0,
515
+ scale_factor: float = 1.0,
516
+ ):
517
+ super().__init__()
518
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
519
+
520
+ self.post_kl = nn.Linear(embed_dim, width)
521
+
522
+ self.transformer = Transformer(
523
+ n_ctx=num_latents,
524
+ width=width,
525
+ layers=num_decoder_layers,
526
+ heads=heads,
527
+ qkv_bias=qkv_bias,
528
+ qk_norm=qk_norm,
529
+ drop_path_rate=drop_path_rate
530
+ )
531
+
532
+ self.geo_decoder = CrossAttentionDecoder(
533
+ fourier_embedder=self.fourier_embedder,
534
+ out_channels=1,
535
+ num_latents=num_latents,
536
+ width=width,
537
+ heads=heads,
538
+ qkv_bias=qkv_bias,
539
+ qk_norm=qk_norm,
540
+ label_type=label_type,
541
+ )
542
+
543
+ self.scale_factor = scale_factor
544
+ self.latent_shape = (num_latents, embed_dim)
545
+
546
+ def forward(self, latents):
547
+ latents = self.post_kl(latents)
548
+ latents = self.transformer(latents)
549
+ return latents
550
+
551
+ @torch.no_grad()
552
+ def latents2mesh(
553
+ self,
554
+ latents: torch.FloatTensor,
555
+ bounds: Union[Tuple[float], List[float], float] = 1.1,
556
+ octree_depth: int = 7,
557
+ num_chunks: int = 10000,
558
+ mc_level: float = -1 / 512,
559
+ octree_resolution: int = None,
560
+ mc_algo: str = 'dmc',
561
+ ):
562
+ device = latents.device
563
+
564
+ # 1. generate query points
565
+ if isinstance(bounds, float):
566
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
567
+ bbox_min = np.array(bounds[0:3])
568
+ bbox_max = np.array(bounds[3:6])
569
+ bbox_size = bbox_max - bbox_min
570
+ xyz_samples, grid_size, length = generate_dense_grid_points(
571
+ bbox_min=bbox_min,
572
+ bbox_max=bbox_max,
573
+ octree_depth=octree_depth,
574
+ octree_resolution=octree_resolution,
575
+ indexing="ij"
576
+ )
577
+ xyz_samples = torch.FloatTensor(xyz_samples)
578
+
579
+ # 2. latents to 3d volume
580
+ batch_logits = []
581
+ batch_size = latents.shape[0]
582
+ for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
583
+ desc=f"MC Level {mc_level} Implicit Function:"):
584
+ queries = xyz_samples[start: start + num_chunks, :].to(device)
585
+ queries = queries.half()
586
+ batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
587
+
588
+ logits = self.geo_decoder(batch_queries.to(latents.dtype), latents)
589
+ if mc_level == -1:
590
+ mc_level = 0
591
+ logits = torch.sigmoid(logits) * 2 - 1
592
+ print(f'Training with soft labels, inference with sigmoid and marching cubes level 0.')
593
+ batch_logits.append(logits)
594
+ grid_logits = torch.cat(batch_logits, dim=1)
595
+ grid_logits = grid_logits.view((batch_size, grid_size[0], grid_size[1], grid_size[2])).float()
596
+
597
+ # 3. extract surface
598
+ outputs = []
599
+ for i in range(batch_size):
600
+ try:
601
+ if mc_algo == 'mc':
602
+ vertices, faces, normals, _ = measure.marching_cubes(
603
+ grid_logits[i].cpu().numpy(),
604
+ mc_level,
605
+ method="lewiner"
606
+ )
607
+ vertices = vertices / grid_size * bbox_size + bbox_min
608
+ elif mc_algo == 'dmc':
609
+ if not hasattr(self, 'dmc'):
610
+ try:
611
+ from diso import DiffDMC
612
+ except:
613
+ raise ImportError("Please install diso via `pip install diso`, or set mc_algo to 'mc'")
614
+ self.dmc = DiffDMC(dtype=torch.float32).to(device)
615
+ octree_resolution = 2 ** octree_depth if octree_resolution is None else octree_resolution
616
+ sdf = -grid_logits[i] / octree_resolution
617
+ verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
618
+ verts = center_vertices(verts)
619
+ vertices = verts.detach().cpu().numpy()
620
+ faces = faces.detach().cpu().numpy()[:, ::-1]
621
+ else:
622
+ raise ValueError(f"mc_algo {mc_algo} not supported.")
623
+
624
+ outputs.append(
625
+ Latent2MeshOutput(
626
+ mesh_v=vertices.astype(np.float32),
627
+ mesh_f=np.ascontiguousarray(faces)
628
+ )
629
+ )
630
+
631
+ except ValueError:
632
+ outputs.append(None)
633
+ except RuntimeError:
634
+ outputs.append(None)
635
+
636
+ return outputs