ifire commited on
Commit
a6bbecf
β€’
1 Parent(s): 0547c7e

Format code and change app.py.

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. .pre-commit-config.yaml +11 -0
  3. README.md +1 -1
  4. app.py +155 -93
  5. extensions/nvdiffrast/README.md +1 -1
  6. extensions/nvdiffrast/nvdiffrast/__init__.py +1 -1
  7. extensions/nvdiffrast/nvdiffrast/common/antialias.cu +1 -1
  8. extensions/nvdiffrast/nvdiffrast/common/cudaraster/CudaRaster.hpp +0 -1
  9. extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.hpp +0 -1
  10. extensions/nvdiffrast/nvdiffrast/common/interpolate.cu +2 -2
  11. extensions/nvdiffrast/nvdiffrast/common/texture.cpp +1 -1
  12. extensions/nvdiffrast/nvdiffrast/tensorflow/ops.py +161 -53
  13. extensions/nvdiffrast/nvdiffrast/tensorflow/plugin_loader.py +170 -76
  14. extensions/nvdiffrast/nvdiffrast/tensorflow/tf_antialias.cu +2 -2
  15. extensions/nvdiffrast/nvdiffrast/tensorflow/tf_interpolate.cu +5 -5
  16. extensions/nvdiffrast/nvdiffrast/tensorflow/tf_texture.cu +3 -3
  17. extensions/nvdiffrast/nvdiffrast/torch/__init__.py +27 -2
  18. extensions/nvdiffrast/nvdiffrast/torch/ops.py +325 -139
  19. extensions/nvdiffrast/setup copy.py +21 -18
  20. extensions/nvdiffrast/setup.py +22 -22
  21. requirements.txt +1 -1
  22. trellis/models/__init__.py +23 -15
  23. trellis/models/sparse_structure_flow.py +49 -26
  24. trellis/models/sparse_structure_vae.py +50 -42
  25. trellis/models/structured_latent_flow.py +70 -45
  26. trellis/models/structured_latent_vae/base.py +36 -20
  27. trellis/models/structured_latent_vae/decoder_gs.py +54 -23
  28. trellis/models/structured_latent_vae/decoder_mesh.py +43 -26
  29. trellis/models/structured_latent_vae/decoder_rf.py +38 -14
  30. trellis/models/structured_latent_vae/encoder.py +5 -3
  31. trellis/modules/attention/__init__.py +18 -11
  32. trellis/modules/attention/full_attn.py +62 -43
  33. trellis/modules/attention/modules.py +47 -22
  34. trellis/modules/norm.py +5 -5
  35. trellis/modules/sparse/__init__.py +51 -44
  36. trellis/modules/sparse/attention/full_attn.py +149 -66
  37. trellis/modules/sparse/attention/modules.py +44 -17
  38. trellis/modules/sparse/attention/serialized_attn.py +108 -51
  39. trellis/modules/sparse/attention/windowed_attn.py +88 -45
  40. trellis/modules/sparse/basic.py +198 -118
  41. trellis/modules/sparse/conv/__init__.py +12 -7
  42. trellis/modules/sparse/conv/conv_spconv.py +78 -20
  43. trellis/modules/sparse/conv/conv_torchsparse.py +52 -14
  44. trellis/modules/sparse/linear.py +1 -3
  45. trellis/modules/sparse/nonlinearity.py +2 -8
  46. trellis/modules/sparse/norm.py +10 -5
  47. trellis/modules/sparse/spatial.py +49 -29
  48. trellis/modules/sparse/transformer/__init__.py +1 -1
  49. trellis/modules/sparse/transformer/blocks.py +14 -4
  50. trellis/modules/sparse/transformer/modulated.py +44 -15
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v2.3.0
4
+ hooks:
5
+ - id: check-yaml
6
+ - id: end-of-file-fixer
7
+ - id: trailing-whitespace
8
+ - repo: https://github.com/psf/black
9
+ rev: 22.10.0
10
+ hooks:
11
+ - id: black
README.md CHANGED
@@ -13,4 +13,4 @@ short_description: Scalable and Versatile 3D Generation from images
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
 
16
- Paper: https://huggingface.co/papers/2412.01506
 
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
 
16
+ Paper: https://huggingface.co/papers/2412.01506
app.py CHANGED
@@ -4,7 +4,8 @@ from gradio_litmodel3d import LitModel3D
4
 
5
  import os
6
  import shutil
7
- os.environ['SPCONV_ALGO'] = 'native'
 
8
  from typing import *
9
  import torch
10
  import numpy as np
@@ -17,15 +18,24 @@ from trellis.utils import render_utils, postprocessing_utils
17
 
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
 
 
 
 
 
 
 
 
 
23
 
24
  def start_session(req: gr.Request):
25
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
26
  os.makedirs(user_dir, exist_ok=True)
27
-
28
-
29
  def end_session(req: gr.Request):
30
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
31
  shutil.rmtree(user_dir)
@@ -48,10 +58,10 @@ def preprocess_image(image: Image.Image) -> Image.Image:
48
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
49
  """
50
  Preprocess a list of input images.
51
-
52
  Args:
53
  images (List[Tuple[Image.Image, str]]): The input images.
54
-
55
  Returns:
56
  List[Image.Image]: The preprocessed images.
57
  """
@@ -62,41 +72,41 @@ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image
62
 
63
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
64
  return {
65
- 'gaussian': {
66
  **gs.init_params,
67
- '_xyz': gs._xyz.cpu().numpy(),
68
- '_features_dc': gs._features_dc.cpu().numpy(),
69
- '_scaling': gs._scaling.cpu().numpy(),
70
- '_rotation': gs._rotation.cpu().numpy(),
71
- '_opacity': gs._opacity.cpu().numpy(),
72
  },
73
- 'mesh': {
74
- 'vertices': mesh.vertices.cpu().numpy(),
75
- 'faces': mesh.faces.cpu().numpy(),
76
  },
77
  }
78
-
79
-
80
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
81
  gs = Gaussian(
82
- aabb=state['gaussian']['aabb'],
83
- sh_degree=state['gaussian']['sh_degree'],
84
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
85
- scaling_bias=state['gaussian']['scaling_bias'],
86
- opacity_bias=state['gaussian']['opacity_bias'],
87
- scaling_activation=state['gaussian']['scaling_activation'],
88
  )
89
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
90
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
91
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
92
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
93
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
94
-
95
  mesh = edict(
96
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
97
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
98
  )
99
-
100
  return gs, mesh
101
 
102
 
@@ -170,12 +180,14 @@ def image_to_3d(
170
  },
171
  mode=multiimage_algo,
172
  )
173
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
174
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
175
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
176
- video_path = os.path.join(user_dir, 'sample.mp4')
 
 
177
  imageio.mimsave(video_path, video, fps=15)
178
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
179
  torch.cuda.empty_cache()
180
  return state, video_path
181
 
@@ -200,8 +212,10 @@ def extract_glb(
200
  """
201
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
202
  gs, mesh = unpack_state(state)
203
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
204
- glb_path = os.path.join(user_dir, 'sample.glb')
 
 
205
  glb.export(glb_path)
206
  torch.cuda.empty_cache()
207
  return glb_path, glb_path
@@ -220,19 +234,21 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
220
  """
221
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
222
  gs, _ = unpack_state(state)
223
- gaussian_path = os.path.join(user_dir, 'sample.ply')
224
  gs.save_ply(gaussian_path)
225
  torch.cuda.empty_cache()
226
  return gaussian_path, gaussian_path
227
 
228
 
229
  def prepare_multi_example() -> List[Image.Image]:
230
- multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
 
 
231
  images = []
232
  for case in multi_case:
233
  _images = []
234
  for i in range(1, 4):
235
- img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
236
  W, H = img.size
237
  img = img.resize((int(W / H * 512), 512))
238
  _images.append(np.array(img))
@@ -246,71 +262,113 @@ def split_image(image: Image.Image) -> List[Image.Image]:
246
  """
247
  image = np.array(image)
248
  alpha = image[..., 3]
249
- alpha = np.any(alpha>0, axis=0)
250
  start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
251
  end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
252
  images = []
253
  for s, e in zip(start_pos, end_pos):
254
- images.append(Image.fromarray(image[:, s:e+1]))
255
  return [preprocess_image(image) for image in images]
256
 
257
 
258
  with gr.Blocks(delete_cache=(600, 600)) as demo:
259
- gr.Markdown("""
 
260
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
261
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
262
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
263
-
264
  ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
265
- """)
266
-
 
267
  with gr.Row():
268
  with gr.Column():
269
  with gr.Tabs() as input_tabs:
270
  with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
271
- image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
 
 
 
 
 
 
272
  with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
273
- multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
274
- gr.Markdown("""
275
- Input different views of the object in separate images.
276
-
 
 
 
 
 
 
 
277
  *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
278
- """)
279
-
 
280
  with gr.Accordion(label="Generation Settings", open=False):
281
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
282
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
283
  gr.Markdown("Stage 1: Sparse Structure Generation")
284
  with gr.Row():
285
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
286
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
 
 
 
 
287
  gr.Markdown("Stage 2: Structured Latent Generation")
288
  with gr.Row():
289
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
290
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
291
- multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
 
 
 
 
 
 
 
 
292
 
293
  generate_btn = gr.Button("Generate")
294
-
295
  with gr.Accordion(label="GLB Extraction Settings", open=False):
296
- mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.0, step=0.01)
297
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=2048, step=512)
298
-
 
 
 
 
299
  with gr.Row():
300
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
301
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
302
- gr.Markdown("""
 
303
  *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
304
- """)
 
305
 
306
  with gr.Column():
307
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
308
- model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
309
-
 
 
 
 
310
  with gr.Row():
311
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
312
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
313
-
 
 
 
 
314
  is_multiimage = gr.State(False)
315
  output_buf = gr.State()
316
 
@@ -318,7 +376,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
318
  with gr.Row() as single_image_example:
319
  examples = gr.Examples(
320
  examples=[
321
- f'assets/example_image/{image}'
322
  for image in os.listdir("assets/example_image")
323
  ],
324
  inputs=[image_prompt],
@@ -340,16 +398,20 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
340
  # Handlers
341
  demo.load(start_session)
342
  demo.unload(end_session)
343
-
344
  single_image_input_tab.select(
345
- lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
346
- outputs=[is_multiimage, single_image_example, multiimage_example]
 
 
347
  )
348
  multiimage_input_tab.select(
349
- lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
350
- outputs=[is_multiimage, single_image_example, multiimage_example]
 
 
351
  )
352
-
353
  image_prompt.upload(
354
  preprocess_image,
355
  inputs=[image_prompt],
@@ -361,13 +423,19 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
361
  outputs=[multiimage_prompt],
362
  )
363
 
364
- generate_btn.click(
365
- get_seed,
366
- inputs=[randomize_seed, seed],
367
- outputs=[seed],
368
- ).then(
369
  image_to_3d,
370
- inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
 
 
 
 
 
 
 
 
 
 
371
  outputs=[output_buf, video_output],
372
  ).then(
373
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
@@ -387,7 +455,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
387
  lambda: gr.Button(interactive=True),
388
  outputs=[download_glb],
389
  )
390
-
391
  extract_gs_btn.click(
392
  extract_gaussian,
393
  inputs=[output_buf],
@@ -401,14 +469,8 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
401
  lambda: gr.Button(interactive=False),
402
  outputs=[download_glb],
403
  )
404
-
405
 
406
  # Launch the Gradio app
407
  if __name__ == "__main__":
408
- pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
409
- pipeline.cuda()
410
- try:
411
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
412
- except:
413
- pass
414
  demo.launch()
 
4
 
5
  import os
6
  import shutil
7
+
8
+ os.environ["SPCONV_ALGO"] = "native"
9
  from typing import *
10
  import torch
11
  import numpy as np
 
18
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
22
  os.makedirs(TMP_DIR, exist_ok=True)
23
 
24
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
25
+ pipeline.cuda()
26
+ try:
27
+ pipeline.preprocess_image(
28
+ Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
29
+ ) # Preload rembg
30
+ except:
31
+ pass
32
+
33
 
34
  def start_session(req: gr.Request):
35
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
36
  os.makedirs(user_dir, exist_ok=True)
37
+
38
+
39
  def end_session(req: gr.Request):
40
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
41
  shutil.rmtree(user_dir)
 
58
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
59
  """
60
  Preprocess a list of input images.
61
+
62
  Args:
63
  images (List[Tuple[Image.Image, str]]): The input images.
64
+
65
  Returns:
66
  List[Image.Image]: The preprocessed images.
67
  """
 
72
 
73
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
74
  return {
75
+ "gaussian": {
76
  **gs.init_params,
77
+ "_xyz": gs._xyz.cpu().numpy(),
78
+ "_features_dc": gs._features_dc.cpu().numpy(),
79
+ "_scaling": gs._scaling.cpu().numpy(),
80
+ "_rotation": gs._rotation.cpu().numpy(),
81
+ "_opacity": gs._opacity.cpu().numpy(),
82
  },
83
+ "mesh": {
84
+ "vertices": mesh.vertices.cpu().numpy(),
85
+ "faces": mesh.faces.cpu().numpy(),
86
  },
87
  }
88
+
89
+
90
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
91
  gs = Gaussian(
92
+ aabb=state["gaussian"]["aabb"],
93
+ sh_degree=state["gaussian"]["sh_degree"],
94
+ mininum_kernel_size=state["gaussian"]["mininum_kernel_size"],
95
+ scaling_bias=state["gaussian"]["scaling_bias"],
96
+ opacity_bias=state["gaussian"]["opacity_bias"],
97
+ scaling_activation=state["gaussian"]["scaling_activation"],
98
  )
99
+ gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device="cuda")
100
+ gs._features_dc = torch.tensor(state["gaussian"]["_features_dc"], device="cuda")
101
+ gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device="cuda")
102
+ gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device="cuda")
103
+ gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device="cuda")
104
+
105
  mesh = edict(
106
+ vertices=torch.tensor(state["mesh"]["vertices"], device="cuda"),
107
+ faces=torch.tensor(state["mesh"]["faces"], device="cuda"),
108
  )
109
+
110
  return gs, mesh
111
 
112
 
 
180
  },
181
  mode=multiimage_algo,
182
  )
183
+ video = render_utils.render_video(outputs["gaussian"][0], num_frames=120)["color"]
184
+ video_geo = render_utils.render_video(outputs["mesh"][0], num_frames=120)["normal"]
185
+ video = [
186
+ np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))
187
+ ]
188
+ video_path = os.path.join(user_dir, "sample.mp4")
189
  imageio.mimsave(video_path, video, fps=15)
190
+ state = pack_state(outputs["gaussian"][0], outputs["mesh"][0])
191
  torch.cuda.empty_cache()
192
  return state, video_path
193
 
 
212
  """
213
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
214
  gs, mesh = unpack_state(state)
215
+ glb = postprocessing_utils.to_glb(
216
+ gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False
217
+ )
218
+ glb_path = os.path.join(user_dir, "sample.glb")
219
  glb.export(glb_path)
220
  torch.cuda.empty_cache()
221
  return glb_path, glb_path
 
234
  """
235
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
236
  gs, _ = unpack_state(state)
237
+ gaussian_path = os.path.join(user_dir, "sample.ply")
238
  gs.save_ply(gaussian_path)
239
  torch.cuda.empty_cache()
240
  return gaussian_path, gaussian_path
241
 
242
 
243
  def prepare_multi_example() -> List[Image.Image]:
244
+ multi_case = list(
245
+ set([i.split("_")[0] for i in os.listdir("assets/example_multi_image")])
246
+ )
247
  images = []
248
  for case in multi_case:
249
  _images = []
250
  for i in range(1, 4):
251
+ img = Image.open(f"assets/example_multi_image/{case}_{i}.png")
252
  W, H = img.size
253
  img = img.resize((int(W / H * 512), 512))
254
  _images.append(np.array(img))
 
262
  """
263
  image = np.array(image)
264
  alpha = image[..., 3]
265
+ alpha = np.any(alpha > 0, axis=0)
266
  start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
267
  end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
268
  images = []
269
  for s, e in zip(start_pos, end_pos):
270
+ images.append(Image.fromarray(image[:, s : e + 1]))
271
  return [preprocess_image(image) for image in images]
272
 
273
 
274
  with gr.Blocks(delete_cache=(600, 600)) as demo:
275
+ gr.Markdown(
276
+ """
277
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
278
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
279
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
280
+
281
  ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
282
+ """
283
+ )
284
+
285
  with gr.Row():
286
  with gr.Column():
287
  with gr.Tabs() as input_tabs:
288
  with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
289
+ image_prompt = gr.Image(
290
+ label="Image Prompt",
291
+ format="png",
292
+ image_mode="RGBA",
293
+ type="pil",
294
+ height=300,
295
+ )
296
  with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
297
+ multiimage_prompt = gr.Gallery(
298
+ label="Image Prompt",
299
+ format="png",
300
+ type="pil",
301
+ height=300,
302
+ columns=3,
303
+ )
304
+ gr.Markdown(
305
+ """
306
+ Input different views of the object in separate images.
307
+
308
  *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
309
+ """
310
+ )
311
+
312
  with gr.Accordion(label="Generation Settings", open=False):
313
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
314
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
315
  gr.Markdown("Stage 1: Sparse Structure Generation")
316
  with gr.Row():
317
+ ss_guidance_strength = gr.Slider(
318
+ 0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1
319
+ )
320
+ ss_sampling_steps = gr.Slider(
321
+ 1, 50, label="Sampling Steps", value=12, step=1
322
+ )
323
  gr.Markdown("Stage 2: Structured Latent Generation")
324
  with gr.Row():
325
+ slat_guidance_strength = gr.Slider(
326
+ 0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1
327
+ )
328
+ slat_sampling_steps = gr.Slider(
329
+ 1, 50, label="Sampling Steps", value=12, step=1
330
+ )
331
+ multiimage_algo = gr.Radio(
332
+ ["stochastic", "multidiffusion"],
333
+ label="Multi-image Algorithm",
334
+ value="stochastic",
335
+ )
336
 
337
  generate_btn = gr.Button("Generate")
338
+
339
  with gr.Accordion(label="GLB Extraction Settings", open=False):
340
+ mesh_simplify = gr.Slider(
341
+ 0.0, 0.98, label="Simplify", value=0.0, step=0.01
342
+ )
343
+ texture_size = gr.Slider(
344
+ 512, 2048, label="Texture Size", value=2048, step=512
345
+ )
346
+
347
  with gr.Row():
348
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
349
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
350
+ gr.Markdown(
351
+ """
352
  *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
353
+ """
354
+ )
355
 
356
  with gr.Column():
357
+ video_output = gr.Video(
358
+ label="Generated 3D Asset", autoplay=True, loop=True, height=300
359
+ )
360
+ model_output = LitModel3D(
361
+ label="Extracted GLB/Gaussian", exposure=10.0, height=300
362
+ )
363
+
364
  with gr.Row():
365
+ download_glb = gr.DownloadButton(
366
+ label="Download GLB", interactive=False
367
+ )
368
+ download_gs = gr.DownloadButton(
369
+ label="Download Gaussian", interactive=False
370
+ )
371
+
372
  is_multiimage = gr.State(False)
373
  output_buf = gr.State()
374
 
 
376
  with gr.Row() as single_image_example:
377
  examples = gr.Examples(
378
  examples=[
379
+ f"assets/example_image/{image}"
380
  for image in os.listdir("assets/example_image")
381
  ],
382
  inputs=[image_prompt],
 
398
  # Handlers
399
  demo.load(start_session)
400
  demo.unload(end_session)
401
+
402
  single_image_input_tab.select(
403
+ lambda: tuple(
404
+ [False, gr.Row.update(visible=True), gr.Row.update(visible=False)]
405
+ ),
406
+ outputs=[is_multiimage, single_image_example, multiimage_example],
407
  )
408
  multiimage_input_tab.select(
409
+ lambda: tuple(
410
+ [True, gr.Row.update(visible=False), gr.Row.update(visible=True)]
411
+ ),
412
+ outputs=[is_multiimage, single_image_example, multiimage_example],
413
  )
414
+
415
  image_prompt.upload(
416
  preprocess_image,
417
  inputs=[image_prompt],
 
423
  outputs=[multiimage_prompt],
424
  )
425
 
426
+ generate_btn.click(get_seed, inputs=[randomize_seed, seed], outputs=[seed],).then(
 
 
 
 
427
  image_to_3d,
428
+ inputs=[
429
+ image_prompt,
430
+ multiimage_prompt,
431
+ is_multiimage,
432
+ seed,
433
+ ss_guidance_strength,
434
+ ss_sampling_steps,
435
+ slat_guidance_strength,
436
+ slat_sampling_steps,
437
+ multiimage_algo,
438
+ ],
439
  outputs=[output_buf, video_output],
440
  ).then(
441
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
 
455
  lambda: gr.Button(interactive=True),
456
  outputs=[download_glb],
457
  )
458
+
459
  extract_gs_btn.click(
460
  extract_gaussian,
461
  inputs=[output_buf],
 
469
  lambda: gr.Button(interactive=False),
470
  outputs=[download_glb],
471
  )
472
+
473
 
474
  # Launch the Gradio app
475
  if __name__ == "__main__":
 
 
 
 
 
 
476
  demo.launch()
extensions/nvdiffrast/README.md CHANGED
@@ -21,7 +21,7 @@ We do not currently accept outside code contributions in the form of pull reques
21
 
22
  Environment map stored as part of `samples/data/envphong.npz` is derived from a Wave Engine
23
  [sample material](https://github.com/WaveEngine/Samples-2.5/tree/master/Materials/EnvironmentMap/Content/Assets/CubeMap.cubemap)
24
- originally shared under
25
  [MIT License](https://github.com/WaveEngine/Samples-2.5/blob/master/LICENSE.md).
26
  Mesh and texture stored as part of `samples/data/earth.npz` are derived from
27
  [3D Earth Photorealistic 2K](https://www.turbosquid.com/3d-models/3d-realistic-earth-photorealistic-2k-1279125)
 
21
 
22
  Environment map stored as part of `samples/data/envphong.npz` is derived from a Wave Engine
23
  [sample material](https://github.com/WaveEngine/Samples-2.5/tree/master/Materials/EnvironmentMap/Content/Assets/CubeMap.cubemap)
24
+ originally shared under
25
  [MIT License](https://github.com/WaveEngine/Samples-2.5/blob/master/LICENSE.md).
26
  Mesh and texture stored as part of `samples/data/earth.npz` are derived from
27
  [3D Earth Photorealistic 2K](https://www.turbosquid.com/3d-models/3d-realistic-earth-photorealistic-2k-1279125)
extensions/nvdiffrast/nvdiffrast/__init__.py CHANGED
@@ -6,4 +6,4 @@
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
- __version__ = '0.3.3'
 
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
+ __version__ = "0.3.3"
extensions/nvdiffrast/nvdiffrast/common/antialias.cu CHANGED
@@ -112,7 +112,7 @@ static __device__ __forceinline__ void evhash_insert_vertex(const AntialiasKerne
112
  {
113
  if (va == vb)
114
  return;
115
-
116
  uint64_t v0 = (uint32_t)min(va, vb) + 1; // canonical vertex order
117
  uint64_t v1 = (uint32_t)max(va, vb) + 1;
118
  uint64_t vk = v0 | (v1 << 32); // hash key
 
112
  {
113
  if (va == vb)
114
  return;
115
+
116
  uint64_t v0 = (uint32_t)min(va, vb) + 1; // canonical vertex order
117
  uint64_t v1 = (uint32_t)max(va, vb) + 1;
118
  uint64_t vk = v0 | (v1 << 32); // hash key
extensions/nvdiffrast/nvdiffrast/common/cudaraster/CudaRaster.hpp CHANGED
@@ -60,4 +60,3 @@ private:
60
 
61
  //------------------------------------------------------------------------
62
  } // namespace CR
63
-
 
60
 
61
  //------------------------------------------------------------------------
62
  } // namespace CR
 
extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.hpp CHANGED
@@ -99,4 +99,3 @@ private:
99
 
100
  //------------------------------------------------------------------------
101
  } // namespace CR
102
-
 
99
 
100
  //------------------------------------------------------------------------
101
  } // namespace CR
 
extensions/nvdiffrast/nvdiffrast/common/interpolate.cu CHANGED
@@ -94,9 +94,9 @@ static __forceinline__ __device__ void InterpolateFwdKernelTemplate(const Interp
94
  float dvdx = db.z;
95
  float dvdy = db.w;
96
 
97
- // Calculate the pixel differentials of chosen attributes.
98
  for (int i=0; i < p.numDiffAttr; i++)
99
- {
100
  // Input attribute index.
101
  int j = p.diff_attrs_all ? i : p.diffAttrs[i];
102
  if (j < 0)
 
94
  float dvdx = db.z;
95
  float dvdy = db.w;
96
 
97
+ // Calculate the pixel differentials of chosen attributes.
98
  for (int i=0; i < p.numDiffAttr; i++)
99
+ {
100
  // Input attribute index.
101
  int j = p.diff_attrs_all ? i : p.diffAttrs[i];
102
  if (j < 0)
extensions/nvdiffrast/nvdiffrast/common/texture.cpp CHANGED
@@ -47,7 +47,7 @@ void raiseMipSizeError(NVDR_CTX_ARGS, const TextureKernelParams& p)
47
 
48
  // Append level size to error message.
49
  snprintf(buf, bufsz, "mip %-2d ", level);
50
- msg += buf;
51
  if (ew) snprintf(buf, bufsz, " err ");
52
  else snprintf(buf, bufsz, "%5d ", w);
53
  msg += buf;
 
47
 
48
  // Append level size to error message.
49
  snprintf(buf, bufsz, "mip %-2d ", level);
50
+ msg += buf;
51
  if (ew) snprintf(buf, bufsz, " err ");
52
  else snprintf(buf, bufsz, "%5d ", w);
53
  msg += buf;
extensions/nvdiffrast/nvdiffrast/tensorflow/ops.py CHANGED
@@ -11,22 +11,26 @@ import numpy as np
11
  import os
12
  from . import plugin_loader
13
 
14
- #----------------------------------------------------------------------------
15
  # Helpers.
16
- #----------------------------------------------------------------------------
17
 
18
  # OpenGL-related linker options depending on platform.
19
  def _get_gl_opts():
20
  libs = {
21
- 'posix': ['GL', 'EGL'],
22
- 'nt': ['gdi32', 'opengl32', 'user32', 'setgpu'],
23
  }
24
- return ['-l' + x for x in libs[os.name]]
 
25
 
26
  # Load the cpp plugin.
27
  def _get_plugin():
28
- fn = os.path.join(os.path.dirname(__file__), 'tf_all.cu')
29
- return plugin_loader.get_plugin(fn, extra_nvcc_options=_get_gl_opts() + ['-DNVDR_TENSORFLOW'])
 
 
 
30
 
31
  # Convert parameter to a numpy array if possible.
32
  def _get_constant(x, dtype):
@@ -35,19 +39,24 @@ def _get_constant(x, dtype):
35
  except (TypeError, ValueError):
36
  return None
37
 
 
38
  # Tests for a construction-time constantness instead of tf.constant node because
39
  # the latter can be overridden in Session.run() feed_dict at evaluation time.
40
  def _is_constant(x, dtype):
41
  if isinstance(x, np.ndarray):
42
- return np.can_cast(x.dtype, dtype, 'unsafe')
43
  else:
44
  return _get_constant(x, dtype) is not None
45
 
46
- #----------------------------------------------------------------------------
 
47
  # Rasterize.
48
- #----------------------------------------------------------------------------
49
 
50
- def rasterize(pos, tri, resolution, ranges=None, tri_const=False, output_db=True, grad_db=True):
 
 
 
51
  assert tri_const is True or tri_const is False
52
  assert output_db is True or output_db is False
53
 
@@ -63,15 +72,19 @@ def rasterize(pos, tri, resolution, ranges=None, tri_const=False, output_db=True
63
  pos = tf.convert_to_tensor(pos, dtype=tf.float32)
64
  resolution = tf.convert_to_tensor(resolution, dtype=tf.int32)
65
  if ranges is None:
66
- ranges = tf.convert_to_tensor(np.zeros(shape=[0, 2], dtype=np.int32)) # Empty tensor.
 
 
67
  else:
68
- ranges = tf.convert_to_tensor(ranges, dtype=tf.int32) # Convert input to tensor.
 
 
69
 
70
  # Infer as much about the output shape as possible.
71
  out_shape = [None, None, None, 4]
72
- if pos.shape.rank == 3: # Instanced mode.
73
  out_shape[0] = pos.shape[0].value
74
- elif pos.shape.rank == 2: # Range mode.
75
  if ranges.shape.rank not in [None, 0]:
76
  out_shape[0] = ranges.shape[0].value
77
  if resolution_c is not None:
@@ -81,24 +94,32 @@ def rasterize(pos, tri, resolution, ranges=None, tri_const=False, output_db=True
81
  # Output pixel differentials.
82
  @tf.custom_gradient
83
  def func_db(pos):
84
- out, out_db = _get_plugin().rasterize_fwd(pos, tri, resolution, ranges, 1, tri_const)
 
 
85
  out.set_shape(out_shape)
86
  out_db.set_shape(out_shape)
 
87
  def grad(dy, ddb):
88
  if grad_db:
89
  return _get_plugin().rasterize_grad_db(pos, tri, out, dy, ddb)
90
  else:
91
  return _get_plugin().rasterize_grad(pos, tri, out, dy)
 
92
  return (out, out_db), grad
93
 
94
  # Do not output pixel differentials.
95
  @tf.custom_gradient
96
  def func(pos):
97
- out, out_db = _get_plugin().rasterize_fwd(pos, tri, resolution, ranges, 0, tri_const)
 
 
98
  out.set_shape(out_shape)
99
- out_db.set_shape(out_shape[:-1] + [0]) # Zero channels in out_db.
 
100
  def grad(dy, _):
101
  return _get_plugin().rasterize_grad(pos, tri, out, dy)
 
102
  return (out, out_db), grad
103
 
104
  # Choose stub.
@@ -107,15 +128,17 @@ def rasterize(pos, tri, resolution, ranges=None, tri_const=False, output_db=True
107
  else:
108
  return func(pos)
109
 
110
- #----------------------------------------------------------------------------
 
111
  # Interpolate.
112
- #----------------------------------------------------------------------------
 
113
 
114
  def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None):
115
  # Sanitize the list of pixel differential attributes.
116
  if diff_attrs is None:
117
  diff_attrs = []
118
- elif diff_attrs != 'all':
119
  diff_attrs = _get_constant(diff_attrs, np.int32)
120
  assert (diff_attrs is not None) and len(diff_attrs.shape) == 1
121
  diff_attrs = diff_attrs.tolist()
@@ -130,16 +153,23 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None):
130
  # Infer output shape.
131
  out_shape = [None, None, None, None]
132
  if rast.shape.rank is not None:
133
- out_shape = [rast.shape[0].value, rast.shape[1].value, rast.shape[2].value, None]
 
 
 
 
 
134
  if attr.shape.rank in [2, 3]:
135
  out_shape[3] = attr.shape[-1].value
136
 
137
  # Output pixel differentials for at least some attributes.
138
  @tf.custom_gradient
139
  def func_da(attr, rast, rast_db):
140
- diff_attrs_all = int(diff_attrs == 'all')
141
  diff_attrs_list = [] if diff_attrs_all else diff_attrs
142
- out, out_da = _get_plugin().interpolate_fwd_da(attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list)
 
 
143
 
144
  # Infer number of channels in out_da.
145
  if not diff_attrs_all:
@@ -154,7 +184,10 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None):
154
  out_da.set_shape([out_shape[0], out_shape[1], out_shape[2], da_channels])
155
 
156
  def grad(dy, dda):
157
- return _get_plugin().interpolate_grad_da(attr, rast, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list)
 
 
 
158
  return (out, out_da), grad
159
 
160
  # No pixel differentials for any attribute.
@@ -162,9 +195,11 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None):
162
  def func(attr, rast):
163
  out, out_da = _get_plugin().interpolate_fwd(attr, rast, tri)
164
  out.set_shape(out_shape)
165
- out_da.set_shape(out_shape[:-1] + [0]) # Zero channels in out_da.
 
166
  def grad(dy, _):
167
  return _get_plugin().interpolate_grad(attr, rast, tri, dy)
 
168
  return (out, out_da), grad
169
 
170
  # Choose stub.
@@ -173,16 +208,26 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None):
173
  else:
174
  return func(attr, rast)
175
 
176
- #----------------------------------------------------------------------------
177
- # Texture.
178
- #----------------------------------------------------------------------------
179
 
180
- def texture(tex, uv, uv_da=None, filter_mode='auto', boundary_mode='wrap', tex_const=False, max_mip_level=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  assert tex_const is True or tex_const is False
182
 
183
  # Default filter mode.
184
- if filter_mode == 'auto':
185
- filter_mode = 'linear-mipmap-linear' if (uv_da is not None) else 'linear'
186
 
187
  # Known constant texture?
188
  tex_const = tex_const or _is_constant(tex, np.float32)
@@ -198,7 +243,7 @@ def texture(tex, uv, uv_da=None, filter_mode='auto', boundary_mode='wrap', tex_c
198
  # Convert inputs to tensors.
199
  tex = tf.convert_to_tensor(tex, dtype=tf.float32)
200
  uv = tf.convert_to_tensor(uv, dtype=tf.float32)
201
- if 'mipmap' in filter_mode:
202
  uv_da = tf.convert_to_tensor(uv_da, dtype=tf.float32)
203
 
204
  # Infer output shape.
@@ -207,37 +252,83 @@ def texture(tex, uv, uv_da=None, filter_mode='auto', boundary_mode='wrap', tex_c
207
  assert uv.shape.rank == 4
208
  out_shape = [uv.shape[0].value, uv.shape[1].value, uv.shape[2].value, None]
209
  if tex.shape.rank is not None:
210
- assert tex.shape.rank == (5 if boundary_mode == 'cube' else 4)
211
  out_shape[-1] = tex.shape[-1].value
212
 
213
  # If mipping disabled via max level=0, we may as well use simpler filtering internally.
214
- if max_mip_level == 0 and filter_mode in ['linear-mipmap-nearest', 'linear-mipmap-linear']:
215
- filter_mode = 'linear'
 
 
 
216
 
217
  # Convert filter mode to internal enumeration.
218
- filter_mode_dict = {'nearest': 0, 'linear': 1, 'linear-mipmap-nearest': 2, 'linear-mipmap-linear': 3}
 
 
 
 
 
219
  filter_mode_enum = filter_mode_dict[filter_mode]
220
 
221
  # Convert boundary mode to internal enumeration.
222
- boundary_mode_dict = {'cube': 0, 'wrap': 1, 'clamp': 2, 'zero': 3}
223
  boundary_mode_enum = boundary_mode_dict[boundary_mode]
224
 
225
  # Linear-mipmap-linear: Mipmaps enabled, all gradients active.
226
  @tf.custom_gradient
227
  def func_linear_mipmap_linear(tex, uv, uv_da):
228
- out, mip = _get_plugin().texture_fwd_mip(tex, uv, uv_da, filter_mode_enum, boundary_mode_enum, tex_const, max_mip_level)
 
 
 
 
 
 
 
 
229
  out.set_shape(out_shape)
 
230
  def grad(dy):
231
- return _get_plugin().texture_grad_linear_mipmap_linear(tex, uv, dy, uv_da, mip, filter_mode_enum, boundary_mode_enum, max_mip_level)
 
 
 
 
 
 
 
 
 
 
232
  return out, grad
233
 
234
  # Linear-mipmap-nearest: Mipmaps enabled, no gradients to uv_da.
235
  @tf.custom_gradient
236
  def func_linear_mipmap_nearest(tex, uv):
237
- out, mip = _get_plugin().texture_fwd_mip(tex, uv, uv_da, filter_mode_enum, boundary_mode_enum, tex_const, max_mip_level)
 
 
 
 
 
 
 
 
238
  out.set_shape(out_shape)
 
239
  def grad(dy):
240
- return _get_plugin().texture_grad_linear_mipmap_nearest(tex, uv, dy, uv_da, mip, filter_mode_enum, boundary_mode_enum, max_mip_level)
 
 
 
 
 
 
 
 
 
 
241
  return out, grad
242
 
243
  # Linear: Mipmaps disabled, no uv_da, no gradients to uv_da.
@@ -245,8 +336,12 @@ def texture(tex, uv, uv_da=None, filter_mode='auto', boundary_mode='wrap', tex_c
245
  def func_linear(tex, uv):
246
  out = _get_plugin().texture_fwd(tex, uv, filter_mode_enum, boundary_mode_enum)
247
  out.set_shape(out_shape)
 
248
  def grad(dy):
249
- return _get_plugin().texture_grad_linear(tex, uv, dy, filter_mode_enum, boundary_mode_enum)
 
 
 
250
  return out, grad
251
 
252
  # Nearest: Mipmaps disabled, no uv_da, no gradients to uv_da or uv.
@@ -254,23 +349,29 @@ def texture(tex, uv, uv_da=None, filter_mode='auto', boundary_mode='wrap', tex_c
254
  def func_nearest(tex):
255
  out = _get_plugin().texture_fwd(tex, uv, filter_mode_enum, boundary_mode_enum)
256
  out.set_shape(out_shape)
 
257
  def grad(dy):
258
- return _get_plugin().texture_grad_nearest(tex, uv, dy, filter_mode_enum, boundary_mode_enum)
 
 
 
259
  return out, grad
260
 
261
  # Choose stub.
262
- if filter_mode == 'linear-mipmap-linear':
263
  return func_linear_mipmap_linear(tex, uv, uv_da)
264
- elif filter_mode == 'linear-mipmap-nearest':
265
  return func_linear_mipmap_nearest(tex, uv)
266
- elif filter_mode == 'linear':
267
  return func_linear(tex, uv)
268
- elif filter_mode == 'nearest':
269
  return func_nearest(tex)
270
 
271
- #----------------------------------------------------------------------------
 
272
  # Antialias.
273
- #----------------------------------------------------------------------------
 
274
 
275
  def antialias(color, rast, pos, tri, tri_const=False, pos_gradient_boost=1.0):
276
  assert tri_const is True or tri_const is False
@@ -289,15 +390,22 @@ def antialias(color, rast, pos, tri, tri_const=False, pos_gradient_boost=1.0):
289
 
290
  @tf.custom_gradient
291
  def func(color, pos):
292
- color_out, work_buffer = _get_plugin().antialias_fwd(color, rast, pos, tri, tri_const)
 
 
293
  color_out.set_shape(color.shape)
 
294
  def grad(dy):
295
- grad_color, grad_pos = _get_plugin().antialias_grad(color, rast, pos, tri, dy, work_buffer)
 
 
296
  if pos_gradient_boost != 1.0:
297
  grad_pos = grad_pos * pos_gradient_boost
298
  return grad_color, grad_pos
 
299
  return color_out, grad
300
 
301
  return func(color, pos)
302
 
303
- #----------------------------------------------------------------------------
 
 
11
  import os
12
  from . import plugin_loader
13
 
14
+ # ----------------------------------------------------------------------------
15
  # Helpers.
16
+ # ----------------------------------------------------------------------------
17
 
18
  # OpenGL-related linker options depending on platform.
19
  def _get_gl_opts():
20
  libs = {
21
+ "posix": ["GL", "EGL"],
22
+ "nt": ["gdi32", "opengl32", "user32", "setgpu"],
23
  }
24
+ return ["-l" + x for x in libs[os.name]]
25
+
26
 
27
  # Load the cpp plugin.
28
  def _get_plugin():
29
+ fn = os.path.join(os.path.dirname(__file__), "tf_all.cu")
30
+ return plugin_loader.get_plugin(
31
+ fn, extra_nvcc_options=_get_gl_opts() + ["-DNVDR_TENSORFLOW"]
32
+ )
33
+
34
 
35
  # Convert parameter to a numpy array if possible.
36
  def _get_constant(x, dtype):
 
39
  except (TypeError, ValueError):
40
  return None
41
 
42
+
43
  # Tests for a construction-time constantness instead of tf.constant node because
44
  # the latter can be overridden in Session.run() feed_dict at evaluation time.
45
  def _is_constant(x, dtype):
46
  if isinstance(x, np.ndarray):
47
+ return np.can_cast(x.dtype, dtype, "unsafe")
48
  else:
49
  return _get_constant(x, dtype) is not None
50
 
51
+
52
+ # ----------------------------------------------------------------------------
53
  # Rasterize.
54
+ # ----------------------------------------------------------------------------
55
 
56
+
57
+ def rasterize(
58
+ pos, tri, resolution, ranges=None, tri_const=False, output_db=True, grad_db=True
59
+ ):
60
  assert tri_const is True or tri_const is False
61
  assert output_db is True or output_db is False
62
 
 
72
  pos = tf.convert_to_tensor(pos, dtype=tf.float32)
73
  resolution = tf.convert_to_tensor(resolution, dtype=tf.int32)
74
  if ranges is None:
75
+ ranges = tf.convert_to_tensor(
76
+ np.zeros(shape=[0, 2], dtype=np.int32)
77
+ ) # Empty tensor.
78
  else:
79
+ ranges = tf.convert_to_tensor(
80
+ ranges, dtype=tf.int32
81
+ ) # Convert input to tensor.
82
 
83
  # Infer as much about the output shape as possible.
84
  out_shape = [None, None, None, 4]
85
+ if pos.shape.rank == 3: # Instanced mode.
86
  out_shape[0] = pos.shape[0].value
87
+ elif pos.shape.rank == 2: # Range mode.
88
  if ranges.shape.rank not in [None, 0]:
89
  out_shape[0] = ranges.shape[0].value
90
  if resolution_c is not None:
 
94
  # Output pixel differentials.
95
  @tf.custom_gradient
96
  def func_db(pos):
97
+ out, out_db = _get_plugin().rasterize_fwd(
98
+ pos, tri, resolution, ranges, 1, tri_const
99
+ )
100
  out.set_shape(out_shape)
101
  out_db.set_shape(out_shape)
102
+
103
  def grad(dy, ddb):
104
  if grad_db:
105
  return _get_plugin().rasterize_grad_db(pos, tri, out, dy, ddb)
106
  else:
107
  return _get_plugin().rasterize_grad(pos, tri, out, dy)
108
+
109
  return (out, out_db), grad
110
 
111
  # Do not output pixel differentials.
112
  @tf.custom_gradient
113
  def func(pos):
114
+ out, out_db = _get_plugin().rasterize_fwd(
115
+ pos, tri, resolution, ranges, 0, tri_const
116
+ )
117
  out.set_shape(out_shape)
118
+ out_db.set_shape(out_shape[:-1] + [0]) # Zero channels in out_db.
119
+
120
  def grad(dy, _):
121
  return _get_plugin().rasterize_grad(pos, tri, out, dy)
122
+
123
  return (out, out_db), grad
124
 
125
  # Choose stub.
 
128
  else:
129
  return func(pos)
130
 
131
+
132
+ # ----------------------------------------------------------------------------
133
  # Interpolate.
134
+ # ----------------------------------------------------------------------------
135
+
136
 
137
  def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None):
138
  # Sanitize the list of pixel differential attributes.
139
  if diff_attrs is None:
140
  diff_attrs = []
141
+ elif diff_attrs != "all":
142
  diff_attrs = _get_constant(diff_attrs, np.int32)
143
  assert (diff_attrs is not None) and len(diff_attrs.shape) == 1
144
  diff_attrs = diff_attrs.tolist()
 
153
  # Infer output shape.
154
  out_shape = [None, None, None, None]
155
  if rast.shape.rank is not None:
156
+ out_shape = [
157
+ rast.shape[0].value,
158
+ rast.shape[1].value,
159
+ rast.shape[2].value,
160
+ None,
161
+ ]
162
  if attr.shape.rank in [2, 3]:
163
  out_shape[3] = attr.shape[-1].value
164
 
165
  # Output pixel differentials for at least some attributes.
166
  @tf.custom_gradient
167
  def func_da(attr, rast, rast_db):
168
+ diff_attrs_all = int(diff_attrs == "all")
169
  diff_attrs_list = [] if diff_attrs_all else diff_attrs
170
+ out, out_da = _get_plugin().interpolate_fwd_da(
171
+ attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list
172
+ )
173
 
174
  # Infer number of channels in out_da.
175
  if not diff_attrs_all:
 
184
  out_da.set_shape([out_shape[0], out_shape[1], out_shape[2], da_channels])
185
 
186
  def grad(dy, dda):
187
+ return _get_plugin().interpolate_grad_da(
188
+ attr, rast, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list
189
+ )
190
+
191
  return (out, out_da), grad
192
 
193
  # No pixel differentials for any attribute.
 
195
  def func(attr, rast):
196
  out, out_da = _get_plugin().interpolate_fwd(attr, rast, tri)
197
  out.set_shape(out_shape)
198
+ out_da.set_shape(out_shape[:-1] + [0]) # Zero channels in out_da.
199
+
200
  def grad(dy, _):
201
  return _get_plugin().interpolate_grad(attr, rast, tri, dy)
202
+
203
  return (out, out_da), grad
204
 
205
  # Choose stub.
 
208
  else:
209
  return func(attr, rast)
210
 
 
 
 
211
 
212
+ # ----------------------------------------------------------------------------
213
+ # Texture.
214
+ # ----------------------------------------------------------------------------
215
+
216
+
217
+ def texture(
218
+ tex,
219
+ uv,
220
+ uv_da=None,
221
+ filter_mode="auto",
222
+ boundary_mode="wrap",
223
+ tex_const=False,
224
+ max_mip_level=None,
225
+ ):
226
  assert tex_const is True or tex_const is False
227
 
228
  # Default filter mode.
229
+ if filter_mode == "auto":
230
+ filter_mode = "linear-mipmap-linear" if (uv_da is not None) else "linear"
231
 
232
  # Known constant texture?
233
  tex_const = tex_const or _is_constant(tex, np.float32)
 
243
  # Convert inputs to tensors.
244
  tex = tf.convert_to_tensor(tex, dtype=tf.float32)
245
  uv = tf.convert_to_tensor(uv, dtype=tf.float32)
246
+ if "mipmap" in filter_mode:
247
  uv_da = tf.convert_to_tensor(uv_da, dtype=tf.float32)
248
 
249
  # Infer output shape.
 
252
  assert uv.shape.rank == 4
253
  out_shape = [uv.shape[0].value, uv.shape[1].value, uv.shape[2].value, None]
254
  if tex.shape.rank is not None:
255
+ assert tex.shape.rank == (5 if boundary_mode == "cube" else 4)
256
  out_shape[-1] = tex.shape[-1].value
257
 
258
  # If mipping disabled via max level=0, we may as well use simpler filtering internally.
259
+ if max_mip_level == 0 and filter_mode in [
260
+ "linear-mipmap-nearest",
261
+ "linear-mipmap-linear",
262
+ ]:
263
+ filter_mode = "linear"
264
 
265
  # Convert filter mode to internal enumeration.
266
+ filter_mode_dict = {
267
+ "nearest": 0,
268
+ "linear": 1,
269
+ "linear-mipmap-nearest": 2,
270
+ "linear-mipmap-linear": 3,
271
+ }
272
  filter_mode_enum = filter_mode_dict[filter_mode]
273
 
274
  # Convert boundary mode to internal enumeration.
275
+ boundary_mode_dict = {"cube": 0, "wrap": 1, "clamp": 2, "zero": 3}
276
  boundary_mode_enum = boundary_mode_dict[boundary_mode]
277
 
278
  # Linear-mipmap-linear: Mipmaps enabled, all gradients active.
279
  @tf.custom_gradient
280
  def func_linear_mipmap_linear(tex, uv, uv_da):
281
+ out, mip = _get_plugin().texture_fwd_mip(
282
+ tex,
283
+ uv,
284
+ uv_da,
285
+ filter_mode_enum,
286
+ boundary_mode_enum,
287
+ tex_const,
288
+ max_mip_level,
289
+ )
290
  out.set_shape(out_shape)
291
+
292
  def grad(dy):
293
+ return _get_plugin().texture_grad_linear_mipmap_linear(
294
+ tex,
295
+ uv,
296
+ dy,
297
+ uv_da,
298
+ mip,
299
+ filter_mode_enum,
300
+ boundary_mode_enum,
301
+ max_mip_level,
302
+ )
303
+
304
  return out, grad
305
 
306
  # Linear-mipmap-nearest: Mipmaps enabled, no gradients to uv_da.
307
  @tf.custom_gradient
308
  def func_linear_mipmap_nearest(tex, uv):
309
+ out, mip = _get_plugin().texture_fwd_mip(
310
+ tex,
311
+ uv,
312
+ uv_da,
313
+ filter_mode_enum,
314
+ boundary_mode_enum,
315
+ tex_const,
316
+ max_mip_level,
317
+ )
318
  out.set_shape(out_shape)
319
+
320
  def grad(dy):
321
+ return _get_plugin().texture_grad_linear_mipmap_nearest(
322
+ tex,
323
+ uv,
324
+ dy,
325
+ uv_da,
326
+ mip,
327
+ filter_mode_enum,
328
+ boundary_mode_enum,
329
+ max_mip_level,
330
+ )
331
+
332
  return out, grad
333
 
334
  # Linear: Mipmaps disabled, no uv_da, no gradients to uv_da.
 
336
  def func_linear(tex, uv):
337
  out = _get_plugin().texture_fwd(tex, uv, filter_mode_enum, boundary_mode_enum)
338
  out.set_shape(out_shape)
339
+
340
  def grad(dy):
341
+ return _get_plugin().texture_grad_linear(
342
+ tex, uv, dy, filter_mode_enum, boundary_mode_enum
343
+ )
344
+
345
  return out, grad
346
 
347
  # Nearest: Mipmaps disabled, no uv_da, no gradients to uv_da or uv.
 
349
  def func_nearest(tex):
350
  out = _get_plugin().texture_fwd(tex, uv, filter_mode_enum, boundary_mode_enum)
351
  out.set_shape(out_shape)
352
+
353
  def grad(dy):
354
+ return _get_plugin().texture_grad_nearest(
355
+ tex, uv, dy, filter_mode_enum, boundary_mode_enum
356
+ )
357
+
358
  return out, grad
359
 
360
  # Choose stub.
361
+ if filter_mode == "linear-mipmap-linear":
362
  return func_linear_mipmap_linear(tex, uv, uv_da)
363
+ elif filter_mode == "linear-mipmap-nearest":
364
  return func_linear_mipmap_nearest(tex, uv)
365
+ elif filter_mode == "linear":
366
  return func_linear(tex, uv)
367
+ elif filter_mode == "nearest":
368
  return func_nearest(tex)
369
 
370
+
371
+ # ----------------------------------------------------------------------------
372
  # Antialias.
373
+ # ----------------------------------------------------------------------------
374
+
375
 
376
  def antialias(color, rast, pos, tri, tri_const=False, pos_gradient_boost=1.0):
377
  assert tri_const is True or tri_const is False
 
390
 
391
  @tf.custom_gradient
392
  def func(color, pos):
393
+ color_out, work_buffer = _get_plugin().antialias_fwd(
394
+ color, rast, pos, tri, tri_const
395
+ )
396
  color_out.set_shape(color.shape)
397
+
398
  def grad(dy):
399
+ grad_color, grad_pos = _get_plugin().antialias_grad(
400
+ color, rast, pos, tri, dy, work_buffer
401
+ )
402
  if pos_gradient_boost != 1.0:
403
  grad_pos = grad_pos * pos_gradient_boost
404
  return grad_color, grad_pos
405
+
406
  return color_out, grad
407
 
408
  return func(color, pos)
409
 
410
+
411
+ # ----------------------------------------------------------------------------
extensions/nvdiffrast/nvdiffrast/tensorflow/plugin_loader.py CHANGED
@@ -14,15 +14,16 @@ import hashlib
14
  import tempfile
15
  import shutil
16
  import tensorflow as tf
17
- from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
18
 
19
- #----------------------------------------------------------------------------
20
  # Global options.
21
 
22
  _nvdiffrast_cache_dir = None
23
 
 
24
  def set_cache_dir(path: str) -> None:
25
- '''Set CUDA kernel compilation temp dir.
26
 
27
  If `set_cache_dir` is not called, the cache directory will default to
28
  one of the below:
@@ -33,103 +34,164 @@ def set_cache_dir(path: str) -> None:
33
 
34
  Args:
35
  path: Where to save CUDA kernel build temporaries
36
- '''
37
  global _nvdiffrast_cache_dir
38
  _nvdiffrast_cache_dir = path
39
 
 
40
  def make_cache_dir_path(*paths: str) -> str:
41
  if _nvdiffrast_cache_dir is not None:
42
  return os.path.join(_nvdiffrast_cache_dir, *paths)
43
- if 'NVDIFFRAST_CACHE_DIR' in os.environ:
44
- return os.path.join(os.environ['NVDIFFRAST_CACHE_DIR'], *paths)
45
- if 'HOME' in os.environ:
46
- return os.path.join(os.environ['HOME'], '.cache', 'nvdiffrast', *paths)
47
- if 'USERPROFILE' in os.environ:
48
- return os.path.join(os.environ['USERPROFILE'], '.cache', 'nvdiffrast', *paths)
49
- return os.path.join(tempfile.gettempdir(), '.cache', 'nvdiffrast', *paths)
50
-
51
- cuda_cache_version_tag = 'v1'
52
- do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe!
53
- verbose = True # Print status messages to stdout.
54
-
55
- #----------------------------------------------------------------------------
 
56
  # Internal helper funcs.
57
 
 
58
  def _find_compiler_bindir():
59
- hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/Enterprise/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
 
 
 
 
 
60
  if hostx64_paths != []:
61
  return hostx64_paths[0]
62
- hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Enterprise/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
 
 
 
 
 
63
  if hostx64_paths != []:
64
  return hostx64_paths[0]
65
- hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
 
 
 
 
 
66
  if hostx64_paths != []:
67
  return hostx64_paths[0]
68
- hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
 
 
 
 
 
69
  if hostx64_paths != []:
70
  return hostx64_paths[0]
71
- hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
 
 
 
 
 
72
  if hostx64_paths != []:
73
  return hostx64_paths[0]
74
- hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
 
 
 
 
 
75
  if hostx64_paths != []:
76
  return hostx64_paths[0]
77
- hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
 
 
 
 
 
78
  if hostx64_paths != []:
79
  return hostx64_paths[0]
80
- hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
 
 
 
 
 
81
  if hostx64_paths != []:
82
  return hostx64_paths[0]
83
- vc_bin_dir = 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin'
84
  if os.path.isdir(vc_bin_dir):
85
  return vc_bin_dir
86
  return None
87
 
 
88
  def _get_compute_cap(device):
89
  caps_str = device.physical_device_desc
90
- m = re.search('compute capability: (\\d+).(\\d+)', caps_str)
91
  major = m.group(1)
92
  minor = m.group(2)
93
  return (major, minor)
94
 
 
95
  def _get_cuda_gpu_arch_string():
96
- gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU']
97
  if len(gpus) == 0:
98
- raise RuntimeError('No GPU devices found')
99
  (major, minor) = _get_compute_cap(gpus[0])
100
- return 'sm_%s%s' % (major, minor)
 
101
 
102
  def _run_cmd(cmd):
103
  with os.popen(cmd) as pipe:
104
  output = pipe.read()
105
  status = pipe.close()
106
  if status is not None:
107
- raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output))
 
 
 
 
108
 
109
  def _prepare_nvcc_cli(opts):
110
- cmd = 'nvcc ' + opts.strip()
111
- cmd += ' --disable-warnings'
112
  cmd += ' --include-path "%s"' % tf.sysconfig.get_include()
113
- cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src')
114
- cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl')
115
- cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive')
 
 
 
 
 
 
116
 
117
  compiler_bindir = _find_compiler_bindir()
118
  if compiler_bindir is None:
119
  # Require that _find_compiler_bindir succeeds on Windows. Allow
120
  # nvcc to use whatever is the default on Linux.
121
- if os.name == 'nt':
122
- raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__)
 
 
 
123
  else:
124
  cmd += ' --compiler-bindir "%s"' % compiler_bindir
125
- cmd += ' 2>&1'
126
  return cmd
127
 
128
- #----------------------------------------------------------------------------
 
129
  # Main entry point.
130
 
131
  _plugin_cache = dict()
132
 
 
133
  def get_plugin(cuda_file, extra_nvcc_options=[]):
134
  cuda_file_base = os.path.basename(cuda_file)
135
  cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
@@ -140,80 +202,112 @@ def get_plugin(cuda_file, extra_nvcc_options=[]):
140
 
141
  # Setup plugin.
142
  if verbose:
143
- print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True)
 
 
144
  try:
145
  # Hash CUDA source.
146
  md5 = hashlib.md5()
147
- with open(cuda_file, 'rb') as f:
148
  md5.update(f.read())
149
- md5.update(b'\n')
150
 
151
  # Hash headers included by the CUDA code by running it through the preprocessor.
152
  if not do_not_hash_included_headers:
153
  if verbose:
154
- print('Preprocessing... ', end='', flush=True)
155
  with tempfile.TemporaryDirectory() as tmp_dir:
156
- tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext)
157
- _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)))
158
- with open(tmp_file, 'rb') as f:
159
- bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros
160
- good_file_str = ('"' + cuda_file_base + '"').encode('utf-8')
 
 
 
 
 
 
 
 
 
161
  for ln in f:
162
- if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas
 
 
163
  ln = ln.replace(bad_file_str, good_file_str)
164
  md5.update(ln)
165
- md5.update(b'\n')
166
 
167
  # Select compiler options.
168
- compile_opts = ''
169
- if os.name == 'nt':
170
- compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib')
171
- compile_opts += ' --library-path="%s"' % (os.path.dirname(__file__) + r"\..\lib") # Find libraries during compilation.
172
- elif os.name == 'posix':
173
- compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.so')
174
- compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\''
 
 
 
 
 
 
175
  else:
176
- assert False # not Windows or Linux, w00t?
177
- compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string()
178
- compile_opts += ' --use_fast_math'
179
  for opt in extra_nvcc_options:
180
- compile_opts += ' ' + opt
181
  nvcc_cmd = _prepare_nvcc_cli(compile_opts)
182
 
183
  # Hash build configuration.
184
- md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n')
185
- md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n')
186
- md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n')
 
 
 
187
 
188
  # Compile if not already compiled.
189
- bin_file_ext = '.dll' if os.name == 'nt' else '.so'
190
  cuda_cache_path = make_cache_dir_path()
191
- bin_file = os.path.join(make_cache_dir_path(), cuda_file_name + '_' + md5.hexdigest() + bin_file_ext)
 
 
192
  if not os.path.isfile(bin_file):
193
  if verbose:
194
- print('Compiling... ', end='', flush=True)
195
  with tempfile.TemporaryDirectory() as tmp_dir:
196
- tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext)
197
- _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))
 
 
 
 
198
  os.makedirs(cuda_cache_path, exist_ok=True)
199
- intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext)
 
 
 
200
  shutil.copyfile(tmp_file, intermediate_file)
201
- os.rename(intermediate_file, bin_file) # atomic
202
 
203
  # Load.
204
  if verbose:
205
- print('Loading... ', end='', flush=True)
206
  plugin = tf.load_op_library(bin_file)
207
 
208
  # Add to cache.
209
  _plugin_cache[cuda_file] = plugin
210
  if verbose:
211
- print('Done.', flush=True)
212
  return plugin
213
 
214
  except:
215
  if verbose:
216
- print('Failed!', flush=True)
217
  raise
218
 
219
- #----------------------------------------------------------------------------
 
 
14
  import tempfile
15
  import shutil
16
  import tensorflow as tf
17
+ from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
18
 
19
+ # ----------------------------------------------------------------------------
20
  # Global options.
21
 
22
  _nvdiffrast_cache_dir = None
23
 
24
+
25
  def set_cache_dir(path: str) -> None:
26
+ """Set CUDA kernel compilation temp dir.
27
 
28
  If `set_cache_dir` is not called, the cache directory will default to
29
  one of the below:
 
34
 
35
  Args:
36
  path: Where to save CUDA kernel build temporaries
37
+ """
38
  global _nvdiffrast_cache_dir
39
  _nvdiffrast_cache_dir = path
40
 
41
+
42
  def make_cache_dir_path(*paths: str) -> str:
43
  if _nvdiffrast_cache_dir is not None:
44
  return os.path.join(_nvdiffrast_cache_dir, *paths)
45
+ if "NVDIFFRAST_CACHE_DIR" in os.environ:
46
+ return os.path.join(os.environ["NVDIFFRAST_CACHE_DIR"], *paths)
47
+ if "HOME" in os.environ:
48
+ return os.path.join(os.environ["HOME"], ".cache", "nvdiffrast", *paths)
49
+ if "USERPROFILE" in os.environ:
50
+ return os.path.join(os.environ["USERPROFILE"], ".cache", "nvdiffrast", *paths)
51
+ return os.path.join(tempfile.gettempdir(), ".cache", "nvdiffrast", *paths)
52
+
53
+
54
+ cuda_cache_version_tag = "v1"
55
+ do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe!
56
+ verbose = True # Print status messages to stdout.
57
+
58
+ # ----------------------------------------------------------------------------
59
  # Internal helper funcs.
60
 
61
+
62
  def _find_compiler_bindir():
63
+ hostx64_paths = sorted(
64
+ glob.glob(
65
+ "C:/Program Files/Microsoft Visual Studio/*/Enterprise/VC/Tools/MSVC/*/bin/Hostx64/x64"
66
+ ),
67
+ reverse=True,
68
+ )
69
  if hostx64_paths != []:
70
  return hostx64_paths[0]
71
+ hostx64_paths = sorted(
72
+ glob.glob(
73
+ "C:/Program Files (x86)/Microsoft Visual Studio/*/Enterprise/VC/Tools/MSVC/*/bin/Hostx64/x64"
74
+ ),
75
+ reverse=True,
76
+ )
77
  if hostx64_paths != []:
78
  return hostx64_paths[0]
79
+ hostx64_paths = sorted(
80
+ glob.glob(
81
+ "C:/Program Files/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64"
82
+ ),
83
+ reverse=True,
84
+ )
85
  if hostx64_paths != []:
86
  return hostx64_paths[0]
87
+ hostx64_paths = sorted(
88
+ glob.glob(
89
+ "C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64"
90
+ ),
91
+ reverse=True,
92
+ )
93
  if hostx64_paths != []:
94
  return hostx64_paths[0]
95
+ hostx64_paths = sorted(
96
+ glob.glob(
97
+ "C:/Program Files/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64"
98
+ ),
99
+ reverse=True,
100
+ )
101
  if hostx64_paths != []:
102
  return hostx64_paths[0]
103
+ hostx64_paths = sorted(
104
+ glob.glob(
105
+ "C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64"
106
+ ),
107
+ reverse=True,
108
+ )
109
  if hostx64_paths != []:
110
  return hostx64_paths[0]
111
+ hostx64_paths = sorted(
112
+ glob.glob(
113
+ "C:/Program Files/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64"
114
+ ),
115
+ reverse=True,
116
+ )
117
  if hostx64_paths != []:
118
  return hostx64_paths[0]
119
+ hostx64_paths = sorted(
120
+ glob.glob(
121
+ "C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64"
122
+ ),
123
+ reverse=True,
124
+ )
125
  if hostx64_paths != []:
126
  return hostx64_paths[0]
127
+ vc_bin_dir = "C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin"
128
  if os.path.isdir(vc_bin_dir):
129
  return vc_bin_dir
130
  return None
131
 
132
+
133
  def _get_compute_cap(device):
134
  caps_str = device.physical_device_desc
135
+ m = re.search("compute capability: (\\d+).(\\d+)", caps_str)
136
  major = m.group(1)
137
  minor = m.group(2)
138
  return (major, minor)
139
 
140
+
141
  def _get_cuda_gpu_arch_string():
142
+ gpus = [x for x in device_lib.list_local_devices() if x.device_type == "GPU"]
143
  if len(gpus) == 0:
144
+ raise RuntimeError("No GPU devices found")
145
  (major, minor) = _get_compute_cap(gpus[0])
146
+ return "sm_%s%s" % (major, minor)
147
+
148
 
149
  def _run_cmd(cmd):
150
  with os.popen(cmd) as pipe:
151
  output = pipe.read()
152
  status = pipe.close()
153
  if status is not None:
154
+ raise RuntimeError(
155
+ "NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s"
156
+ % (cmd, output)
157
+ )
158
+
159
 
160
  def _prepare_nvcc_cli(opts):
161
+ cmd = "nvcc " + opts.strip()
162
+ cmd += " --disable-warnings"
163
  cmd += ' --include-path "%s"' % tf.sysconfig.get_include()
164
+ cmd += ' --include-path "%s"' % os.path.join(
165
+ tf.sysconfig.get_include(), "external", "protobuf_archive", "src"
166
+ )
167
+ cmd += ' --include-path "%s"' % os.path.join(
168
+ tf.sysconfig.get_include(), "external", "com_google_absl"
169
+ )
170
+ cmd += ' --include-path "%s"' % os.path.join(
171
+ tf.sysconfig.get_include(), "external", "eigen_archive"
172
+ )
173
 
174
  compiler_bindir = _find_compiler_bindir()
175
  if compiler_bindir is None:
176
  # Require that _find_compiler_bindir succeeds on Windows. Allow
177
  # nvcc to use whatever is the default on Linux.
178
+ if os.name == "nt":
179
+ raise RuntimeError(
180
+ 'Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".'
181
+ % __file__
182
+ )
183
  else:
184
  cmd += ' --compiler-bindir "%s"' % compiler_bindir
185
+ cmd += " 2>&1"
186
  return cmd
187
 
188
+
189
+ # ----------------------------------------------------------------------------
190
  # Main entry point.
191
 
192
  _plugin_cache = dict()
193
 
194
+
195
  def get_plugin(cuda_file, extra_nvcc_options=[]):
196
  cuda_file_base = os.path.basename(cuda_file)
197
  cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
 
202
 
203
  # Setup plugin.
204
  if verbose:
205
+ print(
206
+ 'Setting up TensorFlow plugin "%s": ' % cuda_file_base, end="", flush=True
207
+ )
208
  try:
209
  # Hash CUDA source.
210
  md5 = hashlib.md5()
211
+ with open(cuda_file, "rb") as f:
212
  md5.update(f.read())
213
+ md5.update(b"\n")
214
 
215
  # Hash headers included by the CUDA code by running it through the preprocessor.
216
  if not do_not_hash_included_headers:
217
  if verbose:
218
+ print("Preprocessing... ", end="", flush=True)
219
  with tempfile.TemporaryDirectory() as tmp_dir:
220
+ tmp_file = os.path.join(
221
+ tmp_dir, cuda_file_name + "_tmp" + cuda_file_ext
222
+ )
223
+ _run_cmd(
224
+ _prepare_nvcc_cli(
225
+ '"%s" --preprocess -o "%s" --keep --keep-dir "%s"'
226
+ % (cuda_file, tmp_file, tmp_dir)
227
+ )
228
+ )
229
+ with open(tmp_file, "rb") as f:
230
+ bad_file_str = ('"' + cuda_file.replace("\\", "/") + '"').encode(
231
+ "utf-8"
232
+ ) # __FILE__ in error check macros
233
+ good_file_str = ('"' + cuda_file_base + '"').encode("utf-8")
234
  for ln in f:
235
+ if not ln.startswith(b"# ") and not ln.startswith(
236
+ b"#line "
237
+ ): # ignore line number pragmas
238
  ln = ln.replace(bad_file_str, good_file_str)
239
  md5.update(ln)
240
+ md5.update(b"\n")
241
 
242
  # Select compiler options.
243
+ compile_opts = ""
244
+ if os.name == "nt":
245
+ compile_opts += '"%s"' % os.path.join(
246
+ tf.sysconfig.get_lib(), "python", "_pywrap_tensorflow_internal.lib"
247
+ )
248
+ compile_opts += ' --library-path="%s"' % (
249
+ os.path.dirname(__file__) + r"\..\lib"
250
+ ) # Find libraries during compilation.
251
+ elif os.name == "posix":
252
+ compile_opts += '"%s"' % os.path.join(
253
+ tf.sysconfig.get_lib(), "python", "_pywrap_tensorflow_internal.so"
254
+ )
255
+ compile_opts += " --compiler-options '-fPIC -D_GLIBCXX_USE_CXX11_ABI=0'"
256
  else:
257
+ assert False # not Windows or Linux, w00t?
258
+ compile_opts += " --gpu-architecture=%s" % _get_cuda_gpu_arch_string()
259
+ compile_opts += " --use_fast_math"
260
  for opt in extra_nvcc_options:
261
+ compile_opts += " " + opt
262
  nvcc_cmd = _prepare_nvcc_cli(compile_opts)
263
 
264
  # Hash build configuration.
265
+ md5.update(("nvcc_cmd: " + nvcc_cmd).encode("utf-8") + b"\n")
266
+ md5.update(("tf.VERSION: " + tf.VERSION).encode("utf-8") + b"\n")
267
+ md5.update(
268
+ ("cuda_cache_version_tag: " + cuda_cache_version_tag).encode("utf-8")
269
+ + b"\n"
270
+ )
271
 
272
  # Compile if not already compiled.
273
+ bin_file_ext = ".dll" if os.name == "nt" else ".so"
274
  cuda_cache_path = make_cache_dir_path()
275
+ bin_file = os.path.join(
276
+ make_cache_dir_path(), cuda_file_name + "_" + md5.hexdigest() + bin_file_ext
277
+ )
278
  if not os.path.isfile(bin_file):
279
  if verbose:
280
+ print("Compiling... ", end="", flush=True)
281
  with tempfile.TemporaryDirectory() as tmp_dir:
282
+ tmp_file = os.path.join(tmp_dir, cuda_file_name + "_tmp" + bin_file_ext)
283
+ _run_cmd(
284
+ nvcc_cmd
285
+ + ' "%s" --shared -o "%s" --keep --keep-dir "%s"'
286
+ % (cuda_file, tmp_file, tmp_dir)
287
+ )
288
  os.makedirs(cuda_cache_path, exist_ok=True)
289
+ intermediate_file = os.path.join(
290
+ cuda_cache_path,
291
+ cuda_file_name + "_" + uuid.uuid4().hex + "_tmp" + bin_file_ext,
292
+ )
293
  shutil.copyfile(tmp_file, intermediate_file)
294
+ os.rename(intermediate_file, bin_file) # atomic
295
 
296
  # Load.
297
  if verbose:
298
+ print("Loading... ", end="", flush=True)
299
  plugin = tf.load_op_library(bin_file)
300
 
301
  # Add to cache.
302
  _plugin_cache[cuda_file] = plugin
303
  if verbose:
304
+ print("Done.", flush=True)
305
  return plugin
306
 
307
  except:
308
  if verbose:
309
+ print("Failed!", flush=True)
310
  raise
311
 
312
+
313
+ # ----------------------------------------------------------------------------
extensions/nvdiffrast/nvdiffrast/tensorflow/tf_antialias.cu CHANGED
@@ -100,13 +100,13 @@ struct AntialiasFwdOp : public OpKernel
100
 
101
  // (Re-)calculate opposite vertex hash.
102
  if (!p.evHash || !p.tri_const)
103
- {
104
  if (p.allocTriangles < p.numTriangles)
105
  {
106
  p.allocTriangles = max(p.allocTriangles, 64);
107
  while (p.allocTriangles < p.numTriangles)
108
  p.allocTriangles <<= 1; // Must be power of two.
109
-
110
  // (Re-)allocate memory for the hash.
111
  OP_CHECK_CUDA_ERROR(ctx, cudaFree(p.evHash));
112
  OP_CHECK_CUDA_ERROR(ctx, cudaMalloc(&p.evHash, p.allocTriangles * AA_HASH_ELEMENTS_PER_TRIANGLE(p.allocTriangles) * sizeof(uint4)));
 
100
 
101
  // (Re-)calculate opposite vertex hash.
102
  if (!p.evHash || !p.tri_const)
103
+ {
104
  if (p.allocTriangles < p.numTriangles)
105
  {
106
  p.allocTriangles = max(p.allocTriangles, 64);
107
  while (p.allocTriangles < p.numTriangles)
108
  p.allocTriangles <<= 1; // Must be power of two.
109
+
110
  // (Re-)allocate memory for the hash.
111
  OP_CHECK_CUDA_ERROR(ctx, cudaFree(p.evHash));
112
  OP_CHECK_CUDA_ERROR(ctx, cudaMalloc(&p.evHash, p.allocTriangles * AA_HASH_ELEMENTS_PER_TRIANGLE(p.allocTriangles) * sizeof(uint4)));
extensions/nvdiffrast/nvdiffrast/tensorflow/tf_interpolate.cu CHANGED
@@ -112,7 +112,7 @@ struct InterpolateFwdOp : public OpKernel
112
 
113
  // Verify that buffers are aligned to allow float2/float4 operations.
114
  OP_REQUIRES(ctx, !((uintptr_t)p.rast & 15), errors::Internal("rast input tensor not aligned to float4"));
115
- OP_REQUIRES(ctx, !((uintptr_t)p.rastDB & 15), errors::Internal("rast_db input tensor not aligned to float4"));
116
  if (ENABLE_DA)
117
  OP_REQUIRES(ctx, !((uintptr_t)p.outDA & 7), errors::Internal("out_da output tensor not aligned to float2"));
118
 
@@ -158,7 +158,7 @@ struct InterpolateGradOp : public OpKernel
158
  InterpolateGradOp(OpKernelConstruction* ctx): OpKernel(ctx)
159
  {
160
  memset(&m_attribs, 0, sizeof(m_attribs));
161
- interpolateParseOpAttributes(ctx, m_attribs, ENABLE_DA);
162
  }
163
 
164
  void Compute(OpKernelContext* ctx)
@@ -247,7 +247,7 @@ struct InterpolateGradOp : public OpKernel
247
  OP_REQUIRES_OK(ctx, ctx->allocate_output(2, grad_rast_shape, &grad_rast_db_tensor));
248
  p.gradRasterDB = grad_rast_db_tensor->flat<float>().data();
249
  }
250
-
251
  // Clear attribute gradients.
252
  cudaMemsetAsync(p.gradAttr, 0, attr_depth * p.numVertices * p.numAttr * sizeof(float), stream);
253
 
@@ -257,10 +257,10 @@ struct InterpolateGradOp : public OpKernel
257
  if (ENABLE_DA)
258
  {
259
  OP_REQUIRES(ctx, !((uintptr_t)p.dda & 7), errors::Internal("dda input tensor not aligned to float2"));
260
- OP_REQUIRES(ctx, !((uintptr_t)p.rastDB & 15), errors::Internal("rast_db input tensor not aligned to float4"));
261
  OP_REQUIRES(ctx, !((uintptr_t)p.gradRasterDB & 15), errors::Internal("grad_rast_db output tensor not aligned to float4"));
262
  }
263
-
264
  // Choose launch parameters.
265
  dim3 blockSize = getLaunchBlockSize(IP_GRAD_MAX_KERNEL_BLOCK_WIDTH, IP_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height);
266
  dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth);
 
112
 
113
  // Verify that buffers are aligned to allow float2/float4 operations.
114
  OP_REQUIRES(ctx, !((uintptr_t)p.rast & 15), errors::Internal("rast input tensor not aligned to float4"));
115
+ OP_REQUIRES(ctx, !((uintptr_t)p.rastDB & 15), errors::Internal("rast_db input tensor not aligned to float4"));
116
  if (ENABLE_DA)
117
  OP_REQUIRES(ctx, !((uintptr_t)p.outDA & 7), errors::Internal("out_da output tensor not aligned to float2"));
118
 
 
158
  InterpolateGradOp(OpKernelConstruction* ctx): OpKernel(ctx)
159
  {
160
  memset(&m_attribs, 0, sizeof(m_attribs));
161
+ interpolateParseOpAttributes(ctx, m_attribs, ENABLE_DA);
162
  }
163
 
164
  void Compute(OpKernelContext* ctx)
 
247
  OP_REQUIRES_OK(ctx, ctx->allocate_output(2, grad_rast_shape, &grad_rast_db_tensor));
248
  p.gradRasterDB = grad_rast_db_tensor->flat<float>().data();
249
  }
250
+
251
  // Clear attribute gradients.
252
  cudaMemsetAsync(p.gradAttr, 0, attr_depth * p.numVertices * p.numAttr * sizeof(float), stream);
253
 
 
257
  if (ENABLE_DA)
258
  {
259
  OP_REQUIRES(ctx, !((uintptr_t)p.dda & 7), errors::Internal("dda input tensor not aligned to float2"));
260
+ OP_REQUIRES(ctx, !((uintptr_t)p.rastDB & 15), errors::Internal("rast_db input tensor not aligned to float4"));
261
  OP_REQUIRES(ctx, !((uintptr_t)p.gradRasterDB & 15), errors::Internal("grad_rast_db output tensor not aligned to float4"));
262
  }
263
+
264
  // Choose launch parameters.
265
  dim3 blockSize = getLaunchBlockSize(IP_GRAD_MAX_KERNEL_BLOCK_WIDTH, IP_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height);
266
  dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth);
extensions/nvdiffrast/nvdiffrast/tensorflow/tf_texture.cu CHANGED
@@ -503,7 +503,7 @@ REGISTER_OP("TextureGradLinearMipmapNearest")
503
  .Attr ("filter_mode: int")
504
  .Attr ("boundary_mode: int")
505
  .Attr ("max_mip_level: int");
506
-
507
  REGISTER_OP("TextureGradLinearMipmapLinear")
508
  .Input ("tex: float")
509
  .Input ("uv: float")
@@ -516,10 +516,10 @@ REGISTER_OP("TextureGradLinearMipmapLinear")
516
  .Attr ("filter_mode: int")
517
  .Attr ("boundary_mode: int")
518
  .Attr ("max_mip_level: int");
519
-
520
  REGISTER_KERNEL_BUILDER(Name("TextureGradNearest") .Device(DEVICE_GPU), TextureGradOp);
521
  REGISTER_KERNEL_BUILDER(Name("TextureGradLinear") .Device(DEVICE_GPU), TextureGradOp);
522
  REGISTER_KERNEL_BUILDER(Name("TextureGradLinearMipmapNearest").Device(DEVICE_GPU), TextureGradOp);
523
  REGISTER_KERNEL_BUILDER(Name("TextureGradLinearMipmapLinear") .Device(DEVICE_GPU), TextureGradOp);
524
-
525
  //------------------------------------------------------------------------
 
503
  .Attr ("filter_mode: int")
504
  .Attr ("boundary_mode: int")
505
  .Attr ("max_mip_level: int");
506
+
507
  REGISTER_OP("TextureGradLinearMipmapLinear")
508
  .Input ("tex: float")
509
  .Input ("uv: float")
 
516
  .Attr ("filter_mode: int")
517
  .Attr ("boundary_mode: int")
518
  .Attr ("max_mip_level: int");
519
+
520
  REGISTER_KERNEL_BUILDER(Name("TextureGradNearest") .Device(DEVICE_GPU), TextureGradOp);
521
  REGISTER_KERNEL_BUILDER(Name("TextureGradLinear") .Device(DEVICE_GPU), TextureGradOp);
522
  REGISTER_KERNEL_BUILDER(Name("TextureGradLinearMipmapNearest").Device(DEVICE_GPU), TextureGradOp);
523
  REGISTER_KERNEL_BUILDER(Name("TextureGradLinearMipmapLinear") .Device(DEVICE_GPU), TextureGradOp);
524
+
525
  //------------------------------------------------------------------------
extensions/nvdiffrast/nvdiffrast/torch/__init__.py CHANGED
@@ -6,5 +6,30 @@
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
- from .ops import RasterizeCudaContext, RasterizeGLContext, get_log_level, set_log_level, rasterize, DepthPeeler, interpolate, texture, texture_construct_mip, antialias, antialias_construct_topology_hash
10
- __all__ = ["RasterizeCudaContext", "RasterizeGLContext", "get_log_level", "set_log_level", "rasterize", "DepthPeeler", "interpolate", "texture", "texture_construct_mip", "antialias", "antialias_construct_topology_hash"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
+ from .ops import (
10
+ RasterizeCudaContext,
11
+ RasterizeGLContext,
12
+ get_log_level,
13
+ set_log_level,
14
+ rasterize,
15
+ DepthPeeler,
16
+ interpolate,
17
+ texture,
18
+ texture_construct_mip,
19
+ antialias,
20
+ antialias_construct_topology_hash,
21
+ )
22
+
23
+ __all__ = [
24
+ "RasterizeCudaContext",
25
+ "RasterizeGLContext",
26
+ "get_log_level",
27
+ "set_log_level",
28
+ "rasterize",
29
+ "DepthPeeler",
30
+ "interpolate",
31
+ "texture",
32
+ "texture_construct_mip",
33
+ "antialias",
34
+ "antialias_construct_topology_hash",
35
+ ]
extensions/nvdiffrast/nvdiffrast/torch/ops.py CHANGED
@@ -14,13 +14,15 @@ import torch
14
  import torch.utils.cpp_extension
15
  from . import _C
16
 
17
- #----------------------------------------------------------------------------
18
  # C++/Cuda plugin compiler/loader.
19
 
20
  _cached_plugin = {}
 
 
21
  def _get_plugin(gl=False):
22
  assert isinstance(gl, bool)
23
-
24
  # Modified with precompiled torch CUDA extension
25
  if not gl:
26
  return _C
@@ -30,16 +32,27 @@ def _get_plugin(gl=False):
30
  return _cached_plugin[gl]
31
 
32
  # Make sure we can find the necessary compiler and libary binaries.
33
- if os.name == 'nt':
34
  lib_dir = os.path.dirname(__file__) + r"\..\lib"
 
35
  def find_cl_path():
36
  import glob
 
37
  def get_sort_key(x):
38
  # Primary criterion is VS version, secondary is edition, third is internal MSVC version.
39
- x = x.split('\\')[3:]
40
- x[1] = {'BuildTools': '~0', 'Community': '~1', 'Pro': '~2', 'Professional': '~3', 'Enterprise': '~4'}.get(x[1], x[1])
 
 
 
 
 
 
41
  return x
42
- vs_relative_path = r"\Microsoft Visual Studio\*\*\VC\Tools\MSVC\*\bin\Hostx64\x64"
 
 
 
43
  paths = glob.glob(r"C:\Program Files" + vs_relative_path)
44
  paths += glob.glob(r"C:\Program Files (x86)" + vs_relative_path)
45
  if paths:
@@ -49,104 +62,126 @@ def _get_plugin(gl=False):
49
  if os.system("where cl.exe >nul 2>nul") != 0:
50
  cl_path = find_cl_path()
51
  if cl_path is None:
52
- raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
53
- os.environ['PATH'] += ';' + cl_path
 
 
54
 
55
  # Compiler options.
56
- common_opts = ['-DNVDR_TORCH']
57
  cc_opts = []
58
- if os.name == 'nt':
59
- cc_opts += ['/wd4067', '/wd4624'] # Disable warnings in torch headers.
60
 
61
  # Linker options for the GL-interfacing plugin.
62
  ldflags = []
63
  if gl:
64
- if os.name == 'posix':
65
- ldflags = ['-lGL', '-lEGL']
66
- elif os.name == 'nt':
67
- libs = ['gdi32', 'opengl32', 'user32', 'setgpu']
68
- ldflags = ['/LIBPATH:' + lib_dir] + ['/DEFAULTLIB:' + x for x in libs]
69
 
70
  # List of source files.
71
  if gl:
72
  source_files = [
73
- '../common/common.cpp',
74
- '../common/glutil.cpp',
75
- '../common/rasterize_gl.cpp',
76
- 'torch_bindings_gl.cpp',
77
- 'torch_rasterize_gl.cpp',
78
  ]
79
  else:
80
  source_files = [
81
- '../common/cudaraster/impl/Buffer.cpp',
82
- '../common/cudaraster/impl/CudaRaster.cpp',
83
- '../common/cudaraster/impl/RasterImpl.cu',
84
- '../common/cudaraster/impl/RasterImpl.cpp',
85
- '../common/common.cpp',
86
- '../common/rasterize.cu',
87
- '../common/interpolate.cu',
88
- '../common/texture.cu',
89
- '../common/texture.cpp',
90
- '../common/antialias.cu',
91
- 'torch_bindings.cpp',
92
- 'torch_rasterize.cpp',
93
- 'torch_interpolate.cpp',
94
- 'torch_texture.cpp',
95
- 'torch_antialias.cpp',
96
  ]
97
 
98
  # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine.
99
- os.environ['TORCH_CUDA_ARCH_LIST'] = ''
100
 
101
  # On Linux, show a warning if GLEW is being forcibly loaded when compiling the GL plugin.
102
- if gl and (os.name == 'posix') and ('libGLEW' in os.environ.get('LD_PRELOAD', '')):
103
- logging.getLogger('nvdiffrast').warning("Warning: libGLEW is being loaded via LD_PRELOAD, and will probably conflict with the OpenGL plugin")
 
 
104
 
105
  # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment.
106
- plugin_name = 'nvdiffrast_plugin' + ('_gl' if gl else '')
107
  try:
108
- lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory(plugin_name, False), 'lock')
 
 
109
  if os.path.exists(lock_fn):
110
- logging.getLogger('nvdiffrast').warning("Lock file exists in build directory: '%s'" % lock_fn)
 
 
111
  except:
112
  pass
113
 
114
  # Speed up compilation on Windows.
115
- if os.name == 'nt':
116
  # Skip telemetry sending step in vcvarsall.bat
117
- os.environ['VSCMD_SKIP_SENDTELEMETRY'] = '1'
118
 
119
  # Opportunistically patch distutils to cache MSVC environments.
120
  try:
121
  import distutils._msvccompiler
122
  import functools
123
- if not hasattr(distutils._msvccompiler._get_vc_env, '__wrapped__'):
124
- distutils._msvccompiler._get_vc_env = functools.lru_cache()(distutils._msvccompiler._get_vc_env)
 
 
 
125
  except:
126
  pass
127
 
128
  # Compile and load.
129
  source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files]
130
- torch.utils.cpp_extension.load(name=plugin_name, sources=source_paths, extra_cflags=common_opts+cc_opts, extra_cuda_cflags=common_opts+['-lineinfo'], extra_ldflags=ldflags, with_cuda=True, verbose=False)
 
 
 
 
 
 
 
 
131
 
132
  # Import, cache, and return the compiled module.
133
  _cached_plugin[gl] = importlib.import_module(plugin_name)
134
  return _cached_plugin[gl]
135
 
136
- #----------------------------------------------------------------------------
 
137
  # Log level.
138
- #----------------------------------------------------------------------------
 
139
 
140
  def get_log_level():
141
- '''Get current log level.
142
 
143
  Returns:
144
  Current log level in nvdiffrast. See `set_log_level()` for possible values.
145
- '''
146
  return _get_plugin().get_log_level()
147
 
 
148
  def set_log_level(level):
149
- '''Set log level.
150
 
151
  Log levels follow the convention on the C++ side of Torch:
152
  0 = Info,
@@ -156,19 +191,21 @@ def set_log_level(level):
156
  The default log level is 1.
157
 
158
  Args:
159
- level: New log level as integer. Internal nvdiffrast messages of this
160
  severity or higher will be printed, while messages of lower
161
  severity will be silent.
162
- '''
163
  _get_plugin().set_log_level(level)
164
 
165
- #----------------------------------------------------------------------------
 
166
  # CudaRaster state wrapper.
167
- #----------------------------------------------------------------------------
 
168
 
169
  class RasterizeCudaContext:
170
  def __init__(self, device=None):
171
- '''Create a new Cuda rasterizer context.
172
 
173
  The context is deleted and internal storage is released when the object is
174
  destroyed.
@@ -180,7 +217,7 @@ class RasterizeCudaContext:
180
  device.
181
  Returns:
182
  The newly created Cuda rasterizer context.
183
- '''
184
  if device is None:
185
  cuda_device_idx = torch.cuda.current_device()
186
  else:
@@ -190,13 +227,15 @@ class RasterizeCudaContext:
190
  self.output_db = True
191
  self.active_depth_peeler = None
192
 
193
- #----------------------------------------------------------------------------
 
194
  # GL state wrapper.
195
- #----------------------------------------------------------------------------
 
196
 
197
  class RasterizeGLContext:
198
- def __init__(self, output_db=True, mode='automatic', device=None):
199
- '''Create a new OpenGL rasterizer context.
200
 
201
  Creating an OpenGL context is a slow operation so you should usually reuse the same
202
  context in all calls to `rasterize()` on the same CPU thread. The OpenGL context
@@ -220,9 +259,9 @@ class RasterizeGLContext:
220
  device.
221
  Returns:
222
  The newly created OpenGL rasterizer context.
223
- '''
224
  assert output_db is True or output_db is False
225
- assert mode in ['automatic', 'manual']
226
  self.output_db = output_db
227
  self.mode = mode
228
  if device is None:
@@ -230,34 +269,42 @@ class RasterizeGLContext:
230
  else:
231
  with torch.cuda.device(device):
232
  cuda_device_idx = torch.cuda.current_device()
233
- self.cpp_wrapper = _get_plugin(gl=True).RasterizeGLStateWrapper(output_db, mode == 'automatic', cuda_device_idx)
234
- self.active_depth_peeler = None # For error checking only.
 
 
235
 
236
  def set_context(self):
237
- '''Set (activate) OpenGL context in the current CPU thread.
238
- Only available if context was created in manual mode.
239
- '''
240
- assert self.mode == 'manual'
241
  self.cpp_wrapper.set_context()
242
 
243
  def release_context(self):
244
- '''Release (deactivate) currently active OpenGL context.
245
- Only available if context was created in manual mode.
246
- '''
247
- assert self.mode == 'manual'
248
  self.cpp_wrapper.release_context()
249
 
250
- #----------------------------------------------------------------------------
 
251
  # Rasterize.
252
- #----------------------------------------------------------------------------
 
253
 
254
  class _rasterize_func(torch.autograd.Function):
255
  @staticmethod
256
  def forward(ctx, raster_ctx, pos, tri, resolution, ranges, grad_db, peeling_idx):
257
  if isinstance(raster_ctx, RasterizeGLContext):
258
- out, out_db = _get_plugin(gl=True).rasterize_fwd_gl(raster_ctx.cpp_wrapper, pos, tri, resolution, ranges, peeling_idx)
 
 
259
  else:
260
- out, out_db = _get_plugin().rasterize_fwd_cuda(raster_ctx.cpp_wrapper, pos, tri, resolution, ranges, peeling_idx)
 
 
261
  ctx.save_for_backward(pos, tri, out)
262
  ctx.saved_grad_db = grad_db
263
  return out, out_db
@@ -271,9 +318,10 @@ class _rasterize_func(torch.autograd.Function):
271
  g_pos = _get_plugin().rasterize_grad(pos, tri, out, dy)
272
  return None, g_pos, None, None, None, None, None
273
 
 
274
  # Op wrapper.
275
  def rasterize(glctx, pos, tri, resolution, ranges=None, grad_db=True):
276
- '''Rasterize triangles.
277
 
278
  All input tensors must be contiguous and reside in GPU memory except for
279
  the `ranges` tensor that, if specified, has to reside in CPU memory. The
@@ -301,7 +349,7 @@ def rasterize(glctx, pos, tri, resolution, ranges=None, grad_db=True):
301
  [minibatch_size, height, width, 4] and contain said derivatives in order
302
  (du/dX, du/dY, dv/dX, dv/dY). Otherwise it will be an empty tensor with shape
303
  [minibatch_size, height, width, 0].
304
- '''
305
  assert isinstance(glctx, (RasterizeGLContext, RasterizeCudaContext))
306
  assert grad_db is True or grad_db is False
307
  grad_db = grad_db and glctx.output_db
@@ -310,30 +358,34 @@ def rasterize(glctx, pos, tri, resolution, ranges=None, grad_db=True):
310
  assert isinstance(pos, torch.Tensor) and isinstance(tri, torch.Tensor)
311
  resolution = tuple(resolution)
312
  if ranges is None:
313
- ranges = torch.empty(size=(0, 2), dtype=torch.int32, device='cpu')
314
  else:
315
  assert isinstance(ranges, torch.Tensor)
316
 
317
  # Check that context is not currently reserved for depth peeling.
318
  if glctx.active_depth_peeler is not None:
319
- return RuntimeError("Cannot call rasterize() during depth peeling operation, use rasterize_next_layer() instead")
 
 
320
 
321
  # Instantiate the function.
322
  return _rasterize_func.apply(glctx, pos, tri, resolution, ranges, grad_db, -1)
323
 
324
- #----------------------------------------------------------------------------
 
325
  # Depth peeler context manager for rasterizing multiple depth layers.
326
- #----------------------------------------------------------------------------
 
327
 
328
  class DepthPeeler:
329
  def __init__(self, glctx, pos, tri, resolution, ranges=None, grad_db=True):
330
- '''Create a depth peeler object for rasterizing multiple depth layers.
331
 
332
  Arguments are the same as in `rasterize()`.
333
 
334
  Returns:
335
  The newly created depth peeler.
336
- '''
337
  assert isinstance(glctx, (RasterizeGLContext, RasterizeCudaContext))
338
  assert grad_db is True or grad_db is False
339
  grad_db = grad_db and glctx.output_db
@@ -342,7 +394,7 @@ class DepthPeeler:
342
  assert isinstance(pos, torch.Tensor) and isinstance(tri, torch.Tensor)
343
  resolution = tuple(resolution)
344
  if ranges is None:
345
- ranges = torch.empty(size=(0, 2), dtype=torch.int32, device='cpu')
346
  else:
347
  assert isinstance(ranges, torch.Tensor)
348
 
@@ -359,7 +411,9 @@ class DepthPeeler:
359
  if self.raster_ctx is None:
360
  raise RuntimeError("Cannot re-enter a terminated depth peeling operation")
361
  if self.raster_ctx.active_depth_peeler is not None:
362
- raise RuntimeError("Cannot have multiple depth peelers active simultaneously in a rasterization context")
 
 
363
  self.raster_ctx.active_depth_peeler = self
364
  self.peeling_idx = 0
365
  return self
@@ -367,7 +421,9 @@ class DepthPeeler:
367
  def __exit__(self, *args):
368
  assert self.raster_ctx.active_depth_peeler is self
369
  self.raster_ctx.active_depth_peeler = None
370
- self.raster_ctx = None # Remove all references to input tensor so they're not left dangling.
 
 
371
  self.pos = None
372
  self.tri = None
373
  self.resolution = None
@@ -377,29 +433,40 @@ class DepthPeeler:
377
  return None
378
 
379
  def rasterize_next_layer(self):
380
- '''Rasterize next depth layer.
381
 
382
  Operation is equivalent to `rasterize()` except that previously reported
383
  surface points are culled away.
384
 
385
  Returns:
386
  A tuple of two tensors as in `rasterize()`.
387
- '''
388
  assert self.raster_ctx.active_depth_peeler is self
389
  assert self.peeling_idx >= 0
390
- result = _rasterize_func.apply(self.raster_ctx, self.pos, self.tri, self.resolution, self.ranges, self.grad_db, self.peeling_idx)
 
 
 
 
 
 
 
 
391
  self.peeling_idx += 1
392
  return result
393
 
394
- #----------------------------------------------------------------------------
 
395
  # Interpolate.
396
- #----------------------------------------------------------------------------
397
 
398
  # Output pixel differentials for at least some attributes.
399
  class _interpolate_func_da(torch.autograd.Function):
400
  @staticmethod
401
  def forward(ctx, attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list):
402
- out, out_da = _get_plugin().interpolate_fwd_da(attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list)
 
 
403
  ctx.save_for_backward(attr, rast, tri, rast_db)
404
  ctx.saved_misc = diff_attrs_all, diff_attrs_list
405
  return out, out_da
@@ -408,9 +475,12 @@ class _interpolate_func_da(torch.autograd.Function):
408
  def backward(ctx, dy, dda):
409
  attr, rast, tri, rast_db = ctx.saved_tensors
410
  diff_attrs_all, diff_attrs_list = ctx.saved_misc
411
- g_attr, g_rast, g_rast_db = _get_plugin().interpolate_grad_da(attr, rast, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list)
 
 
412
  return g_attr, g_rast, None, g_rast_db, None, None
413
 
 
414
  # No pixel differential for any attribute.
415
  class _interpolate_func(torch.autograd.Function):
416
  @staticmethod
@@ -425,6 +495,7 @@ class _interpolate_func(torch.autograd.Function):
425
  g_attr, g_rast = _get_plugin().interpolate_grad(attr, rast, tri, dy)
426
  return g_attr, g_rast, None
427
 
 
428
  # Op wrapper.
429
  def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None):
430
  """Interpolate vertex attributes.
@@ -433,13 +504,13 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None):
433
  will be contiguous and reside in GPU memory.
434
 
435
  Args:
436
- attr: Attribute tensor with dtype `torch.float32`.
437
- Shape is [num_vertices, num_attributes] in range mode, or
438
  [minibatch_size, num_vertices, num_attributes] in instanced mode.
439
  Broadcasting is supported along the minibatch axis.
440
  rast: Main output tensor from `rasterize()`.
441
  tri: Triangle tensor with shape [num_triangles, 3] and dtype `torch.int32`.
442
- rast_db: (Optional) Tensor containing image-space derivatives of barycentrics,
443
  i.e., the second output tensor from `rasterize()`. Enables computing
444
  image-space derivatives of attributes.
445
  diff_attrs: (Optional) List of attribute indices for which image-space
@@ -459,12 +530,12 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None):
459
  # Sanitize the list of pixel differential attributes.
460
  if diff_attrs is None:
461
  diff_attrs = []
462
- elif diff_attrs != 'all':
463
  diff_attrs = np.asarray(diff_attrs, np.int32)
464
  assert len(diff_attrs.shape) == 1
465
  diff_attrs = diff_attrs.tolist()
466
 
467
- diff_attrs_all = int(diff_attrs == 'all')
468
  diff_attrs_list = [] if diff_attrs_all else diff_attrs
469
 
470
  # Check inputs.
@@ -474,18 +545,32 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None):
474
 
475
  # Choose stub.
476
  if diff_attrs:
477
- return _interpolate_func_da.apply(attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list)
 
 
478
  else:
479
  return _interpolate_func.apply(attr, rast, tri)
480
 
481
- #----------------------------------------------------------------------------
 
482
  # Texture
483
- #----------------------------------------------------------------------------
484
 
485
  # Linear-mipmap-linear and linear-mipmap-nearest: Mipmaps enabled.
486
  class _texture_func_mip(torch.autograd.Function):
487
  @staticmethod
488
- def forward(ctx, filter_mode, tex, uv, uv_da, mip_level_bias, mip_wrapper, filter_mode_enum, boundary_mode_enum, *mip_stack):
 
 
 
 
 
 
 
 
 
 
 
489
  empty = torch.tensor([])
490
  if uv_da is None:
491
  uv_da = empty
@@ -493,7 +578,16 @@ class _texture_func_mip(torch.autograd.Function):
493
  mip_level_bias = empty
494
  if mip_wrapper is None:
495
  mip_wrapper = _get_plugin().TextureMipWrapper()
496
- out = _get_plugin().texture_fwd_mip(tex, uv, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum)
 
 
 
 
 
 
 
 
 
497
  ctx.save_for_backward(tex, uv, uv_da, mip_level_bias, *mip_stack)
498
  ctx.saved_misc = filter_mode, mip_wrapper, filter_mode_enum, boundary_mode_enum
499
  return out
@@ -502,12 +596,50 @@ class _texture_func_mip(torch.autograd.Function):
502
  def backward(ctx, dy):
503
  tex, uv, uv_da, mip_level_bias, *mip_stack = ctx.saved_tensors
504
  filter_mode, mip_wrapper, filter_mode_enum, boundary_mode_enum = ctx.saved_misc
505
- if filter_mode == 'linear-mipmap-linear':
506
- g_tex, g_uv, g_uv_da, g_mip_level_bias, g_mip_stack = _get_plugin().texture_grad_linear_mipmap_linear(tex, uv, dy, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum)
507
- return (None, g_tex, g_uv, g_uv_da, g_mip_level_bias, None, None, None) + tuple(g_mip_stack)
508
- else: # linear-mipmap-nearest
509
- g_tex, g_uv, g_mip_stack = _get_plugin().texture_grad_linear_mipmap_nearest(tex, uv, dy, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum)
510
- return (None, g_tex, g_uv, None, None, None, None, None) + tuple(g_mip_stack)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
 
512
  # Linear and nearest: Mipmaps disabled.
513
  class _texture_func(torch.autograd.Function):
@@ -522,15 +654,29 @@ class _texture_func(torch.autograd.Function):
522
  def backward(ctx, dy):
523
  tex, uv = ctx.saved_tensors
524
  filter_mode, filter_mode_enum, boundary_mode_enum = ctx.saved_misc
525
- if filter_mode == 'linear':
526
- g_tex, g_uv = _get_plugin().texture_grad_linear(tex, uv, dy, filter_mode_enum, boundary_mode_enum)
 
 
527
  return None, g_tex, g_uv, None, None
528
- else: # nearest
529
- g_tex = _get_plugin().texture_grad_nearest(tex, uv, dy, filter_mode_enum, boundary_mode_enum)
 
 
530
  return None, g_tex, None, None, None
531
 
 
532
  # Op wrapper.
533
- def texture(tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='auto', boundary_mode='wrap', max_mip_level=None):
 
 
 
 
 
 
 
 
 
534
  """Perform texture sampling.
535
 
536
  All input tensors must be contiguous and reside in GPU memory. The output tensor
@@ -580,8 +726,12 @@ def texture(tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='aut
580
  """
581
 
582
  # Default filter mode.
583
- if filter_mode == 'auto':
584
- filter_mode = 'linear-mipmap-linear' if (uv_da is not None or mip_level_bias is not None) else 'linear'
 
 
 
 
585
 
586
  # Sanitize inputs.
587
  if max_mip_level is None:
@@ -592,23 +742,33 @@ def texture(tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='aut
592
 
593
  # Check inputs.
594
  assert isinstance(tex, torch.Tensor) and isinstance(uv, torch.Tensor)
595
- if 'mipmap' in filter_mode:
596
- assert isinstance(uv_da, torch.Tensor) or isinstance(mip_level_bias, torch.Tensor)
 
 
597
 
598
  # If mipping disabled via max level=0, we may as well use simpler filtering internally.
599
- if max_mip_level == 0 and filter_mode in ['linear-mipmap-nearest', 'linear-mipmap-linear']:
600
- filter_mode = 'linear'
 
 
 
601
 
602
  # Convert filter mode to internal enumeration.
603
- filter_mode_dict = {'nearest': 0, 'linear': 1, 'linear-mipmap-nearest': 2, 'linear-mipmap-linear': 3}
 
 
 
 
 
604
  filter_mode_enum = filter_mode_dict[filter_mode]
605
 
606
  # Convert boundary mode to internal enumeration.
607
- boundary_mode_dict = {'cube': 0, 'wrap': 1, 'clamp': 2, 'zero': 3}
608
  boundary_mode_enum = boundary_mode_dict[boundary_mode]
609
 
610
  # Construct a mipmap if necessary.
611
- if 'mipmap' in filter_mode:
612
  mip_wrapper, mip_stack = None, []
613
  if mip is not None:
614
  assert isinstance(mip, (_get_plugin().TextureMipWrapper, list))
@@ -618,13 +778,28 @@ def texture(tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='aut
618
  else:
619
  mip_wrapper = mip
620
  else:
621
- mip_wrapper = _get_plugin().texture_construct_mip(tex, max_mip_level, boundary_mode == 'cube')
 
 
622
 
623
  # Choose stub.
624
- if filter_mode == 'linear-mipmap-linear' or filter_mode == 'linear-mipmap-nearest':
625
- return _texture_func_mip.apply(filter_mode, tex, uv, uv_da, mip_level_bias, mip_wrapper, filter_mode_enum, boundary_mode_enum, *mip_stack)
 
 
 
 
 
 
 
 
 
 
626
  else:
627
- return _texture_func.apply(filter_mode, tex, uv, filter_mode_enum, boundary_mode_enum)
 
 
 
628
 
629
  # Mipmap precalculation for cases where the texture stays constant.
630
  def texture_construct_mip(tex, max_mip_level=None, cube_mode=False):
@@ -639,7 +814,7 @@ def texture_construct_mip(tex, max_mip_level=None, cube_mode=False):
639
  cube_mode: Must be set to True if `tex` specifies a cube map texture.
640
 
641
  Returns:
642
- An opaque object containing the mipmap stack. This can be supplied in a call to `texture()`
643
  in the `mip` argument.
644
  """
645
 
@@ -652,14 +827,18 @@ def texture_construct_mip(tex, max_mip_level=None, cube_mode=False):
652
  assert max_mip_level >= 0
653
  return _get_plugin().texture_construct_mip(tex, max_mip_level, cube_mode)
654
 
655
- #----------------------------------------------------------------------------
 
656
  # Antialias.
657
- #----------------------------------------------------------------------------
 
658
 
659
  class _antialias_func(torch.autograd.Function):
660
  @staticmethod
661
  def forward(ctx, color, rast, pos, tri, topology_hash, pos_gradient_boost):
662
- out, work_buffer = _get_plugin().antialias_fwd(color, rast, pos, tri, topology_hash)
 
 
663
  ctx.save_for_backward(color, rast, pos, tri)
664
  ctx.saved_misc = pos_gradient_boost, work_buffer
665
  return out
@@ -668,11 +847,14 @@ class _antialias_func(torch.autograd.Function):
668
  def backward(ctx, dy):
669
  color, rast, pos, tri = ctx.saved_tensors
670
  pos_gradient_boost, work_buffer = ctx.saved_misc
671
- g_color, g_pos = _get_plugin().antialias_grad(color, rast, pos, tri, dy, work_buffer)
 
 
672
  if pos_gradient_boost != 1.0:
673
  g_pos = g_pos * pos_gradient_boost
674
  return g_color, None, g_pos, None, None, None
675
 
 
676
  # Op wrapper.
677
  def antialias(color, rast, pos, tri, topology_hash=None, pos_gradient_boost=1.0):
678
  """Perform antialiasing.
@@ -711,13 +893,16 @@ def antialias(color, rast, pos, tri, topology_hash=None, pos_gradient_boost=1.0)
711
  topology_hash = _get_plugin().antialias_construct_topology_hash(tri)
712
 
713
  # Instantiate the function.
714
- return _antialias_func.apply(color, rast, pos, tri, topology_hash, pos_gradient_boost)
 
 
 
715
 
716
  # Topology hash precalculation for cases where the triangle array stays constant.
717
  def antialias_construct_topology_hash(tri):
718
  """Construct a topology hash for a triangle tensor.
719
 
720
- This function can be used for constructing a topology hash for a triangle tensor that is
721
  known to remain constant. This avoids reconstructing it every time `antialias()` is called.
722
 
723
  Args:
@@ -725,10 +910,11 @@ def antialias_construct_topology_hash(tri):
725
  GPU memory.
726
 
727
  Returns:
728
- An opaque object containing the topology hash. This can be supplied in a call to
729
  `antialias()` in the `topology_hash` argument.
730
  """
731
  assert isinstance(tri, torch.Tensor)
732
  return _get_plugin().antialias_construct_topology_hash(tri)
733
 
734
- #----------------------------------------------------------------------------
 
 
14
  import torch.utils.cpp_extension
15
  from . import _C
16
 
17
+ # ----------------------------------------------------------------------------
18
  # C++/Cuda plugin compiler/loader.
19
 
20
  _cached_plugin = {}
21
+
22
+
23
  def _get_plugin(gl=False):
24
  assert isinstance(gl, bool)
25
+
26
  # Modified with precompiled torch CUDA extension
27
  if not gl:
28
  return _C
 
32
  return _cached_plugin[gl]
33
 
34
  # Make sure we can find the necessary compiler and libary binaries.
35
+ if os.name == "nt":
36
  lib_dir = os.path.dirname(__file__) + r"\..\lib"
37
+
38
  def find_cl_path():
39
  import glob
40
+
41
  def get_sort_key(x):
42
  # Primary criterion is VS version, secondary is edition, third is internal MSVC version.
43
+ x = x.split("\\")[3:]
44
+ x[1] = {
45
+ "BuildTools": "~0",
46
+ "Community": "~1",
47
+ "Pro": "~2",
48
+ "Professional": "~3",
49
+ "Enterprise": "~4",
50
+ }.get(x[1], x[1])
51
  return x
52
+
53
+ vs_relative_path = (
54
+ r"\Microsoft Visual Studio\*\*\VC\Tools\MSVC\*\bin\Hostx64\x64"
55
+ )
56
  paths = glob.glob(r"C:\Program Files" + vs_relative_path)
57
  paths += glob.glob(r"C:\Program Files (x86)" + vs_relative_path)
58
  if paths:
 
62
  if os.system("where cl.exe >nul 2>nul") != 0:
63
  cl_path = find_cl_path()
64
  if cl_path is None:
65
+ raise RuntimeError(
66
+ "Could not locate a supported Microsoft Visual C++ installation"
67
+ )
68
+ os.environ["PATH"] += ";" + cl_path
69
 
70
  # Compiler options.
71
+ common_opts = ["-DNVDR_TORCH"]
72
  cc_opts = []
73
+ if os.name == "nt":
74
+ cc_opts += ["/wd4067", "/wd4624"] # Disable warnings in torch headers.
75
 
76
  # Linker options for the GL-interfacing plugin.
77
  ldflags = []
78
  if gl:
79
+ if os.name == "posix":
80
+ ldflags = ["-lGL", "-lEGL"]
81
+ elif os.name == "nt":
82
+ libs = ["gdi32", "opengl32", "user32", "setgpu"]
83
+ ldflags = ["/LIBPATH:" + lib_dir] + ["/DEFAULTLIB:" + x for x in libs]
84
 
85
  # List of source files.
86
  if gl:
87
  source_files = [
88
+ "../common/common.cpp",
89
+ "../common/glutil.cpp",
90
+ "../common/rasterize_gl.cpp",
91
+ "torch_bindings_gl.cpp",
92
+ "torch_rasterize_gl.cpp",
93
  ]
94
  else:
95
  source_files = [
96
+ "../common/cudaraster/impl/Buffer.cpp",
97
+ "../common/cudaraster/impl/CudaRaster.cpp",
98
+ "../common/cudaraster/impl/RasterImpl.cu",
99
+ "../common/cudaraster/impl/RasterImpl.cpp",
100
+ "../common/common.cpp",
101
+ "../common/rasterize.cu",
102
+ "../common/interpolate.cu",
103
+ "../common/texture.cu",
104
+ "../common/texture.cpp",
105
+ "../common/antialias.cu",
106
+ "torch_bindings.cpp",
107
+ "torch_rasterize.cpp",
108
+ "torch_interpolate.cpp",
109
+ "torch_texture.cpp",
110
+ "torch_antialias.cpp",
111
  ]
112
 
113
  # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine.
114
+ os.environ["TORCH_CUDA_ARCH_LIST"] = ""
115
 
116
  # On Linux, show a warning if GLEW is being forcibly loaded when compiling the GL plugin.
117
+ if gl and (os.name == "posix") and ("libGLEW" in os.environ.get("LD_PRELOAD", "")):
118
+ logging.getLogger("nvdiffrast").warning(
119
+ "Warning: libGLEW is being loaded via LD_PRELOAD, and will probably conflict with the OpenGL plugin"
120
+ )
121
 
122
  # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment.
123
+ plugin_name = "nvdiffrast_plugin" + ("_gl" if gl else "")
124
  try:
125
+ lock_fn = os.path.join(
126
+ torch.utils.cpp_extension._get_build_directory(plugin_name, False), "lock"
127
+ )
128
  if os.path.exists(lock_fn):
129
+ logging.getLogger("nvdiffrast").warning(
130
+ "Lock file exists in build directory: '%s'" % lock_fn
131
+ )
132
  except:
133
  pass
134
 
135
  # Speed up compilation on Windows.
136
+ if os.name == "nt":
137
  # Skip telemetry sending step in vcvarsall.bat
138
+ os.environ["VSCMD_SKIP_SENDTELEMETRY"] = "1"
139
 
140
  # Opportunistically patch distutils to cache MSVC environments.
141
  try:
142
  import distutils._msvccompiler
143
  import functools
144
+
145
+ if not hasattr(distutils._msvccompiler._get_vc_env, "__wrapped__"):
146
+ distutils._msvccompiler._get_vc_env = functools.lru_cache()(
147
+ distutils._msvccompiler._get_vc_env
148
+ )
149
  except:
150
  pass
151
 
152
  # Compile and load.
153
  source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files]
154
+ torch.utils.cpp_extension.load(
155
+ name=plugin_name,
156
+ sources=source_paths,
157
+ extra_cflags=common_opts + cc_opts,
158
+ extra_cuda_cflags=common_opts + ["-lineinfo"],
159
+ extra_ldflags=ldflags,
160
+ with_cuda=True,
161
+ verbose=False,
162
+ )
163
 
164
  # Import, cache, and return the compiled module.
165
  _cached_plugin[gl] = importlib.import_module(plugin_name)
166
  return _cached_plugin[gl]
167
 
168
+
169
+ # ----------------------------------------------------------------------------
170
  # Log level.
171
+ # ----------------------------------------------------------------------------
172
+
173
 
174
  def get_log_level():
175
+ """Get current log level.
176
 
177
  Returns:
178
  Current log level in nvdiffrast. See `set_log_level()` for possible values.
179
+ """
180
  return _get_plugin().get_log_level()
181
 
182
+
183
  def set_log_level(level):
184
+ """Set log level.
185
 
186
  Log levels follow the convention on the C++ side of Torch:
187
  0 = Info,
 
191
  The default log level is 1.
192
 
193
  Args:
194
+ level: New log level as integer. Internal nvdiffrast messages of this
195
  severity or higher will be printed, while messages of lower
196
  severity will be silent.
197
+ """
198
  _get_plugin().set_log_level(level)
199
 
200
+
201
+ # ----------------------------------------------------------------------------
202
  # CudaRaster state wrapper.
203
+ # ----------------------------------------------------------------------------
204
+
205
 
206
  class RasterizeCudaContext:
207
  def __init__(self, device=None):
208
+ """Create a new Cuda rasterizer context.
209
 
210
  The context is deleted and internal storage is released when the object is
211
  destroyed.
 
217
  device.
218
  Returns:
219
  The newly created Cuda rasterizer context.
220
+ """
221
  if device is None:
222
  cuda_device_idx = torch.cuda.current_device()
223
  else:
 
227
  self.output_db = True
228
  self.active_depth_peeler = None
229
 
230
+
231
+ # ----------------------------------------------------------------------------
232
  # GL state wrapper.
233
+ # ----------------------------------------------------------------------------
234
+
235
 
236
  class RasterizeGLContext:
237
+ def __init__(self, output_db=True, mode="automatic", device=None):
238
+ """Create a new OpenGL rasterizer context.
239
 
240
  Creating an OpenGL context is a slow operation so you should usually reuse the same
241
  context in all calls to `rasterize()` on the same CPU thread. The OpenGL context
 
259
  device.
260
  Returns:
261
  The newly created OpenGL rasterizer context.
262
+ """
263
  assert output_db is True or output_db is False
264
+ assert mode in ["automatic", "manual"]
265
  self.output_db = output_db
266
  self.mode = mode
267
  if device is None:
 
269
  else:
270
  with torch.cuda.device(device):
271
  cuda_device_idx = torch.cuda.current_device()
272
+ self.cpp_wrapper = _get_plugin(gl=True).RasterizeGLStateWrapper(
273
+ output_db, mode == "automatic", cuda_device_idx
274
+ )
275
+ self.active_depth_peeler = None # For error checking only.
276
 
277
  def set_context(self):
278
+ """Set (activate) OpenGL context in the current CPU thread.
279
+ Only available if context was created in manual mode.
280
+ """
281
+ assert self.mode == "manual"
282
  self.cpp_wrapper.set_context()
283
 
284
  def release_context(self):
285
+ """Release (deactivate) currently active OpenGL context.
286
+ Only available if context was created in manual mode.
287
+ """
288
+ assert self.mode == "manual"
289
  self.cpp_wrapper.release_context()
290
 
291
+
292
+ # ----------------------------------------------------------------------------
293
  # Rasterize.
294
+ # ----------------------------------------------------------------------------
295
+
296
 
297
  class _rasterize_func(torch.autograd.Function):
298
  @staticmethod
299
  def forward(ctx, raster_ctx, pos, tri, resolution, ranges, grad_db, peeling_idx):
300
  if isinstance(raster_ctx, RasterizeGLContext):
301
+ out, out_db = _get_plugin(gl=True).rasterize_fwd_gl(
302
+ raster_ctx.cpp_wrapper, pos, tri, resolution, ranges, peeling_idx
303
+ )
304
  else:
305
+ out, out_db = _get_plugin().rasterize_fwd_cuda(
306
+ raster_ctx.cpp_wrapper, pos, tri, resolution, ranges, peeling_idx
307
+ )
308
  ctx.save_for_backward(pos, tri, out)
309
  ctx.saved_grad_db = grad_db
310
  return out, out_db
 
318
  g_pos = _get_plugin().rasterize_grad(pos, tri, out, dy)
319
  return None, g_pos, None, None, None, None, None
320
 
321
+
322
  # Op wrapper.
323
  def rasterize(glctx, pos, tri, resolution, ranges=None, grad_db=True):
324
+ """Rasterize triangles.
325
 
326
  All input tensors must be contiguous and reside in GPU memory except for
327
  the `ranges` tensor that, if specified, has to reside in CPU memory. The
 
349
  [minibatch_size, height, width, 4] and contain said derivatives in order
350
  (du/dX, du/dY, dv/dX, dv/dY). Otherwise it will be an empty tensor with shape
351
  [minibatch_size, height, width, 0].
352
+ """
353
  assert isinstance(glctx, (RasterizeGLContext, RasterizeCudaContext))
354
  assert grad_db is True or grad_db is False
355
  grad_db = grad_db and glctx.output_db
 
358
  assert isinstance(pos, torch.Tensor) and isinstance(tri, torch.Tensor)
359
  resolution = tuple(resolution)
360
  if ranges is None:
361
+ ranges = torch.empty(size=(0, 2), dtype=torch.int32, device="cpu")
362
  else:
363
  assert isinstance(ranges, torch.Tensor)
364
 
365
  # Check that context is not currently reserved for depth peeling.
366
  if glctx.active_depth_peeler is not None:
367
+ return RuntimeError(
368
+ "Cannot call rasterize() during depth peeling operation, use rasterize_next_layer() instead"
369
+ )
370
 
371
  # Instantiate the function.
372
  return _rasterize_func.apply(glctx, pos, tri, resolution, ranges, grad_db, -1)
373
 
374
+
375
+ # ----------------------------------------------------------------------------
376
  # Depth peeler context manager for rasterizing multiple depth layers.
377
+ # ----------------------------------------------------------------------------
378
+
379
 
380
  class DepthPeeler:
381
  def __init__(self, glctx, pos, tri, resolution, ranges=None, grad_db=True):
382
+ """Create a depth peeler object for rasterizing multiple depth layers.
383
 
384
  Arguments are the same as in `rasterize()`.
385
 
386
  Returns:
387
  The newly created depth peeler.
388
+ """
389
  assert isinstance(glctx, (RasterizeGLContext, RasterizeCudaContext))
390
  assert grad_db is True or grad_db is False
391
  grad_db = grad_db and glctx.output_db
 
394
  assert isinstance(pos, torch.Tensor) and isinstance(tri, torch.Tensor)
395
  resolution = tuple(resolution)
396
  if ranges is None:
397
+ ranges = torch.empty(size=(0, 2), dtype=torch.int32, device="cpu")
398
  else:
399
  assert isinstance(ranges, torch.Tensor)
400
 
 
411
  if self.raster_ctx is None:
412
  raise RuntimeError("Cannot re-enter a terminated depth peeling operation")
413
  if self.raster_ctx.active_depth_peeler is not None:
414
+ raise RuntimeError(
415
+ "Cannot have multiple depth peelers active simultaneously in a rasterization context"
416
+ )
417
  self.raster_ctx.active_depth_peeler = self
418
  self.peeling_idx = 0
419
  return self
 
421
  def __exit__(self, *args):
422
  assert self.raster_ctx.active_depth_peeler is self
423
  self.raster_ctx.active_depth_peeler = None
424
+ self.raster_ctx = (
425
+ None # Remove all references to input tensor so they're not left dangling.
426
+ )
427
  self.pos = None
428
  self.tri = None
429
  self.resolution = None
 
433
  return None
434
 
435
  def rasterize_next_layer(self):
436
+ """Rasterize next depth layer.
437
 
438
  Operation is equivalent to `rasterize()` except that previously reported
439
  surface points are culled away.
440
 
441
  Returns:
442
  A tuple of two tensors as in `rasterize()`.
443
+ """
444
  assert self.raster_ctx.active_depth_peeler is self
445
  assert self.peeling_idx >= 0
446
+ result = _rasterize_func.apply(
447
+ self.raster_ctx,
448
+ self.pos,
449
+ self.tri,
450
+ self.resolution,
451
+ self.ranges,
452
+ self.grad_db,
453
+ self.peeling_idx,
454
+ )
455
  self.peeling_idx += 1
456
  return result
457
 
458
+
459
+ # ----------------------------------------------------------------------------
460
  # Interpolate.
461
+ # ----------------------------------------------------------------------------
462
 
463
  # Output pixel differentials for at least some attributes.
464
  class _interpolate_func_da(torch.autograd.Function):
465
  @staticmethod
466
  def forward(ctx, attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list):
467
+ out, out_da = _get_plugin().interpolate_fwd_da(
468
+ attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list
469
+ )
470
  ctx.save_for_backward(attr, rast, tri, rast_db)
471
  ctx.saved_misc = diff_attrs_all, diff_attrs_list
472
  return out, out_da
 
475
  def backward(ctx, dy, dda):
476
  attr, rast, tri, rast_db = ctx.saved_tensors
477
  diff_attrs_all, diff_attrs_list = ctx.saved_misc
478
+ g_attr, g_rast, g_rast_db = _get_plugin().interpolate_grad_da(
479
+ attr, rast, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list
480
+ )
481
  return g_attr, g_rast, None, g_rast_db, None, None
482
 
483
+
484
  # No pixel differential for any attribute.
485
  class _interpolate_func(torch.autograd.Function):
486
  @staticmethod
 
495
  g_attr, g_rast = _get_plugin().interpolate_grad(attr, rast, tri, dy)
496
  return g_attr, g_rast, None
497
 
498
+
499
  # Op wrapper.
500
  def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None):
501
  """Interpolate vertex attributes.
 
504
  will be contiguous and reside in GPU memory.
505
 
506
  Args:
507
+ attr: Attribute tensor with dtype `torch.float32`.
508
+ Shape is [num_vertices, num_attributes] in range mode, or
509
  [minibatch_size, num_vertices, num_attributes] in instanced mode.
510
  Broadcasting is supported along the minibatch axis.
511
  rast: Main output tensor from `rasterize()`.
512
  tri: Triangle tensor with shape [num_triangles, 3] and dtype `torch.int32`.
513
+ rast_db: (Optional) Tensor containing image-space derivatives of barycentrics,
514
  i.e., the second output tensor from `rasterize()`. Enables computing
515
  image-space derivatives of attributes.
516
  diff_attrs: (Optional) List of attribute indices for which image-space
 
530
  # Sanitize the list of pixel differential attributes.
531
  if diff_attrs is None:
532
  diff_attrs = []
533
+ elif diff_attrs != "all":
534
  diff_attrs = np.asarray(diff_attrs, np.int32)
535
  assert len(diff_attrs.shape) == 1
536
  diff_attrs = diff_attrs.tolist()
537
 
538
+ diff_attrs_all = int(diff_attrs == "all")
539
  diff_attrs_list = [] if diff_attrs_all else diff_attrs
540
 
541
  # Check inputs.
 
545
 
546
  # Choose stub.
547
  if diff_attrs:
548
+ return _interpolate_func_da.apply(
549
+ attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list
550
+ )
551
  else:
552
  return _interpolate_func.apply(attr, rast, tri)
553
 
554
+
555
+ # ----------------------------------------------------------------------------
556
  # Texture
557
+ # ----------------------------------------------------------------------------
558
 
559
  # Linear-mipmap-linear and linear-mipmap-nearest: Mipmaps enabled.
560
  class _texture_func_mip(torch.autograd.Function):
561
  @staticmethod
562
+ def forward(
563
+ ctx,
564
+ filter_mode,
565
+ tex,
566
+ uv,
567
+ uv_da,
568
+ mip_level_bias,
569
+ mip_wrapper,
570
+ filter_mode_enum,
571
+ boundary_mode_enum,
572
+ *mip_stack
573
+ ):
574
  empty = torch.tensor([])
575
  if uv_da is None:
576
  uv_da = empty
 
578
  mip_level_bias = empty
579
  if mip_wrapper is None:
580
  mip_wrapper = _get_plugin().TextureMipWrapper()
581
+ out = _get_plugin().texture_fwd_mip(
582
+ tex,
583
+ uv,
584
+ uv_da,
585
+ mip_level_bias,
586
+ mip_wrapper,
587
+ mip_stack,
588
+ filter_mode_enum,
589
+ boundary_mode_enum,
590
+ )
591
  ctx.save_for_backward(tex, uv, uv_da, mip_level_bias, *mip_stack)
592
  ctx.saved_misc = filter_mode, mip_wrapper, filter_mode_enum, boundary_mode_enum
593
  return out
 
596
  def backward(ctx, dy):
597
  tex, uv, uv_da, mip_level_bias, *mip_stack = ctx.saved_tensors
598
  filter_mode, mip_wrapper, filter_mode_enum, boundary_mode_enum = ctx.saved_misc
599
+ if filter_mode == "linear-mipmap-linear":
600
+ (
601
+ g_tex,
602
+ g_uv,
603
+ g_uv_da,
604
+ g_mip_level_bias,
605
+ g_mip_stack,
606
+ ) = _get_plugin().texture_grad_linear_mipmap_linear(
607
+ tex,
608
+ uv,
609
+ dy,
610
+ uv_da,
611
+ mip_level_bias,
612
+ mip_wrapper,
613
+ mip_stack,
614
+ filter_mode_enum,
615
+ boundary_mode_enum,
616
+ )
617
+ return (
618
+ None,
619
+ g_tex,
620
+ g_uv,
621
+ g_uv_da,
622
+ g_mip_level_bias,
623
+ None,
624
+ None,
625
+ None,
626
+ ) + tuple(g_mip_stack)
627
+ else: # linear-mipmap-nearest
628
+ g_tex, g_uv, g_mip_stack = _get_plugin().texture_grad_linear_mipmap_nearest(
629
+ tex,
630
+ uv,
631
+ dy,
632
+ uv_da,
633
+ mip_level_bias,
634
+ mip_wrapper,
635
+ mip_stack,
636
+ filter_mode_enum,
637
+ boundary_mode_enum,
638
+ )
639
+ return (None, g_tex, g_uv, None, None, None, None, None) + tuple(
640
+ g_mip_stack
641
+ )
642
+
643
 
644
  # Linear and nearest: Mipmaps disabled.
645
  class _texture_func(torch.autograd.Function):
 
654
  def backward(ctx, dy):
655
  tex, uv = ctx.saved_tensors
656
  filter_mode, filter_mode_enum, boundary_mode_enum = ctx.saved_misc
657
+ if filter_mode == "linear":
658
+ g_tex, g_uv = _get_plugin().texture_grad_linear(
659
+ tex, uv, dy, filter_mode_enum, boundary_mode_enum
660
+ )
661
  return None, g_tex, g_uv, None, None
662
+ else: # nearest
663
+ g_tex = _get_plugin().texture_grad_nearest(
664
+ tex, uv, dy, filter_mode_enum, boundary_mode_enum
665
+ )
666
  return None, g_tex, None, None, None
667
 
668
+
669
  # Op wrapper.
670
+ def texture(
671
+ tex,
672
+ uv,
673
+ uv_da=None,
674
+ mip_level_bias=None,
675
+ mip=None,
676
+ filter_mode="auto",
677
+ boundary_mode="wrap",
678
+ max_mip_level=None,
679
+ ):
680
  """Perform texture sampling.
681
 
682
  All input tensors must be contiguous and reside in GPU memory. The output tensor
 
726
  """
727
 
728
  # Default filter mode.
729
+ if filter_mode == "auto":
730
+ filter_mode = (
731
+ "linear-mipmap-linear"
732
+ if (uv_da is not None or mip_level_bias is not None)
733
+ else "linear"
734
+ )
735
 
736
  # Sanitize inputs.
737
  if max_mip_level is None:
 
742
 
743
  # Check inputs.
744
  assert isinstance(tex, torch.Tensor) and isinstance(uv, torch.Tensor)
745
+ if "mipmap" in filter_mode:
746
+ assert isinstance(uv_da, torch.Tensor) or isinstance(
747
+ mip_level_bias, torch.Tensor
748
+ )
749
 
750
  # If mipping disabled via max level=0, we may as well use simpler filtering internally.
751
+ if max_mip_level == 0 and filter_mode in [
752
+ "linear-mipmap-nearest",
753
+ "linear-mipmap-linear",
754
+ ]:
755
+ filter_mode = "linear"
756
 
757
  # Convert filter mode to internal enumeration.
758
+ filter_mode_dict = {
759
+ "nearest": 0,
760
+ "linear": 1,
761
+ "linear-mipmap-nearest": 2,
762
+ "linear-mipmap-linear": 3,
763
+ }
764
  filter_mode_enum = filter_mode_dict[filter_mode]
765
 
766
  # Convert boundary mode to internal enumeration.
767
+ boundary_mode_dict = {"cube": 0, "wrap": 1, "clamp": 2, "zero": 3}
768
  boundary_mode_enum = boundary_mode_dict[boundary_mode]
769
 
770
  # Construct a mipmap if necessary.
771
+ if "mipmap" in filter_mode:
772
  mip_wrapper, mip_stack = None, []
773
  if mip is not None:
774
  assert isinstance(mip, (_get_plugin().TextureMipWrapper, list))
 
778
  else:
779
  mip_wrapper = mip
780
  else:
781
+ mip_wrapper = _get_plugin().texture_construct_mip(
782
+ tex, max_mip_level, boundary_mode == "cube"
783
+ )
784
 
785
  # Choose stub.
786
+ if filter_mode == "linear-mipmap-linear" or filter_mode == "linear-mipmap-nearest":
787
+ return _texture_func_mip.apply(
788
+ filter_mode,
789
+ tex,
790
+ uv,
791
+ uv_da,
792
+ mip_level_bias,
793
+ mip_wrapper,
794
+ filter_mode_enum,
795
+ boundary_mode_enum,
796
+ *mip_stack
797
+ )
798
  else:
799
+ return _texture_func.apply(
800
+ filter_mode, tex, uv, filter_mode_enum, boundary_mode_enum
801
+ )
802
+
803
 
804
  # Mipmap precalculation for cases where the texture stays constant.
805
  def texture_construct_mip(tex, max_mip_level=None, cube_mode=False):
 
814
  cube_mode: Must be set to True if `tex` specifies a cube map texture.
815
 
816
  Returns:
817
+ An opaque object containing the mipmap stack. This can be supplied in a call to `texture()`
818
  in the `mip` argument.
819
  """
820
 
 
827
  assert max_mip_level >= 0
828
  return _get_plugin().texture_construct_mip(tex, max_mip_level, cube_mode)
829
 
830
+
831
+ # ----------------------------------------------------------------------------
832
  # Antialias.
833
+ # ----------------------------------------------------------------------------
834
+
835
 
836
  class _antialias_func(torch.autograd.Function):
837
  @staticmethod
838
  def forward(ctx, color, rast, pos, tri, topology_hash, pos_gradient_boost):
839
+ out, work_buffer = _get_plugin().antialias_fwd(
840
+ color, rast, pos, tri, topology_hash
841
+ )
842
  ctx.save_for_backward(color, rast, pos, tri)
843
  ctx.saved_misc = pos_gradient_boost, work_buffer
844
  return out
 
847
  def backward(ctx, dy):
848
  color, rast, pos, tri = ctx.saved_tensors
849
  pos_gradient_boost, work_buffer = ctx.saved_misc
850
+ g_color, g_pos = _get_plugin().antialias_grad(
851
+ color, rast, pos, tri, dy, work_buffer
852
+ )
853
  if pos_gradient_boost != 1.0:
854
  g_pos = g_pos * pos_gradient_boost
855
  return g_color, None, g_pos, None, None, None
856
 
857
+
858
  # Op wrapper.
859
  def antialias(color, rast, pos, tri, topology_hash=None, pos_gradient_boost=1.0):
860
  """Perform antialiasing.
 
893
  topology_hash = _get_plugin().antialias_construct_topology_hash(tri)
894
 
895
  # Instantiate the function.
896
+ return _antialias_func.apply(
897
+ color, rast, pos, tri, topology_hash, pos_gradient_boost
898
+ )
899
+
900
 
901
  # Topology hash precalculation for cases where the triangle array stays constant.
902
  def antialias_construct_topology_hash(tri):
903
  """Construct a topology hash for a triangle tensor.
904
 
905
+ This function can be used for constructing a topology hash for a triangle tensor that is
906
  known to remain constant. This avoids reconstructing it every time `antialias()` is called.
907
 
908
  Args:
 
910
  GPU memory.
911
 
912
  Returns:
913
+ An opaque object containing the topology hash. This can be supplied in a call to
914
  `antialias()` in the `topology_hash` argument.
915
  """
916
  assert isinstance(tri, torch.Tensor)
917
  return _get_plugin().antialias_construct_topology_hash(tri)
918
 
919
+
920
+ # ----------------------------------------------------------------------------
extensions/nvdiffrast/setup copy.py CHANGED
@@ -24,28 +24,31 @@ setuptools.setup(
24
  url="https://github.com/NVlabs/nvdiffrast",
25
  packages=setuptools.find_packages(),
26
  package_data={
27
- 'nvdiffrast': [
28
- 'common/*.h',
29
- 'common/*.inl',
30
- 'common/*.cu',
31
- 'common/*.cpp',
32
- 'common/cudaraster/*.hpp',
33
- 'common/cudaraster/impl/*.cpp',
34
- 'common/cudaraster/impl/*.hpp',
35
- 'common/cudaraster/impl/*.inl',
36
- 'common/cudaraster/impl/*.cu',
37
- 'lib/*.h',
38
- 'torch/*.h',
39
- 'torch/*.inl',
40
- 'torch/*.cpp',
41
- 'tensorflow/*.cu',
42
- ] + (['lib/*.lib'] if os.name == 'nt' else [])
 
43
  },
44
  include_package_data=True,
45
- install_requires=['numpy'], # note: can't require torch here as it will install torch even for a TensorFlow container
 
 
46
  classifiers=[
47
  "Programming Language :: Python :: 3",
48
  "Operating System :: OS Independent",
49
  ],
50
- python_requires='>=3.6',
51
  )
 
24
  url="https://github.com/NVlabs/nvdiffrast",
25
  packages=setuptools.find_packages(),
26
  package_data={
27
+ "nvdiffrast": [
28
+ "common/*.h",
29
+ "common/*.inl",
30
+ "common/*.cu",
31
+ "common/*.cpp",
32
+ "common/cudaraster/*.hpp",
33
+ "common/cudaraster/impl/*.cpp",
34
+ "common/cudaraster/impl/*.hpp",
35
+ "common/cudaraster/impl/*.inl",
36
+ "common/cudaraster/impl/*.cu",
37
+ "lib/*.h",
38
+ "torch/*.h",
39
+ "torch/*.inl",
40
+ "torch/*.cpp",
41
+ "tensorflow/*.cu",
42
+ ]
43
+ + (["lib/*.lib"] if os.name == "nt" else [])
44
  },
45
  include_package_data=True,
46
+ install_requires=[
47
+ "numpy"
48
+ ], # note: can't require torch here as it will install torch even for a TensorFlow container
49
  classifiers=[
50
  "Programming Language :: Python :: 3",
51
  "Operating System :: OS Independent",
52
  ],
53
+ python_requires=">=3.6",
54
  )
extensions/nvdiffrast/setup.py CHANGED
@@ -48,35 +48,35 @@ setuptools.setup(
48
  CUDAExtension(
49
  name="nvdiffrast.torch._C",
50
  sources=[
51
- 'nvdiffrast/common/cudaraster/impl/Buffer.cpp',
52
- 'nvdiffrast/common/cudaraster/impl/CudaRaster.cpp',
53
- 'nvdiffrast/common/cudaraster/impl/RasterImpl_.cu',
54
- 'nvdiffrast/common/cudaraster/impl/RasterImpl.cpp',
55
- 'nvdiffrast/common/common.cpp',
56
- 'nvdiffrast/common/rasterize.cu',
57
- 'nvdiffrast/common/interpolate.cu',
58
- 'nvdiffrast/common/texture_.cu',
59
- 'nvdiffrast/common/texture.cpp',
60
- 'nvdiffrast/common/antialias.cu',
61
- 'nvdiffrast/torch/torch_bindings.cpp',
62
- 'nvdiffrast/torch/torch_rasterize.cpp',
63
- 'nvdiffrast/torch/torch_interpolate.cpp',
64
- 'nvdiffrast/torch/torch_texture.cpp',
65
- 'nvdiffrast/torch/torch_antialias.cpp',
66
  ],
67
  extra_compile_args={
68
- 'cxx': ['-DNVDR_TORCH'],
69
- 'nvcc': ['-DNVDR_TORCH', '-lineinfo'],
70
  },
71
  )
72
  ],
73
- cmdclass={
74
- 'build_ext': BuildExtension
75
- },
76
- install_requires=['numpy'], # note: can't require torch here as it will install torch even for a TensorFlow container
77
  classifiers=[
78
  "Programming Language :: Python :: 3",
79
  "Operating System :: OS Independent",
80
  ],
81
- python_requires='>=3.6',
82
  )
 
48
  CUDAExtension(
49
  name="nvdiffrast.torch._C",
50
  sources=[
51
+ "nvdiffrast/common/cudaraster/impl/Buffer.cpp",
52
+ "nvdiffrast/common/cudaraster/impl/CudaRaster.cpp",
53
+ "nvdiffrast/common/cudaraster/impl/RasterImpl_.cu",
54
+ "nvdiffrast/common/cudaraster/impl/RasterImpl.cpp",
55
+ "nvdiffrast/common/common.cpp",
56
+ "nvdiffrast/common/rasterize.cu",
57
+ "nvdiffrast/common/interpolate.cu",
58
+ "nvdiffrast/common/texture_.cu",
59
+ "nvdiffrast/common/texture.cpp",
60
+ "nvdiffrast/common/antialias.cu",
61
+ "nvdiffrast/torch/torch_bindings.cpp",
62
+ "nvdiffrast/torch/torch_rasterize.cpp",
63
+ "nvdiffrast/torch/torch_interpolate.cpp",
64
+ "nvdiffrast/torch/torch_texture.cpp",
65
+ "nvdiffrast/torch/torch_antialias.cpp",
66
  ],
67
  extra_compile_args={
68
+ "cxx": ["-DNVDR_TORCH"],
69
+ "nvcc": ["-DNVDR_TORCH", "-lineinfo"],
70
  },
71
  )
72
  ],
73
+ cmdclass={"build_ext": BuildExtension},
74
+ install_requires=[
75
+ "numpy"
76
+ ], # note: can't require torch here as it will install torch even for a TensorFlow container
77
  classifiers=[
78
  "Programming Language :: Python :: 3",
79
  "Operating System :: OS Independent",
80
  ],
81
+ python_requires=">=3.6",
82
  )
requirements.txt CHANGED
@@ -26,4 +26,4 @@ https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/diff_gaus
26
  https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true
27
  spaces
28
  plyfile==1.1
29
- utils3d
 
26
  https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true
27
  spaces
28
  plyfile==1.1
29
+ utils3d
trellis/models/__init__.py CHANGED
@@ -1,20 +1,21 @@
1
  import importlib
2
 
3
  __attributes = {
4
- 'SparseStructureEncoder': 'sparse_structure_vae',
5
- 'SparseStructureDecoder': 'sparse_structure_vae',
6
- 'SparseStructureFlowModel': 'sparse_structure_flow',
7
- 'SLatEncoder': 'structured_latent_vae',
8
- 'SLatGaussianDecoder': 'structured_latent_vae',
9
- 'SLatRadianceFieldDecoder': 'structured_latent_vae',
10
- 'SLatMeshDecoder': 'structured_latent_vae',
11
- 'SLatFlowModel': 'structured_latent_flow',
12
  }
13
 
14
  __submodules = []
15
 
16
  __all__ = list(__attributes.keys()) + __submodules
17
 
 
18
  def __getattr__(name):
19
  if name not in globals():
20
  if name in __attributes:
@@ -41,6 +42,7 @@ def from_pretrained(path: str, **kwargs):
41
  import os
42
  import json
43
  from safetensors.torch import load_file
 
44
  is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
45
 
46
  if is_local:
@@ -48,23 +50,29 @@ def from_pretrained(path: str, **kwargs):
48
  model_file = f"{path}.safetensors"
49
  else:
50
  from huggingface_hub import hf_hub_download
51
- path_parts = path.split('/')
52
- repo_id = f'{path_parts[0]}/{path_parts[1]}'
53
- model_name = '/'.join(path_parts[2:])
 
54
  config_file = hf_hub_download(repo_id, f"{model_name}.json")
55
  model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
56
 
57
- with open(config_file, 'r') as f:
58
  config = json.load(f)
59
- model = __getattr__(config['name'])(**config['args'], **kwargs)
60
  model.load_state_dict(load_file(model_file))
61
 
62
  return model
63
 
64
 
65
  # For Pylance
66
- if __name__ == '__main__':
67
  from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder
68
  from .sparse_structure_flow import SparseStructureFlowModel
69
- from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder
 
 
 
 
 
70
  from .structured_latent_flow import SLatFlowModel
 
1
  import importlib
2
 
3
  __attributes = {
4
+ "SparseStructureEncoder": "sparse_structure_vae",
5
+ "SparseStructureDecoder": "sparse_structure_vae",
6
+ "SparseStructureFlowModel": "sparse_structure_flow",
7
+ "SLatEncoder": "structured_latent_vae",
8
+ "SLatGaussianDecoder": "structured_latent_vae",
9
+ "SLatRadianceFieldDecoder": "structured_latent_vae",
10
+ "SLatMeshDecoder": "structured_latent_vae",
11
+ "SLatFlowModel": "structured_latent_flow",
12
  }
13
 
14
  __submodules = []
15
 
16
  __all__ = list(__attributes.keys()) + __submodules
17
 
18
+
19
  def __getattr__(name):
20
  if name not in globals():
21
  if name in __attributes:
 
42
  import os
43
  import json
44
  from safetensors.torch import load_file
45
+
46
  is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
47
 
48
  if is_local:
 
50
  model_file = f"{path}.safetensors"
51
  else:
52
  from huggingface_hub import hf_hub_download
53
+
54
+ path_parts = path.split("/")
55
+ repo_id = f"{path_parts[0]}/{path_parts[1]}"
56
+ model_name = "/".join(path_parts[2:])
57
  config_file = hf_hub_download(repo_id, f"{model_name}.json")
58
  model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
59
 
60
+ with open(config_file, "r") as f:
61
  config = json.load(f)
62
+ model = __getattr__(config["name"])(**config["args"], **kwargs)
63
  model.load_state_dict(load_file(model_file))
64
 
65
  return model
66
 
67
 
68
  # For Pylance
69
+ if __name__ == "__main__":
70
  from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder
71
  from .sparse_structure_flow import SparseStructureFlowModel
72
+ from .structured_latent_vae import (
73
+ SLatEncoder,
74
+ SLatGaussianDecoder,
75
+ SLatRadianceFieldDecoder,
76
+ SLatMeshDecoder,
77
+ )
78
  from .structured_latent_flow import SLatFlowModel
trellis/models/sparse_structure_flow.py CHANGED
@@ -4,7 +4,10 @@ import torch.nn as nn
4
  import torch.nn.functional as F
5
  import numpy as np
6
  from ..modules.utils import convert_module_to_f16, convert_module_to_f32
7
- from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
 
 
 
8
  from ..modules.spatial import patchify, unpatchify
9
 
10
 
@@ -12,6 +15,7 @@ class TimestepEmbedder(nn.Module):
12
  """
13
  Embeds scalar timesteps into vector representations.
14
  """
 
15
  def __init__(self, hidden_size, frequency_embedding_size=256):
16
  super().__init__()
17
  self.mlp = nn.Sequential(
@@ -38,12 +42,16 @@ class TimestepEmbedder(nn.Module):
38
  # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
39
  half = dim // 2
40
  freqs = torch.exp(
41
- -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
 
 
42
  ).to(device=t.device)
43
  args = t[:, None].float() * freqs[None]
44
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
  if dim % 2:
46
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
 
 
47
  return embedding
48
 
49
  def forward(self, t):
@@ -93,34 +101,41 @@ class SparseStructureFlowModel(nn.Module):
93
  self.t_embedder = TimestepEmbedder(model_channels)
94
  if share_mod:
95
  self.adaLN_modulation = nn.Sequential(
96
- nn.SiLU(),
97
- nn.Linear(model_channels, 6 * model_channels, bias=True)
98
  )
99
 
100
  if pe_mode == "ape":
101
  pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
102
- coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij')
 
 
 
 
 
 
103
  coords = torch.stack(coords, dim=-1).reshape(-1, 3)
104
  pos_emb = pos_embedder(coords)
105
  self.register_buffer("pos_emb", pos_emb)
106
 
107
  self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
108
-
109
- self.blocks = nn.ModuleList([
110
- ModulatedTransformerCrossBlock(
111
- model_channels,
112
- cond_channels,
113
- num_heads=self.num_heads,
114
- mlp_ratio=self.mlp_ratio,
115
- attn_mode='full',
116
- use_checkpoint=self.use_checkpoint,
117
- use_rope=(pe_mode == "rope"),
118
- share_mod=share_mod,
119
- qk_rms_norm=self.qk_rms_norm,
120
- qk_rms_norm_cross=self.qk_rms_norm_cross,
121
- )
122
- for _ in range(num_blocks)
123
- ])
 
 
124
 
125
  self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
126
 
@@ -154,6 +169,7 @@ class SparseStructureFlowModel(nn.Module):
154
  torch.nn.init.xavier_uniform_(module.weight)
155
  if module.bias is not None:
156
  nn.init.constant_(module.bias, 0)
 
157
  self.apply(_basic_init)
158
 
159
  # Initialize timestep embedding MLP:
@@ -173,9 +189,14 @@ class SparseStructureFlowModel(nn.Module):
173
  nn.init.constant_(self.out_layer.weight, 0)
174
  nn.init.constant_(self.out_layer.bias, 0)
175
 
176
- def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
177
- assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
178
- f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
 
 
 
 
 
179
 
180
  h = patchify(x, self.patch_size)
181
  h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
@@ -194,7 +215,9 @@ class SparseStructureFlowModel(nn.Module):
194
  h = F.layer_norm(h, h.shape[-1:])
195
  h = self.out_layer(h)
196
 
197
- h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
 
 
198
  h = unpatchify(h, self.patch_size).contiguous()
199
 
200
  return h
 
4
  import torch.nn.functional as F
5
  import numpy as np
6
  from ..modules.utils import convert_module_to_f16, convert_module_to_f32
7
+ from ..modules.transformer import (
8
+ AbsolutePositionEmbedder,
9
+ ModulatedTransformerCrossBlock,
10
+ )
11
  from ..modules.spatial import patchify, unpatchify
12
 
13
 
 
15
  """
16
  Embeds scalar timesteps into vector representations.
17
  """
18
+
19
  def __init__(self, hidden_size, frequency_embedding_size=256):
20
  super().__init__()
21
  self.mlp = nn.Sequential(
 
42
  # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
43
  half = dim // 2
44
  freqs = torch.exp(
45
+ -np.log(max_period)
46
+ * torch.arange(start=0, end=half, dtype=torch.float32)
47
+ / half
48
  ).to(device=t.device)
49
  args = t[:, None].float() * freqs[None]
50
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
51
  if dim % 2:
52
+ embedding = torch.cat(
53
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
54
+ )
55
  return embedding
56
 
57
  def forward(self, t):
 
101
  self.t_embedder = TimestepEmbedder(model_channels)
102
  if share_mod:
103
  self.adaLN_modulation = nn.Sequential(
104
+ nn.SiLU(), nn.Linear(model_channels, 6 * model_channels, bias=True)
 
105
  )
106
 
107
  if pe_mode == "ape":
108
  pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
109
+ coords = torch.meshgrid(
110
+ *[
111
+ torch.arange(res, device=self.device)
112
+ for res in [resolution // patch_size] * 3
113
+ ],
114
+ indexing="ij",
115
+ )
116
  coords = torch.stack(coords, dim=-1).reshape(-1, 3)
117
  pos_emb = pos_embedder(coords)
118
  self.register_buffer("pos_emb", pos_emb)
119
 
120
  self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
121
+
122
+ self.blocks = nn.ModuleList(
123
+ [
124
+ ModulatedTransformerCrossBlock(
125
+ model_channels,
126
+ cond_channels,
127
+ num_heads=self.num_heads,
128
+ mlp_ratio=self.mlp_ratio,
129
+ attn_mode="full",
130
+ use_checkpoint=self.use_checkpoint,
131
+ use_rope=(pe_mode == "rope"),
132
+ share_mod=share_mod,
133
+ qk_rms_norm=self.qk_rms_norm,
134
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
135
+ )
136
+ for _ in range(num_blocks)
137
+ ]
138
+ )
139
 
140
  self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
141
 
 
169
  torch.nn.init.xavier_uniform_(module.weight)
170
  if module.bias is not None:
171
  nn.init.constant_(module.bias, 0)
172
+
173
  self.apply(_basic_init)
174
 
175
  # Initialize timestep embedding MLP:
 
189
  nn.init.constant_(self.out_layer.weight, 0)
190
  nn.init.constant_(self.out_layer.bias, 0)
191
 
192
+ def forward(
193
+ self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor
194
+ ) -> torch.Tensor:
195
+ assert [*x.shape] == [
196
+ x.shape[0],
197
+ self.in_channels,
198
+ *[self.resolution] * 3,
199
+ ], f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
200
 
201
  h = patchify(x, self.patch_size)
202
  h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
 
215
  h = F.layer_norm(h, h.shape[-1:])
216
  h = self.out_layer(h)
217
 
218
+ h = h.permute(0, 2, 1).view(
219
+ h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3
220
+ )
221
  h = unpatchify(h, self.patch_size).contiguous()
222
 
223
  return h
trellis/models/sparse_structure_vae.py CHANGED
@@ -33,9 +33,15 @@ class ResBlock3d(nn.Module):
33
  self.norm1 = norm_layer(norm_type, channels)
34
  self.norm2 = norm_layer(norm_type, self.out_channels)
35
  self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
36
- self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
37
- self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
38
-
 
 
 
 
 
 
39
  def forward(self, x: torch.Tensor) -> torch.Tensor:
40
  h = self.norm1(x)
41
  h = F.silu(h)
@@ -63,7 +69,9 @@ class DownsampleBlock3d(nn.Module):
63
  if mode == "conv":
64
  self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
65
  elif mode == "avgpool":
66
- assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
 
 
67
 
68
  def forward(self, x: torch.Tensor) -> torch.Tensor:
69
  if hasattr(self, "conv"):
@@ -86,9 +94,11 @@ class UpsampleBlock3d(nn.Module):
86
  self.out_channels = out_channels
87
 
88
  if mode == "conv":
89
- self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
90
  elif mode == "nearest":
91
- assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
 
 
92
 
93
  def forward(self, x: torch.Tensor) -> torch.Tensor:
94
  if hasattr(self, "conv"):
@@ -96,12 +106,12 @@ class UpsampleBlock3d(nn.Module):
96
  return pixel_shuffle_3d(x, 2)
97
  else:
98
  return F.interpolate(x, scale_factor=2, mode="nearest")
99
-
100
 
101
  class SparseStructureEncoder(nn.Module):
102
  """
103
  Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
104
-
105
  Args:
106
  in_channels (int): Channels of the input.
107
  latent_channels (int): Channels of the latent representation.
@@ -111,6 +121,7 @@ class SparseStructureEncoder(nn.Module):
111
  norm_type (Literal["group", "layer"]): Type of normalization layer.
112
  use_fp16 (bool): Whether to use FP16.
113
  """
 
114
  def __init__(
115
  self,
116
  in_channels: int,
@@ -135,24 +146,21 @@ class SparseStructureEncoder(nn.Module):
135
 
136
  self.blocks = nn.ModuleList([])
137
  for i, ch in enumerate(channels):
138
- self.blocks.extend([
139
- ResBlock3d(ch, ch)
140
- for _ in range(num_res_blocks)
141
- ])
142
  if i < len(channels) - 1:
143
- self.blocks.append(
144
- DownsampleBlock3d(ch, channels[i+1])
145
- )
146
-
147
- self.middle_block = nn.Sequential(*[
148
- ResBlock3d(channels[-1], channels[-1])
149
- for _ in range(num_res_blocks_middle)
150
- ])
151
 
152
  self.out_layer = nn.Sequential(
153
  norm_layer(norm_type, channels[-1]),
154
  nn.SiLU(),
155
- nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1)
156
  )
157
 
158
  if use_fp16:
@@ -183,7 +191,9 @@ class SparseStructureEncoder(nn.Module):
183
  self.blocks.apply(convert_module_to_f32)
184
  self.middle_block.apply(convert_module_to_f32)
185
 
186
- def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor:
 
 
187
  h = self.input_layer(x)
188
  h = h.type(self.dtype)
189
 
@@ -201,16 +211,16 @@ class SparseStructureEncoder(nn.Module):
201
  z = mean + std * torch.randn_like(std)
202
  else:
203
  z = mean
204
-
205
  if return_raw:
206
  return z, mean, logvar
207
  return z
208
-
209
 
210
  class SparseStructureDecoder(nn.Module):
211
  """
212
  Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
213
-
214
  Args:
215
  out_channels (int): Channels of the output.
216
  latent_channels (int): Channels of the latent representation.
@@ -219,7 +229,8 @@ class SparseStructureDecoder(nn.Module):
219
  num_res_blocks_middle (int): Number of residual blocks in the middle.
220
  norm_type (Literal["group", "layer"]): Type of normalization layer.
221
  use_fp16 (bool): Whether to use FP16.
222
- """
 
223
  def __init__(
224
  self,
225
  out_channels: int,
@@ -242,26 +253,23 @@ class SparseStructureDecoder(nn.Module):
242
 
243
  self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
244
 
245
- self.middle_block = nn.Sequential(*[
246
- ResBlock3d(channels[0], channels[0])
247
- for _ in range(num_res_blocks_middle)
248
- ])
 
 
249
 
250
  self.blocks = nn.ModuleList([])
251
  for i, ch in enumerate(channels):
252
- self.blocks.extend([
253
- ResBlock3d(ch, ch)
254
- for _ in range(num_res_blocks)
255
- ])
256
  if i < len(channels) - 1:
257
- self.blocks.append(
258
- UpsampleBlock3d(ch, channels[i+1])
259
- )
260
 
261
  self.out_layer = nn.Sequential(
262
  norm_layer(norm_type, channels[-1]),
263
  nn.SiLU(),
264
- nn.Conv3d(channels[-1], out_channels, 3, padding=1)
265
  )
266
 
267
  if use_fp16:
@@ -273,7 +281,7 @@ class SparseStructureDecoder(nn.Module):
273
  Return the device of the model.
274
  """
275
  return next(self.parameters()).device
276
-
277
  def convert_to_fp16(self) -> None:
278
  """
279
  Convert the torso of the model to float16.
@@ -291,12 +299,12 @@ class SparseStructureDecoder(nn.Module):
291
  self.dtype = torch.float32
292
  self.blocks.apply(convert_module_to_f32)
293
  self.middle_block.apply(convert_module_to_f32)
294
-
295
  def forward(self, x: torch.Tensor) -> torch.Tensor:
296
  h = self.input_layer(x)
297
-
298
  h = h.type(self.dtype)
299
-
300
  h = self.middle_block(h)
301
  for block in self.blocks:
302
  h = block(h)
 
33
  self.norm1 = norm_layer(norm_type, channels)
34
  self.norm2 = norm_layer(norm_type, self.out_channels)
35
  self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
36
+ self.conv2 = zero_module(
37
+ nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)
38
+ )
39
+ self.skip_connection = (
40
+ nn.Conv3d(channels, self.out_channels, 1)
41
+ if channels != self.out_channels
42
+ else nn.Identity()
43
+ )
44
+
45
  def forward(self, x: torch.Tensor) -> torch.Tensor:
46
  h = self.norm1(x)
47
  h = F.silu(h)
 
69
  if mode == "conv":
70
  self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
71
  elif mode == "avgpool":
72
+ assert (
73
+ in_channels == out_channels
74
+ ), "Pooling mode requires in_channels to be equal to out_channels"
75
 
76
  def forward(self, x: torch.Tensor) -> torch.Tensor:
77
  if hasattr(self, "conv"):
 
94
  self.out_channels = out_channels
95
 
96
  if mode == "conv":
97
+ self.conv = nn.Conv3d(in_channels, out_channels * 8, 3, padding=1)
98
  elif mode == "nearest":
99
+ assert (
100
+ in_channels == out_channels
101
+ ), "Nearest mode requires in_channels to be equal to out_channels"
102
 
103
  def forward(self, x: torch.Tensor) -> torch.Tensor:
104
  if hasattr(self, "conv"):
 
106
  return pixel_shuffle_3d(x, 2)
107
  else:
108
  return F.interpolate(x, scale_factor=2, mode="nearest")
109
+
110
 
111
  class SparseStructureEncoder(nn.Module):
112
  """
113
  Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
114
+
115
  Args:
116
  in_channels (int): Channels of the input.
117
  latent_channels (int): Channels of the latent representation.
 
121
  norm_type (Literal["group", "layer"]): Type of normalization layer.
122
  use_fp16 (bool): Whether to use FP16.
123
  """
124
+
125
  def __init__(
126
  self,
127
  in_channels: int,
 
146
 
147
  self.blocks = nn.ModuleList([])
148
  for i, ch in enumerate(channels):
149
+ self.blocks.extend([ResBlock3d(ch, ch) for _ in range(num_res_blocks)])
 
 
 
150
  if i < len(channels) - 1:
151
+ self.blocks.append(DownsampleBlock3d(ch, channels[i + 1]))
152
+
153
+ self.middle_block = nn.Sequential(
154
+ *[
155
+ ResBlock3d(channels[-1], channels[-1])
156
+ for _ in range(num_res_blocks_middle)
157
+ ]
158
+ )
159
 
160
  self.out_layer = nn.Sequential(
161
  norm_layer(norm_type, channels[-1]),
162
  nn.SiLU(),
163
+ nn.Conv3d(channels[-1], latent_channels * 2, 3, padding=1),
164
  )
165
 
166
  if use_fp16:
 
191
  self.blocks.apply(convert_module_to_f32)
192
  self.middle_block.apply(convert_module_to_f32)
193
 
194
+ def forward(
195
+ self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False
196
+ ) -> torch.Tensor:
197
  h = self.input_layer(x)
198
  h = h.type(self.dtype)
199
 
 
211
  z = mean + std * torch.randn_like(std)
212
  else:
213
  z = mean
214
+
215
  if return_raw:
216
  return z, mean, logvar
217
  return z
218
+
219
 
220
  class SparseStructureDecoder(nn.Module):
221
  """
222
  Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
223
+
224
  Args:
225
  out_channels (int): Channels of the output.
226
  latent_channels (int): Channels of the latent representation.
 
229
  num_res_blocks_middle (int): Number of residual blocks in the middle.
230
  norm_type (Literal["group", "layer"]): Type of normalization layer.
231
  use_fp16 (bool): Whether to use FP16.
232
+ """
233
+
234
  def __init__(
235
  self,
236
  out_channels: int,
 
253
 
254
  self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
255
 
256
+ self.middle_block = nn.Sequential(
257
+ *[
258
+ ResBlock3d(channels[0], channels[0])
259
+ for _ in range(num_res_blocks_middle)
260
+ ]
261
+ )
262
 
263
  self.blocks = nn.ModuleList([])
264
  for i, ch in enumerate(channels):
265
+ self.blocks.extend([ResBlock3d(ch, ch) for _ in range(num_res_blocks)])
 
 
 
266
  if i < len(channels) - 1:
267
+ self.blocks.append(UpsampleBlock3d(ch, channels[i + 1]))
 
 
268
 
269
  self.out_layer = nn.Sequential(
270
  norm_layer(norm_type, channels[-1]),
271
  nn.SiLU(),
272
+ nn.Conv3d(channels[-1], out_channels, 3, padding=1),
273
  )
274
 
275
  if use_fp16:
 
281
  Return the device of the model.
282
  """
283
  return next(self.parameters()).device
284
+
285
  def convert_to_fp16(self) -> None:
286
  """
287
  Convert the torso of the model to float16.
 
299
  self.dtype = torch.float32
300
  self.blocks.apply(convert_module_to_f32)
301
  self.middle_block.apply(convert_module_to_f32)
302
+
303
  def forward(self, x: torch.Tensor) -> torch.Tensor:
304
  h = self.input_layer(x)
305
+
306
  h = h.type(self.dtype)
307
+
308
  h = self.middle_block(h)
309
  for block in self.blocks:
310
  h = block(h)
trellis/models/structured_latent_flow.py CHANGED
@@ -26,18 +26,26 @@ class SparseResBlock3d(nn.Module):
26
  self.out_channels = out_channels or channels
27
  self.downsample = downsample
28
  self.upsample = upsample
29
-
30
- assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
 
 
31
 
32
  self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
33
  self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
34
  self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
35
- self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
 
 
36
  self.emb_layers = nn.Sequential(
37
  nn.SiLU(),
38
  nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
39
  )
40
- self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
 
 
 
 
41
  self.updown = None
42
  if self.downsample:
43
  self.updown = sp.SparseDownsample(2)
@@ -63,7 +71,7 @@ class SparseResBlock3d(nn.Module):
63
  h = h + self.skip_connection(x)
64
 
65
  return h
66
-
67
 
68
  class SLatFlowModel(nn.Module):
69
  def __init__(
@@ -109,14 +117,17 @@ class SLatFlowModel(nn.Module):
109
  self.qk_rms_norm_cross = qk_rms_norm_cross
110
  self.dtype = torch.float16 if use_fp16 else torch.float32
111
 
112
- assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
113
- assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
 
 
 
 
114
 
115
  self.t_embedder = TimestepEmbedder(model_channels)
116
  if share_mod:
117
  self.adaLN_modulation = nn.Sequential(
118
- nn.SiLU(),
119
- nn.Linear(model_channels, 6 * model_channels, bias=True)
120
  )
121
 
122
  if pe_mode == "ape":
@@ -124,15 +135,19 @@ class SLatFlowModel(nn.Module):
124
 
125
  self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0])
126
  self.input_blocks = nn.ModuleList([])
127
- for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
128
- self.input_blocks.extend([
129
- SparseResBlock3d(
130
- chs,
131
- model_channels,
132
- out_channels=chs,
133
- )
134
- for _ in range(num_io_res_blocks-1)
135
- ])
 
 
 
 
136
  self.input_blocks.append(
137
  SparseResBlock3d(
138
  chs,
@@ -141,25 +156,30 @@ class SLatFlowModel(nn.Module):
141
  downsample=True,
142
  )
143
  )
144
-
145
- self.blocks = nn.ModuleList([
146
- ModulatedSparseTransformerCrossBlock(
147
- model_channels,
148
- cond_channels,
149
- num_heads=self.num_heads,
150
- mlp_ratio=self.mlp_ratio,
151
- attn_mode='full',
152
- use_checkpoint=self.use_checkpoint,
153
- use_rope=(pe_mode == "rope"),
154
- share_mod=self.share_mod,
155
- qk_rms_norm=self.qk_rms_norm,
156
- qk_rms_norm_cross=self.qk_rms_norm_cross,
157
- )
158
- for _ in range(num_blocks)
159
- ])
 
 
160
 
161
  self.out_blocks = nn.ModuleList([])
162
- for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
 
 
 
163
  self.out_blocks.append(
164
  SparseResBlock3d(
165
  prev_chs * 2 if self.use_skip_connection else prev_chs,
@@ -168,14 +188,16 @@ class SLatFlowModel(nn.Module):
168
  upsample=True,
169
  )
170
  )
171
- self.out_blocks.extend([
172
- SparseResBlock3d(
173
- chs * 2 if self.use_skip_connection else chs,
174
- model_channels,
175
- out_channels=chs,
176
- )
177
- for _ in range(num_io_res_blocks-1)
178
- ])
 
 
179
  self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels)
180
 
181
  self.initialize_weights()
@@ -212,6 +234,7 @@ class SLatFlowModel(nn.Module):
212
  torch.nn.init.xavier_uniform_(module.weight)
213
  if module.bias is not None:
214
  nn.init.constant_(module.bias, 0)
 
215
  self.apply(_basic_init)
216
 
217
  # Initialize timestep embedding MLP:
@@ -231,7 +254,9 @@ class SLatFlowModel(nn.Module):
231
  nn.init.constant_(self.out_layer.weight, 0)
232
  nn.init.constant_(self.out_layer.bias, 0)
233
 
234
- def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor:
 
 
235
  h = self.input_layer(x).type(self.dtype)
236
  t_emb = self.t_embedder(t)
237
  if self.share_mod:
@@ -244,7 +269,7 @@ class SLatFlowModel(nn.Module):
244
  for block in self.input_blocks:
245
  h = block(h, t_emb)
246
  skips.append(h.feats)
247
-
248
  if self.pe_mode == "ape":
249
  h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
250
  for block in self.blocks:
 
26
  self.out_channels = out_channels or channels
27
  self.downsample = downsample
28
  self.upsample = upsample
29
+
30
+ assert not (
31
+ downsample and upsample
32
+ ), "Cannot downsample and upsample at the same time"
33
 
34
  self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
35
  self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
36
  self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
37
+ self.conv2 = zero_module(
38
+ sp.SparseConv3d(self.out_channels, self.out_channels, 3)
39
+ )
40
  self.emb_layers = nn.Sequential(
41
  nn.SiLU(),
42
  nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
43
  )
44
+ self.skip_connection = (
45
+ sp.SparseLinear(channels, self.out_channels)
46
+ if channels != self.out_channels
47
+ else nn.Identity()
48
+ )
49
  self.updown = None
50
  if self.downsample:
51
  self.updown = sp.SparseDownsample(2)
 
71
  h = h + self.skip_connection(x)
72
 
73
  return h
74
+
75
 
76
  class SLatFlowModel(nn.Module):
77
  def __init__(
 
117
  self.qk_rms_norm_cross = qk_rms_norm_cross
118
  self.dtype = torch.float16 if use_fp16 else torch.float32
119
 
120
+ assert int(np.log2(patch_size)) == np.log2(
121
+ patch_size
122
+ ), "Patch size must be a power of 2"
123
+ assert np.log2(patch_size) == len(
124
+ io_block_channels
125
+ ), "Number of IO ResBlocks must match the number of stages"
126
 
127
  self.t_embedder = TimestepEmbedder(model_channels)
128
  if share_mod:
129
  self.adaLN_modulation = nn.Sequential(
130
+ nn.SiLU(), nn.Linear(model_channels, 6 * model_channels, bias=True)
 
131
  )
132
 
133
  if pe_mode == "ape":
 
135
 
136
  self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0])
137
  self.input_blocks = nn.ModuleList([])
138
+ for chs, next_chs in zip(
139
+ io_block_channels, io_block_channels[1:] + [model_channels]
140
+ ):
141
+ self.input_blocks.extend(
142
+ [
143
+ SparseResBlock3d(
144
+ chs,
145
+ model_channels,
146
+ out_channels=chs,
147
+ )
148
+ for _ in range(num_io_res_blocks - 1)
149
+ ]
150
+ )
151
  self.input_blocks.append(
152
  SparseResBlock3d(
153
  chs,
 
156
  downsample=True,
157
  )
158
  )
159
+
160
+ self.blocks = nn.ModuleList(
161
+ [
162
+ ModulatedSparseTransformerCrossBlock(
163
+ model_channels,
164
+ cond_channels,
165
+ num_heads=self.num_heads,
166
+ mlp_ratio=self.mlp_ratio,
167
+ attn_mode="full",
168
+ use_checkpoint=self.use_checkpoint,
169
+ use_rope=(pe_mode == "rope"),
170
+ share_mod=self.share_mod,
171
+ qk_rms_norm=self.qk_rms_norm,
172
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
173
+ )
174
+ for _ in range(num_blocks)
175
+ ]
176
+ )
177
 
178
  self.out_blocks = nn.ModuleList([])
179
+ for chs, prev_chs in zip(
180
+ reversed(io_block_channels),
181
+ [model_channels] + list(reversed(io_block_channels[1:])),
182
+ ):
183
  self.out_blocks.append(
184
  SparseResBlock3d(
185
  prev_chs * 2 if self.use_skip_connection else prev_chs,
 
188
  upsample=True,
189
  )
190
  )
191
+ self.out_blocks.extend(
192
+ [
193
+ SparseResBlock3d(
194
+ chs * 2 if self.use_skip_connection else chs,
195
+ model_channels,
196
+ out_channels=chs,
197
+ )
198
+ for _ in range(num_io_res_blocks - 1)
199
+ ]
200
+ )
201
  self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels)
202
 
203
  self.initialize_weights()
 
234
  torch.nn.init.xavier_uniform_(module.weight)
235
  if module.bias is not None:
236
  nn.init.constant_(module.bias, 0)
237
+
238
  self.apply(_basic_init)
239
 
240
  # Initialize timestep embedding MLP:
 
254
  nn.init.constant_(self.out_layer.weight, 0)
255
  nn.init.constant_(self.out_layer.bias, 0)
256
 
257
+ def forward(
258
+ self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor
259
+ ) -> sp.SparseTensor:
260
  h = self.input_layer(x).type(self.dtype)
261
  t_emb = self.t_embedder(t)
262
  if self.share_mod:
 
269
  for block in self.input_blocks:
270
  h = block(h, t_emb)
271
  skips.append(h.feats)
272
+
273
  if self.pe_mode == "ape":
274
  h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
275
  for block in self.blocks:
trellis/models/structured_latent_vae/base.py CHANGED
@@ -13,15 +13,23 @@ def block_attn_config(self):
13
  """
14
  for i in range(self.num_blocks):
15
  if self.attn_mode == "shift_window":
16
- yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
 
 
17
  elif self.attn_mode == "shift_sequence":
18
- yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
 
 
 
 
19
  elif self.attn_mode == "shift_order":
20
  yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
21
  elif self.attn_mode == "full":
22
  yield "full", None, None, None, None
23
  elif self.attn_mode == "swin":
24
- yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
 
 
25
 
26
 
27
  class SparseTransformerBase(nn.Module):
@@ -29,6 +37,7 @@ class SparseTransformerBase(nn.Module):
29
  Sparse Transformer without output layers.
30
  Serve as the base class for encoder and decoder.
31
  """
 
32
  def __init__(
33
  self,
34
  in_channels: int,
@@ -37,7 +46,9 @@ class SparseTransformerBase(nn.Module):
37
  num_heads: Optional[int] = None,
38
  num_head_channels: Optional[int] = 64,
39
  mlp_ratio: float = 4.0,
40
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
 
 
41
  window_size: Optional[int] = None,
42
  pe_mode: Literal["ape", "rope"] = "ape",
43
  use_fp16: bool = False,
@@ -62,22 +73,26 @@ class SparseTransformerBase(nn.Module):
62
  self.pos_embedder = AbsolutePositionEmbedder(model_channels)
63
 
64
  self.input_layer = sp.SparseLinear(in_channels, model_channels)
65
- self.blocks = nn.ModuleList([
66
- SparseTransformerBlock(
67
- model_channels,
68
- num_heads=self.num_heads,
69
- mlp_ratio=self.mlp_ratio,
70
- attn_mode=attn_mode,
71
- window_size=window_size,
72
- shift_sequence=shift_sequence,
73
- shift_window=shift_window,
74
- serialize_mode=serialize_mode,
75
- use_checkpoint=self.use_checkpoint,
76
- use_rope=(pe_mode == "rope"),
77
- qk_rms_norm=self.qk_rms_norm,
78
- )
79
- for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
80
- ])
 
 
 
 
81
 
82
  @property
83
  def device(self) -> torch.device:
@@ -105,6 +120,7 @@ class SparseTransformerBase(nn.Module):
105
  torch.nn.init.xavier_uniform_(module.weight)
106
  if module.bias is not None:
107
  nn.init.constant_(module.bias, 0)
 
108
  self.apply(_basic_init)
109
 
110
  def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
 
13
  """
14
  for i in range(self.num_blocks):
15
  if self.attn_mode == "shift_window":
16
+ yield "serialized", self.window_size, 0, (
17
+ 16 * (i % 2),
18
+ ) * 3, sp.SerializeMode.Z_ORDER
19
  elif self.attn_mode == "shift_sequence":
20
+ yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (
21
+ 0,
22
+ 0,
23
+ 0,
24
+ ), sp.SerializeMode.Z_ORDER
25
  elif self.attn_mode == "shift_order":
26
  yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
27
  elif self.attn_mode == "full":
28
  yield "full", None, None, None, None
29
  elif self.attn_mode == "swin":
30
+ yield "windowed", self.window_size, None, self.window_size // 2 * (
31
+ i % 2
32
+ ), None
33
 
34
 
35
  class SparseTransformerBase(nn.Module):
 
37
  Sparse Transformer without output layers.
38
  Serve as the base class for encoder and decoder.
39
  """
40
+
41
  def __init__(
42
  self,
43
  in_channels: int,
 
46
  num_heads: Optional[int] = None,
47
  num_head_channels: Optional[int] = 64,
48
  mlp_ratio: float = 4.0,
49
+ attn_mode: Literal[
50
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
51
+ ] = "full",
52
  window_size: Optional[int] = None,
53
  pe_mode: Literal["ape", "rope"] = "ape",
54
  use_fp16: bool = False,
 
73
  self.pos_embedder = AbsolutePositionEmbedder(model_channels)
74
 
75
  self.input_layer = sp.SparseLinear(in_channels, model_channels)
76
+ self.blocks = nn.ModuleList(
77
+ [
78
+ SparseTransformerBlock(
79
+ model_channels,
80
+ num_heads=self.num_heads,
81
+ mlp_ratio=self.mlp_ratio,
82
+ attn_mode=attn_mode,
83
+ window_size=window_size,
84
+ shift_sequence=shift_sequence,
85
+ shift_window=shift_window,
86
+ serialize_mode=serialize_mode,
87
+ use_checkpoint=self.use_checkpoint,
88
+ use_rope=(pe_mode == "rope"),
89
+ qk_rms_norm=self.qk_rms_norm,
90
+ )
91
+ for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(
92
+ self
93
+ )
94
+ ]
95
+ )
96
 
97
  @property
98
  def device(self) -> torch.device:
 
120
  torch.nn.init.xavier_uniform_(module.weight)
121
  if module.bias is not None:
122
  nn.init.constant_(module.bias, 0)
123
+
124
  self.apply(_basic_init)
125
 
126
  def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
trellis/models/structured_latent_vae/decoder_gs.py CHANGED
@@ -18,7 +18,9 @@ class SLatGaussianDecoder(SparseTransformerBase):
18
  num_heads: Optional[int] = None,
19
  num_head_channels: Optional[int] = 64,
20
  mlp_ratio: float = 4,
21
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
 
 
22
  window_size: int = 8,
23
  pe_mode: Literal["ape", "rope"] = "ape",
24
  use_fp16: bool = False,
@@ -57,26 +59,44 @@ class SLatGaussianDecoder(SparseTransformerBase):
57
  nn.init.constant_(self.out_layer.bias, 0)
58
 
59
  def _build_perturbation(self) -> None:
60
- perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])]
 
 
 
61
  perturbation = torch.tensor(perturbation).float() * 2 - 1
62
- perturbation = perturbation / self.rep_config['voxel_size']
63
  perturbation = torch.atanh(perturbation).to(self.device)
64
- self.register_buffer('offset_perturbation', perturbation)
65
 
66
  def _calc_layout(self) -> None:
67
  self.layout = {
68
- '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
69
- '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3},
70
- '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
71
- '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4},
72
- '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  }
74
  start = 0
75
  for k, v in self.layout.items():
76
- v['range'] = (start, start + v['size'])
77
- start += v['size']
78
  self.out_channels = start
79
-
80
  def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]:
81
  """
82
  Convert a batch of network outputs to 3D representations.
@@ -92,24 +112,35 @@ class SLatGaussianDecoder(SparseTransformerBase):
92
  representation = Gaussian(
93
  sh_degree=0,
94
  aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
95
- mininum_kernel_size = self.rep_config['3d_filter_kernel_size'],
96
- scaling_bias = self.rep_config['scaling_bias'],
97
- opacity_bias = self.rep_config['opacity_bias'],
98
- scaling_activation = self.rep_config['scaling_activation']
99
  )
100
  xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
101
  for k, v in self.layout.items():
102
- if k == '_xyz':
103
- offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])
104
- offset = offset * self.rep_config['lr'][k]
105
- if self.rep_config['perturb_offset']:
 
 
106
  offset = offset + self.offset_perturbation
107
- offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size']
 
 
 
 
 
108
  _xyz = xyz.unsqueeze(1) + offset
109
  setattr(representation, k, _xyz.flatten(0, 1))
110
  else:
111
- feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
112
- feats = feats * self.rep_config['lr'][k]
 
 
 
 
113
  setattr(representation, k, feats)
114
  ret.append(representation)
115
  return ret
 
18
  num_heads: Optional[int] = None,
19
  num_head_channels: Optional[int] = 64,
20
  mlp_ratio: float = 4,
21
+ attn_mode: Literal[
22
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
23
+ ] = "swin",
24
  window_size: int = 8,
25
  pe_mode: Literal["ape", "rope"] = "ape",
26
  use_fp16: bool = False,
 
59
  nn.init.constant_(self.out_layer.bias, 0)
60
 
61
  def _build_perturbation(self) -> None:
62
+ perturbation = [
63
+ hammersley_sequence(3, i, self.rep_config["num_gaussians"])
64
+ for i in range(self.rep_config["num_gaussians"])
65
+ ]
66
  perturbation = torch.tensor(perturbation).float() * 2 - 1
67
+ perturbation = perturbation / self.rep_config["voxel_size"]
68
  perturbation = torch.atanh(perturbation).to(self.device)
69
+ self.register_buffer("offset_perturbation", perturbation)
70
 
71
  def _calc_layout(self) -> None:
72
  self.layout = {
73
+ "_xyz": {
74
+ "shape": (self.rep_config["num_gaussians"], 3),
75
+ "size": self.rep_config["num_gaussians"] * 3,
76
+ },
77
+ "_features_dc": {
78
+ "shape": (self.rep_config["num_gaussians"], 1, 3),
79
+ "size": self.rep_config["num_gaussians"] * 3,
80
+ },
81
+ "_scaling": {
82
+ "shape": (self.rep_config["num_gaussians"], 3),
83
+ "size": self.rep_config["num_gaussians"] * 3,
84
+ },
85
+ "_rotation": {
86
+ "shape": (self.rep_config["num_gaussians"], 4),
87
+ "size": self.rep_config["num_gaussians"] * 4,
88
+ },
89
+ "_opacity": {
90
+ "shape": (self.rep_config["num_gaussians"], 1),
91
+ "size": self.rep_config["num_gaussians"],
92
+ },
93
  }
94
  start = 0
95
  for k, v in self.layout.items():
96
+ v["range"] = (start, start + v["size"])
97
+ start += v["size"]
98
  self.out_channels = start
99
+
100
  def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]:
101
  """
102
  Convert a batch of network outputs to 3D representations.
 
112
  representation = Gaussian(
113
  sh_degree=0,
114
  aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
115
+ mininum_kernel_size=self.rep_config["3d_filter_kernel_size"],
116
+ scaling_bias=self.rep_config["scaling_bias"],
117
+ opacity_bias=self.rep_config["opacity_bias"],
118
+ scaling_activation=self.rep_config["scaling_activation"],
119
  )
120
  xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
121
  for k, v in self.layout.items():
122
+ if k == "_xyz":
123
+ offset = x.feats[x.layout[i]][
124
+ :, v["range"][0] : v["range"][1]
125
+ ].reshape(-1, *v["shape"])
126
+ offset = offset * self.rep_config["lr"][k]
127
+ if self.rep_config["perturb_offset"]:
128
  offset = offset + self.offset_perturbation
129
+ offset = (
130
+ torch.tanh(offset)
131
+ / self.resolution
132
+ * 0.5
133
+ * self.rep_config["voxel_size"]
134
+ )
135
  _xyz = xyz.unsqueeze(1) + offset
136
  setattr(representation, k, _xyz.flatten(0, 1))
137
  else:
138
+ feats = (
139
+ x.feats[x.layout[i]][:, v["range"][0] : v["range"][1]]
140
+ .reshape(-1, *v["shape"])
141
+ .flatten(0, 1)
142
+ )
143
+ feats = feats * self.rep_config["lr"][k]
144
  setattr(representation, k, feats)
145
  ret.append(representation)
146
  return ret
trellis/models/structured_latent_vae/decoder_mesh.py CHANGED
@@ -19,12 +19,13 @@ class SparseSubdivideBlock3d(nn.Module):
19
  out_channels: if specified, the number of output channels.
20
  num_groups: the number of groups for the group norm.
21
  """
 
22
  def __init__(
23
  self,
24
  channels: int,
25
  resolution: int,
26
  out_channels: Optional[int] = None,
27
- num_groups: int = 32
28
  ):
29
  super().__init__()
30
  self.channels = channels
@@ -33,24 +34,34 @@ class SparseSubdivideBlock3d(nn.Module):
33
  self.out_channels = out_channels or channels
34
 
35
  self.act_layers = nn.Sequential(
36
- sp.SparseGroupNorm32(num_groups, channels),
37
- sp.SparseSiLU()
38
  )
39
-
40
  self.sub = sp.SparseSubdivide()
41
-
42
  self.out_layers = nn.Sequential(
43
- sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
 
 
44
  sp.SparseGroupNorm32(num_groups, self.out_channels),
45
  sp.SparseSiLU(),
46
- zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
 
 
 
 
 
 
 
47
  )
48
-
49
  if self.out_channels == channels:
50
  self.skip_connection = nn.Identity()
51
  else:
52
- self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
53
-
 
 
54
  def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
55
  """
56
  Apply the block to a Tensor, conditioned on a timestep embedding.
@@ -78,7 +89,9 @@ class SLatMeshDecoder(SparseTransformerBase):
78
  num_heads: Optional[int] = None,
79
  num_head_channels: Optional[int] = 64,
80
  mlp_ratio: float = 4,
81
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
 
 
82
  window_size: int = 8,
83
  pe_mode: Literal["ape", "rope"] = "ape",
84
  use_fp16: bool = False,
@@ -102,20 +115,24 @@ class SLatMeshDecoder(SparseTransformerBase):
102
  )
103
  self.resolution = resolution
104
  self.rep_config = representation_config
105
- self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False))
 
 
106
  self.out_channels = self.mesh_extractor.feats_channels
107
- self.upsample = nn.ModuleList([
108
- SparseSubdivideBlock3d(
109
- channels=model_channels,
110
- resolution=resolution,
111
- out_channels=model_channels // 4
112
- ),
113
- SparseSubdivideBlock3d(
114
- channels=model_channels // 4,
115
- resolution=resolution * 2,
116
- out_channels=model_channels // 8
117
- )
118
- ])
 
 
119
  self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
120
 
121
  self.initialize_weights()
@@ -140,8 +157,8 @@ class SLatMeshDecoder(SparseTransformerBase):
140
  Convert the torso of the model to float32.
141
  """
142
  super().convert_to_fp32()
143
- self.upsample.apply(convert_module_to_f32)
144
-
145
  def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
146
  """
147
  Convert a batch of network outputs to 3D representations.
 
19
  out_channels: if specified, the number of output channels.
20
  num_groups: the number of groups for the group norm.
21
  """
22
+
23
  def __init__(
24
  self,
25
  channels: int,
26
  resolution: int,
27
  out_channels: Optional[int] = None,
28
+ num_groups: int = 32,
29
  ):
30
  super().__init__()
31
  self.channels = channels
 
34
  self.out_channels = out_channels or channels
35
 
36
  self.act_layers = nn.Sequential(
37
+ sp.SparseGroupNorm32(num_groups, channels), sp.SparseSiLU()
 
38
  )
39
+
40
  self.sub = sp.SparseSubdivide()
41
+
42
  self.out_layers = nn.Sequential(
43
+ sp.SparseConv3d(
44
+ channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"
45
+ ),
46
  sp.SparseGroupNorm32(num_groups, self.out_channels),
47
  sp.SparseSiLU(),
48
+ zero_module(
49
+ sp.SparseConv3d(
50
+ self.out_channels,
51
+ self.out_channels,
52
+ 3,
53
+ indice_key=f"res_{self.out_resolution}",
54
+ )
55
+ ),
56
  )
57
+
58
  if self.out_channels == channels:
59
  self.skip_connection = nn.Identity()
60
  else:
61
+ self.skip_connection = sp.SparseConv3d(
62
+ channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}"
63
+ )
64
+
65
  def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
66
  """
67
  Apply the block to a Tensor, conditioned on a timestep embedding.
 
89
  num_heads: Optional[int] = None,
90
  num_head_channels: Optional[int] = 64,
91
  mlp_ratio: float = 4,
92
+ attn_mode: Literal[
93
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
94
+ ] = "swin",
95
  window_size: int = 8,
96
  pe_mode: Literal["ape", "rope"] = "ape",
97
  use_fp16: bool = False,
 
115
  )
116
  self.resolution = resolution
117
  self.rep_config = representation_config
118
+ self.mesh_extractor = SparseFeatures2Mesh(
119
+ res=self.resolution * 4, use_color=self.rep_config.get("use_color", False)
120
+ )
121
  self.out_channels = self.mesh_extractor.feats_channels
122
+ self.upsample = nn.ModuleList(
123
+ [
124
+ SparseSubdivideBlock3d(
125
+ channels=model_channels,
126
+ resolution=resolution,
127
+ out_channels=model_channels // 4,
128
+ ),
129
+ SparseSubdivideBlock3d(
130
+ channels=model_channels // 4,
131
+ resolution=resolution * 2,
132
+ out_channels=model_channels // 8,
133
+ ),
134
+ ]
135
+ )
136
  self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
137
 
138
  self.initialize_weights()
 
157
  Convert the torso of the model to float32.
158
  """
159
  super().convert_to_fp32()
160
+ self.upsample.apply(convert_module_to_f32)
161
+
162
  def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
163
  """
164
  Convert a batch of network outputs to 3D representations.
trellis/models/structured_latent_vae/decoder_rf.py CHANGED
@@ -18,7 +18,9 @@ class SLatRadianceFieldDecoder(SparseTransformerBase):
18
  num_heads: Optional[int] = None,
19
  num_head_channels: Optional[int] = 64,
20
  mlp_ratio: float = 4,
21
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
 
 
22
  window_size: int = 8,
23
  pe_mode: Literal["ape", "rope"] = "ape",
24
  use_fp16: bool = False,
@@ -57,16 +59,25 @@ class SLatRadianceFieldDecoder(SparseTransformerBase):
57
 
58
  def _calc_layout(self) -> None:
59
  self.layout = {
60
- 'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']},
61
- 'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']},
62
- 'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3},
 
 
 
 
 
 
 
 
 
63
  }
64
  start = 0
65
  for k, v in self.layout.items():
66
- v['range'] = (start, start + v['size'])
67
- start += v['size']
68
- self.out_channels = start
69
-
70
  def to_representation(self, x: sp.SparseTensor) -> List[Strivec]:
71
  """
72
  Convert a batch of network outputs to 3D representations.
@@ -83,15 +94,28 @@ class SLatRadianceFieldDecoder(SparseTransformerBase):
83
  sh_degree=0,
84
  resolution=self.resolution,
85
  aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
86
- rank=self.rep_config['rank'],
87
- dim=self.rep_config['dim'],
88
- device='cuda',
89
  )
90
  representation.density_shift = 0.0
91
- representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
92
- representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
 
 
 
 
 
 
 
93
  for k, v in self.layout.items():
94
- setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']))
 
 
 
 
 
 
95
  representation.trivec = representation.trivec + 1
96
  ret.append(representation)
97
  return ret
 
18
  num_heads: Optional[int] = None,
19
  num_head_channels: Optional[int] = 64,
20
  mlp_ratio: float = 4,
21
+ attn_mode: Literal[
22
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
23
+ ] = "swin",
24
  window_size: int = 8,
25
  pe_mode: Literal["ape", "rope"] = "ape",
26
  use_fp16: bool = False,
 
59
 
60
  def _calc_layout(self) -> None:
61
  self.layout = {
62
+ "trivec": {
63
+ "shape": (self.rep_config["rank"], 3, self.rep_config["dim"]),
64
+ "size": self.rep_config["rank"] * 3 * self.rep_config["dim"],
65
+ },
66
+ "density": {
67
+ "shape": (self.rep_config["rank"],),
68
+ "size": self.rep_config["rank"],
69
+ },
70
+ "features_dc": {
71
+ "shape": (self.rep_config["rank"], 1, 3),
72
+ "size": self.rep_config["rank"] * 3,
73
+ },
74
  }
75
  start = 0
76
  for k, v in self.layout.items():
77
+ v["range"] = (start, start + v["size"])
78
+ start += v["size"]
79
+ self.out_channels = start
80
+
81
  def to_representation(self, x: sp.SparseTensor) -> List[Strivec]:
82
  """
83
  Convert a batch of network outputs to 3D representations.
 
94
  sh_degree=0,
95
  resolution=self.resolution,
96
  aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
97
+ rank=self.rep_config["rank"],
98
+ dim=self.rep_config["dim"],
99
+ device="cuda",
100
  )
101
  representation.density_shift = 0.0
102
+ representation.position = (
103
+ x.coords[x.layout[i]][:, 1:].float() + 0.5
104
+ ) / self.resolution
105
+ representation.depth = torch.full(
106
+ (representation.position.shape[0], 1),
107
+ int(np.log2(self.resolution)),
108
+ dtype=torch.uint8,
109
+ device="cuda",
110
+ )
111
  for k, v in self.layout.items():
112
+ setattr(
113
+ representation,
114
+ k,
115
+ x.feats[x.layout[i]][:, v["range"][0] : v["range"][1]].reshape(
116
+ -1, *v["shape"]
117
+ ),
118
+ )
119
  representation.trivec = representation.trivec + 1
120
  ret.append(representation)
121
  return ret
trellis/models/structured_latent_vae/encoder.py CHANGED
@@ -17,7 +17,9 @@ class SLatEncoder(SparseTransformerBase):
17
  num_heads: Optional[int] = None,
18
  num_head_channels: Optional[int] = 64,
19
  mlp_ratio: float = 4,
20
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
 
 
21
  window_size: int = 8,
22
  pe_mode: Literal["ape", "rope"] = "ape",
23
  use_fp16: bool = False,
@@ -56,7 +58,7 @@ class SLatEncoder(SparseTransformerBase):
56
  h = h.type(x.dtype)
57
  h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
58
  h = self.out_layer(h)
59
-
60
  # Sample from the posterior distribution
61
  mean, logvar = h.feats.chunk(2, dim=-1)
62
  if sample_posterior:
@@ -65,7 +67,7 @@ class SLatEncoder(SparseTransformerBase):
65
  else:
66
  z = mean
67
  z = h.replace(z)
68
-
69
  if return_raw:
70
  return z, mean, logvar
71
  else:
 
17
  num_heads: Optional[int] = None,
18
  num_head_channels: Optional[int] = 64,
19
  mlp_ratio: float = 4,
20
+ attn_mode: Literal[
21
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
22
+ ] = "swin",
23
  window_size: int = 8,
24
  pe_mode: Literal["ape", "rope"] = "ape",
25
  use_fp16: bool = False,
 
58
  h = h.type(x.dtype)
59
  h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
60
  h = self.out_layer(h)
61
+
62
  # Sample from the posterior distribution
63
  mean, logvar = h.feats.chunk(2, dim=-1)
64
  if sample_posterior:
 
67
  else:
68
  z = mean
69
  z = h.replace(z)
70
+
71
  if return_raw:
72
  return z, mean, logvar
73
  else:
trellis/modules/attention/__init__.py CHANGED
@@ -1,32 +1,39 @@
1
  from typing import *
2
 
3
- BACKEND = 'flash_attn'
4
  DEBUG = False
5
 
 
6
  def __from_env():
7
  import os
8
-
9
  global BACKEND
10
  global DEBUG
11
-
12
- env_attn_backend = os.environ.get('ATTN_BACKEND')
13
- env_sttn_debug = os.environ.get('ATTN_DEBUG')
14
-
15
- if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
 
 
 
 
 
16
  BACKEND = env_attn_backend
17
  if env_sttn_debug is not None:
18
- DEBUG = env_sttn_debug == '1'
19
 
20
  print(f"[ATTENTION] Using backend: {BACKEND}")
21
-
22
 
23
  __from_env()
24
-
25
 
26
- def set_backend(backend: Literal['xformers', 'flash_attn']):
 
27
  global BACKEND
28
  BACKEND = backend
29
 
 
30
  def set_debug(debug: bool):
31
  global DEBUG
32
  DEBUG = debug
 
1
  from typing import *
2
 
3
+ BACKEND = "flash_attn"
4
  DEBUG = False
5
 
6
+
7
  def __from_env():
8
  import os
9
+
10
  global BACKEND
11
  global DEBUG
12
+
13
+ env_attn_backend = os.environ.get("ATTN_BACKEND")
14
+ env_sttn_debug = os.environ.get("ATTN_DEBUG")
15
+
16
+ if env_attn_backend is not None and env_attn_backend in [
17
+ "xformers",
18
+ "flash_attn",
19
+ "sdpa",
20
+ "naive",
21
+ ]:
22
  BACKEND = env_attn_backend
23
  if env_sttn_debug is not None:
24
+ DEBUG = env_sttn_debug == "1"
25
 
26
  print(f"[ATTENTION] Using backend: {BACKEND}")
27
+
28
 
29
  __from_env()
 
30
 
31
+
32
+ def set_backend(backend: Literal["xformers", "flash_attn"]):
33
  global BACKEND
34
  BACKEND = backend
35
 
36
+
37
  def set_debug(debug: bool):
38
  global DEBUG
39
  DEBUG = debug
trellis/modules/attention/full_attn.py CHANGED
@@ -3,20 +3,20 @@ import torch
3
  import math
4
  from . import DEBUG, BACKEND
5
 
6
- if BACKEND == 'xformers':
7
  import xformers.ops as xops
8
- elif BACKEND == 'flash_attn':
9
  import flash_attn
10
- elif BACKEND == 'sdpa':
11
  from torch.nn.functional import scaled_dot_product_attention as sdpa
12
- elif BACKEND == 'naive':
13
  pass
14
  else:
15
  raise ValueError(f"Unknown attention backend: {BACKEND}")
16
 
17
 
18
  __all__ = [
19
- 'scaled_dot_product_attention',
20
  ]
21
 
22
 
@@ -24,14 +24,14 @@ def _naive_sdpa(q, k, v):
24
  """
25
  Naive implementation of scaled dot product attention.
26
  """
27
- q = q.permute(0, 2, 1, 3) # [N, H, L, C]
28
- k = k.permute(0, 2, 1, 3) # [N, H, L, C]
29
- v = v.permute(0, 2, 1, 3) # [N, H, L, C]
30
  scale_factor = 1 / math.sqrt(q.size(-1))
31
  attn_weight = q @ k.transpose(-2, -1) * scale_factor
32
  attn_weight = torch.softmax(attn_weight, dim=-1)
33
  out = attn_weight @ v
34
- out = out.permute(0, 2, 1, 3) # [N, L, H, C]
35
  return out
36
 
37
 
@@ -45,6 +45,7 @@ def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
45
  """
46
  ...
47
 
 
48
  @overload
49
  def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
50
  """
@@ -56,8 +57,11 @@ def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Ten
56
  """
57
  ...
58
 
 
59
  @overload
60
- def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
 
 
61
  """
62
  Apply scaled dot product attention.
63
 
@@ -71,64 +75,79 @@ def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tens
71
  """
72
  ...
73
 
 
74
  def scaled_dot_product_attention(*args, **kwargs):
75
- arg_names_dict = {
76
- 1: ['qkv'],
77
- 2: ['q', 'kv'],
78
- 3: ['q', 'k', 'v']
79
- }
80
  num_all_args = len(args) + len(kwargs)
81
- assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
82
- for key in arg_names_dict[num_all_args][len(args):]:
 
 
83
  assert key in kwargs, f"Missing argument {key}"
84
 
85
  if num_all_args == 1:
86
- qkv = args[0] if len(args) > 0 else kwargs['qkv']
87
- assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
 
 
88
  device = qkv.device
89
 
90
  elif num_all_args == 2:
91
- q = args[0] if len(args) > 0 else kwargs['q']
92
- kv = args[1] if len(args) > 1 else kwargs['kv']
93
- assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
94
- assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
95
- assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
 
 
 
 
 
 
96
  device = q.device
97
 
98
  elif num_all_args == 3:
99
- q = args[0] if len(args) > 0 else kwargs['q']
100
- k = args[1] if len(args) > 1 else kwargs['k']
101
- v = args[2] if len(args) > 2 else kwargs['v']
102
- assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
103
- assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
104
- assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
105
- assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
106
- device = q.device
107
-
108
- if BACKEND == 'xformers':
 
 
 
 
 
 
 
 
109
  if num_all_args == 1:
110
  q, k, v = qkv.unbind(dim=2)
111
  elif num_all_args == 2:
112
  k, v = kv.unbind(dim=2)
113
  out = xops.memory_efficient_attention(q, k, v)
114
- elif BACKEND == 'flash_attn':
115
  if num_all_args == 1:
116
  out = flash_attn.flash_attn_qkvpacked_func(qkv)
117
  elif num_all_args == 2:
118
  out = flash_attn.flash_attn_kvpacked_func(q, kv)
119
  elif num_all_args == 3:
120
  out = flash_attn.flash_attn_func(q, k, v)
121
- elif BACKEND == 'sdpa':
122
  if num_all_args == 1:
123
  q, k, v = qkv.unbind(dim=2)
124
  elif num_all_args == 2:
125
  k, v = kv.unbind(dim=2)
126
- q = q.permute(0, 2, 1, 3) # [N, H, L, C]
127
- k = k.permute(0, 2, 1, 3) # [N, H, L, C]
128
- v = v.permute(0, 2, 1, 3) # [N, H, L, C]
129
- out = sdpa(q, k, v) # [N, H, L, C]
130
- out = out.permute(0, 2, 1, 3) # [N, L, H, C]
131
- elif BACKEND == 'naive':
132
  if num_all_args == 1:
133
  q, k, v = qkv.unbind(dim=2)
134
  elif num_all_args == 2:
@@ -136,5 +155,5 @@ def scaled_dot_product_attention(*args, **kwargs):
136
  out = _naive_sdpa(q, k, v)
137
  else:
138
  raise ValueError(f"Unknown attention module: {BACKEND}")
139
-
140
  return out
 
3
  import math
4
  from . import DEBUG, BACKEND
5
 
6
+ if BACKEND == "xformers":
7
  import xformers.ops as xops
8
+ elif BACKEND == "flash_attn":
9
  import flash_attn
10
+ elif BACKEND == "sdpa":
11
  from torch.nn.functional import scaled_dot_product_attention as sdpa
12
+ elif BACKEND == "naive":
13
  pass
14
  else:
15
  raise ValueError(f"Unknown attention backend: {BACKEND}")
16
 
17
 
18
  __all__ = [
19
+ "scaled_dot_product_attention",
20
  ]
21
 
22
 
 
24
  """
25
  Naive implementation of scaled dot product attention.
26
  """
27
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
28
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
29
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
30
  scale_factor = 1 / math.sqrt(q.size(-1))
31
  attn_weight = q @ k.transpose(-2, -1) * scale_factor
32
  attn_weight = torch.softmax(attn_weight, dim=-1)
33
  out = attn_weight @ v
34
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
35
  return out
36
 
37
 
 
45
  """
46
  ...
47
 
48
+
49
  @overload
50
  def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
51
  """
 
57
  """
58
  ...
59
 
60
+
61
  @overload
62
+ def scaled_dot_product_attention(
63
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
64
+ ) -> torch.Tensor:
65
  """
66
  Apply scaled dot product attention.
67
 
 
75
  """
76
  ...
77
 
78
+
79
  def scaled_dot_product_attention(*args, **kwargs):
80
+ arg_names_dict = {1: ["qkv"], 2: ["q", "kv"], 3: ["q", "k", "v"]}
 
 
 
 
81
  num_all_args = len(args) + len(kwargs)
82
+ assert (
83
+ num_all_args in arg_names_dict
84
+ ), f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
85
+ for key in arg_names_dict[num_all_args][len(args) :]:
86
  assert key in kwargs, f"Missing argument {key}"
87
 
88
  if num_all_args == 1:
89
+ qkv = args[0] if len(args) > 0 else kwargs["qkv"]
90
+ assert (
91
+ len(qkv.shape) == 5 and qkv.shape[2] == 3
92
+ ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
93
  device = qkv.device
94
 
95
  elif num_all_args == 2:
96
+ q = args[0] if len(args) > 0 else kwargs["q"]
97
+ kv = args[1] if len(args) > 1 else kwargs["kv"]
98
+ assert (
99
+ q.shape[0] == kv.shape[0]
100
+ ), f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
101
+ assert (
102
+ len(q.shape) == 4
103
+ ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
104
+ assert (
105
+ len(kv.shape) == 5
106
+ ), f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
107
  device = q.device
108
 
109
  elif num_all_args == 3:
110
+ q = args[0] if len(args) > 0 else kwargs["q"]
111
+ k = args[1] if len(args) > 1 else kwargs["k"]
112
+ v = args[2] if len(args) > 2 else kwargs["v"]
113
+ assert (
114
+ q.shape[0] == k.shape[0] == v.shape[0]
115
+ ), f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
116
+ assert (
117
+ len(q.shape) == 4
118
+ ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
119
+ assert (
120
+ len(k.shape) == 4
121
+ ), f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
122
+ assert (
123
+ len(v.shape) == 4
124
+ ), f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
125
+ device = q.device
126
+
127
+ if BACKEND == "xformers":
128
  if num_all_args == 1:
129
  q, k, v = qkv.unbind(dim=2)
130
  elif num_all_args == 2:
131
  k, v = kv.unbind(dim=2)
132
  out = xops.memory_efficient_attention(q, k, v)
133
+ elif BACKEND == "flash_attn":
134
  if num_all_args == 1:
135
  out = flash_attn.flash_attn_qkvpacked_func(qkv)
136
  elif num_all_args == 2:
137
  out = flash_attn.flash_attn_kvpacked_func(q, kv)
138
  elif num_all_args == 3:
139
  out = flash_attn.flash_attn_func(q, k, v)
140
+ elif BACKEND == "sdpa":
141
  if num_all_args == 1:
142
  q, k, v = qkv.unbind(dim=2)
143
  elif num_all_args == 2:
144
  k, v = kv.unbind(dim=2)
145
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
146
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
147
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
148
+ out = sdpa(q, k, v) # [N, H, L, C]
149
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
150
+ elif BACKEND == "naive":
151
  if num_all_args == 1:
152
  q, k, v = qkv.unbind(dim=2)
153
  elif num_all_args == 2:
 
155
  out = _naive_sdpa(q, k, v)
156
  else:
157
  raise ValueError(f"Unknown attention module: {BACKEND}")
158
+
159
  return out
trellis/modules/attention/modules.py CHANGED
@@ -8,11 +8,11 @@ from .full_attn import scaled_dot_product_attention
8
  class MultiHeadRMSNorm(nn.Module):
9
  def __init__(self, dim: int, heads: int):
10
  super().__init__()
11
- self.scale = dim ** 0.5
12
  self.gamma = nn.Parameter(torch.ones(heads, dim))
13
 
14
  def forward(self, x: torch.Tensor) -> torch.Tensor:
15
- return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
16
 
17
 
18
  class RotaryPositionEmbedder(nn.Module):
@@ -23,21 +23,25 @@ class RotaryPositionEmbedder(nn.Module):
23
  self.in_channels = in_channels
24
  self.freq_dim = hidden_size // in_channels // 2
25
  self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
26
- self.freqs = 1.0 / (10000 ** self.freqs)
27
-
28
  def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
29
  self.freqs = self.freqs.to(indices.device)
30
  phases = torch.outer(indices, self.freqs)
31
  phases = torch.polar(torch.ones_like(phases), phases)
32
  return phases
33
-
34
  def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
35
  x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
36
  x_rotated = x_complex * phases
37
- x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
 
 
38
  return x_embed
39
-
40
- def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
41
  """
42
  Args:
43
  q (sp.SparseTensor): [..., N, D] tensor of queries
@@ -48,24 +52,38 @@ class RotaryPositionEmbedder(nn.Module):
48
  indices = torch.arange(q.shape[-2], device=q.device)
49
  if len(q.shape) > 2:
50
  indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
51
-
52
  phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
53
  if phases.shape[1] < self.hidden_size // 2:
54
- phases = torch.cat([phases, torch.polar(
55
- torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
56
- torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
57
- )], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  q_embed = self._rotary_embedding(q, phases)
59
  k_embed = self._rotary_embedding(k, phases)
60
  return q_embed, k_embed
61
-
62
 
63
  class MultiHeadAttention(nn.Module):
64
  def __init__(
65
  self,
66
  channels: int,
67
  num_heads: int,
68
- ctx_channels: Optional[int]=None,
69
  type: Literal["self", "cross"] = "self",
70
  attn_mode: Literal["full", "windowed"] = "full",
71
  window_size: Optional[int] = None,
@@ -78,11 +96,13 @@ class MultiHeadAttention(nn.Module):
78
  assert channels % num_heads == 0
79
  assert type in ["self", "cross"], f"Invalid attention type: {type}"
80
  assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
81
- assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
82
-
 
 
83
  if attn_mode == "windowed":
84
  raise NotImplementedError("Windowed attention is not yet implemented")
85
-
86
  self.channels = channels
87
  self.head_dim = channels // num_heads
88
  self.ctx_channels = ctx_channels if ctx_channels is not None else channels
@@ -99,17 +119,22 @@ class MultiHeadAttention(nn.Module):
99
  else:
100
  self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
101
  self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
102
-
103
  if self.qk_rms_norm:
104
  self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
105
  self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
106
-
107
  self.to_out = nn.Linear(channels, channels)
108
 
109
  if use_rope:
110
  self.rope = RotaryPositionEmbedder(channels)
111
-
112
- def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
 
 
 
 
 
113
  B, L, C = x.shape
114
  if self._type == "self":
115
  qkv = self.to_qkv(x)
 
8
  class MultiHeadRMSNorm(nn.Module):
9
  def __init__(self, dim: int, heads: int):
10
  super().__init__()
11
+ self.scale = dim**0.5
12
  self.gamma = nn.Parameter(torch.ones(heads, dim))
13
 
14
  def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return (F.normalize(x.float(), dim=-1) * self.gamma * self.scale).to(x.dtype)
16
 
17
 
18
  class RotaryPositionEmbedder(nn.Module):
 
23
  self.in_channels = in_channels
24
  self.freq_dim = hidden_size // in_channels // 2
25
  self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
26
+ self.freqs = 1.0 / (10000**self.freqs)
27
+
28
  def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
29
  self.freqs = self.freqs.to(indices.device)
30
  phases = torch.outer(indices, self.freqs)
31
  phases = torch.polar(torch.ones_like(phases), phases)
32
  return phases
33
+
34
  def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
35
  x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
36
  x_rotated = x_complex * phases
37
+ x_embed = (
38
+ torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
39
+ )
40
  return x_embed
41
+
42
+ def forward(
43
+ self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None
44
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
45
  """
46
  Args:
47
  q (sp.SparseTensor): [..., N, D] tensor of queries
 
52
  indices = torch.arange(q.shape[-2], device=q.device)
53
  if len(q.shape) > 2:
54
  indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
55
+
56
  phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
57
  if phases.shape[1] < self.hidden_size // 2:
58
+ phases = torch.cat(
59
+ [
60
+ phases,
61
+ torch.polar(
62
+ torch.ones(
63
+ *phases.shape[:-1],
64
+ self.hidden_size // 2 - phases.shape[1],
65
+ device=phases.device,
66
+ ),
67
+ torch.zeros(
68
+ *phases.shape[:-1],
69
+ self.hidden_size // 2 - phases.shape[1],
70
+ device=phases.device,
71
+ ),
72
+ ),
73
+ ],
74
+ dim=-1,
75
+ )
76
  q_embed = self._rotary_embedding(q, phases)
77
  k_embed = self._rotary_embedding(k, phases)
78
  return q_embed, k_embed
79
+
80
 
81
  class MultiHeadAttention(nn.Module):
82
  def __init__(
83
  self,
84
  channels: int,
85
  num_heads: int,
86
+ ctx_channels: Optional[int] = None,
87
  type: Literal["self", "cross"] = "self",
88
  attn_mode: Literal["full", "windowed"] = "full",
89
  window_size: Optional[int] = None,
 
96
  assert channels % num_heads == 0
97
  assert type in ["self", "cross"], f"Invalid attention type: {type}"
98
  assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
99
+ assert (
100
+ type == "self" or attn_mode == "full"
101
+ ), "Cross-attention only supports full attention"
102
+
103
  if attn_mode == "windowed":
104
  raise NotImplementedError("Windowed attention is not yet implemented")
105
+
106
  self.channels = channels
107
  self.head_dim = channels // num_heads
108
  self.ctx_channels = ctx_channels if ctx_channels is not None else channels
 
119
  else:
120
  self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
121
  self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
122
+
123
  if self.qk_rms_norm:
124
  self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
125
  self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
126
+
127
  self.to_out = nn.Linear(channels, channels)
128
 
129
  if use_rope:
130
  self.rope = RotaryPositionEmbedder(channels)
131
+
132
+ def forward(
133
+ self,
134
+ x: torch.Tensor,
135
+ context: Optional[torch.Tensor] = None,
136
+ indices: Optional[torch.Tensor] = None,
137
+ ) -> torch.Tensor:
138
  B, L, C = x.shape
139
  if self._type == "self":
140
  qkv = self.to_qkv(x)
trellis/modules/norm.py CHANGED
@@ -5,21 +5,21 @@ import torch.nn as nn
5
  class LayerNorm32(nn.LayerNorm):
6
  def forward(self, x: torch.Tensor) -> torch.Tensor:
7
  return super().forward(x.float()).type(x.dtype)
8
-
9
 
10
  class GroupNorm32(nn.GroupNorm):
11
  """
12
  A GroupNorm layer that converts to float32 before the forward pass.
13
  """
 
14
  def forward(self, x: torch.Tensor) -> torch.Tensor:
15
  return super().forward(x.float()).type(x.dtype)
16
-
17
-
18
  class ChannelLayerNorm32(LayerNorm32):
19
  def forward(self, x: torch.Tensor) -> torch.Tensor:
20
  DIM = x.dim()
21
  x = x.permute(0, *range(2, DIM), 1).contiguous()
22
  x = super().forward(x)
23
- x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
24
  return x
25
-
 
5
  class LayerNorm32(nn.LayerNorm):
6
  def forward(self, x: torch.Tensor) -> torch.Tensor:
7
  return super().forward(x.float()).type(x.dtype)
8
+
9
 
10
  class GroupNorm32(nn.GroupNorm):
11
  """
12
  A GroupNorm layer that converts to float32 before the forward pass.
13
  """
14
+
15
  def forward(self, x: torch.Tensor) -> torch.Tensor:
16
  return super().forward(x.float()).type(x.dtype)
17
+
18
+
19
  class ChannelLayerNorm32(LayerNorm32):
20
  def forward(self, x: torch.Tensor) -> torch.Tensor:
21
  DIM = x.dim()
22
  x = x.permute(0, *range(2, DIM), 1).contiguous()
23
  x = super().forward(x)
24
+ x = x.permute(0, DIM - 1, *range(1, DIM - 1)).contiguous()
25
  return x
 
trellis/modules/sparse/__init__.py CHANGED
@@ -1,81 +1,88 @@
1
  from typing import *
2
 
3
- BACKEND = 'spconv'
4
  DEBUG = False
5
- ATTN = 'flash_attn'
 
6
 
7
  def __from_env():
8
  import os
9
-
10
  global BACKEND
11
  global DEBUG
12
  global ATTN
13
-
14
- env_sparse_backend = os.environ.get('SPARSE_BACKEND')
15
- env_sparse_debug = os.environ.get('SPARSE_DEBUG')
16
- env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
17
  if env_sparse_attn is None:
18
- env_sparse_attn = os.environ.get('ATTN_BACKEND')
19
 
20
- if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
 
 
 
21
  BACKEND = env_sparse_backend
22
  if env_sparse_debug is not None:
23
- DEBUG = env_sparse_debug == '1'
24
- if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
25
  ATTN = env_sparse_attn
26
-
27
  print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
28
-
29
 
30
  __from_env()
31
-
32
 
33
- def set_backend(backend: Literal['spconv', 'torchsparse']):
 
34
  global BACKEND
35
  BACKEND = backend
36
 
 
37
  def set_debug(debug: bool):
38
  global DEBUG
39
  DEBUG = debug
40
 
41
- def set_attn(attn: Literal['xformers', 'flash_attn']):
 
42
  global ATTN
43
  ATTN = attn
44
-
45
-
46
  import importlib
47
 
48
  __attributes = {
49
- 'SparseTensor': 'basic',
50
- 'sparse_batch_broadcast': 'basic',
51
- 'sparse_batch_op': 'basic',
52
- 'sparse_cat': 'basic',
53
- 'sparse_unbind': 'basic',
54
- 'SparseGroupNorm': 'norm',
55
- 'SparseLayerNorm': 'norm',
56
- 'SparseGroupNorm32': 'norm',
57
- 'SparseLayerNorm32': 'norm',
58
- 'SparseReLU': 'nonlinearity',
59
- 'SparseSiLU': 'nonlinearity',
60
- 'SparseGELU': 'nonlinearity',
61
- 'SparseActivation': 'nonlinearity',
62
- 'SparseLinear': 'linear',
63
- 'sparse_scaled_dot_product_attention': 'attention',
64
- 'SerializeMode': 'attention',
65
- 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
66
- 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
67
- 'SparseMultiHeadAttention': 'attention',
68
- 'SparseConv3d': 'conv',
69
- 'SparseInverseConv3d': 'conv',
70
- 'SparseDownsample': 'spatial',
71
- 'SparseUpsample': 'spatial',
72
- 'SparseSubdivide' : 'spatial'
73
  }
74
 
75
- __submodules = ['transformer']
76
 
77
  __all__ = list(__attributes.keys()) + __submodules
78
 
 
79
  def __getattr__(name):
80
  if name not in globals():
81
  if name in __attributes:
@@ -91,7 +98,7 @@ def __getattr__(name):
91
 
92
 
93
  # For Pylance
94
- if __name__ == '__main__':
95
  from .basic import *
96
  from .norm import *
97
  from .nonlinearity import *
 
1
  from typing import *
2
 
3
+ BACKEND = "spconv"
4
  DEBUG = False
5
+ ATTN = "flash_attn"
6
+
7
 
8
  def __from_env():
9
  import os
10
+
11
  global BACKEND
12
  global DEBUG
13
  global ATTN
14
+
15
+ env_sparse_backend = os.environ.get("SPARSE_BACKEND")
16
+ env_sparse_debug = os.environ.get("SPARSE_DEBUG")
17
+ env_sparse_attn = os.environ.get("SPARSE_ATTN_BACKEND")
18
  if env_sparse_attn is None:
19
+ env_sparse_attn = os.environ.get("ATTN_BACKEND")
20
 
21
+ if env_sparse_backend is not None and env_sparse_backend in [
22
+ "spconv",
23
+ "torchsparse",
24
+ ]:
25
  BACKEND = env_sparse_backend
26
  if env_sparse_debug is not None:
27
+ DEBUG = env_sparse_debug == "1"
28
+ if env_sparse_attn is not None and env_sparse_attn in ["xformers", "flash_attn"]:
29
  ATTN = env_sparse_attn
30
+
31
  print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
32
+
33
 
34
  __from_env()
 
35
 
36
+
37
+ def set_backend(backend: Literal["spconv", "torchsparse"]):
38
  global BACKEND
39
  BACKEND = backend
40
 
41
+
42
  def set_debug(debug: bool):
43
  global DEBUG
44
  DEBUG = debug
45
 
46
+
47
+ def set_attn(attn: Literal["xformers", "flash_attn"]):
48
  global ATTN
49
  ATTN = attn
50
+
51
+
52
  import importlib
53
 
54
  __attributes = {
55
+ "SparseTensor": "basic",
56
+ "sparse_batch_broadcast": "basic",
57
+ "sparse_batch_op": "basic",
58
+ "sparse_cat": "basic",
59
+ "sparse_unbind": "basic",
60
+ "SparseGroupNorm": "norm",
61
+ "SparseLayerNorm": "norm",
62
+ "SparseGroupNorm32": "norm",
63
+ "SparseLayerNorm32": "norm",
64
+ "SparseReLU": "nonlinearity",
65
+ "SparseSiLU": "nonlinearity",
66
+ "SparseGELU": "nonlinearity",
67
+ "SparseActivation": "nonlinearity",
68
+ "SparseLinear": "linear",
69
+ "sparse_scaled_dot_product_attention": "attention",
70
+ "SerializeMode": "attention",
71
+ "sparse_serialized_scaled_dot_product_self_attention": "attention",
72
+ "sparse_windowed_scaled_dot_product_self_attention": "attention",
73
+ "SparseMultiHeadAttention": "attention",
74
+ "SparseConv3d": "conv",
75
+ "SparseInverseConv3d": "conv",
76
+ "SparseDownsample": "spatial",
77
+ "SparseUpsample": "spatial",
78
+ "SparseSubdivide": "spatial",
79
  }
80
 
81
+ __submodules = ["transformer"]
82
 
83
  __all__ = list(__attributes.keys()) + __submodules
84
 
85
+
86
  def __getattr__(name):
87
  if name not in globals():
88
  if name in __attributes:
 
98
 
99
 
100
  # For Pylance
101
+ if __name__ == "__main__":
102
  from .basic import *
103
  from .norm import *
104
  from .nonlinearity import *
trellis/modules/sparse/attention/full_attn.py CHANGED
@@ -3,16 +3,16 @@ import torch
3
  from .. import SparseTensor
4
  from .. import DEBUG, ATTN
5
 
6
- if ATTN == 'xformers':
7
  import xformers.ops as xops
8
- elif ATTN == 'flash_attn':
9
  import flash_attn
10
  else:
11
  raise ValueError(f"Unknown attention module: {ATTN}")
12
 
13
 
14
  __all__ = [
15
- 'sparse_scaled_dot_product_attention',
16
  ]
17
 
18
 
@@ -26,8 +26,11 @@ def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
26
  """
27
  ...
28
 
 
29
  @overload
30
- def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor:
 
 
31
  """
32
  Apply scaled dot product attention to a sparse tensor.
33
 
@@ -37,8 +40,11 @@ def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor,
37
  """
38
  ...
39
 
 
40
  @overload
41
- def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor:
 
 
42
  """
43
  Apply scaled dot product attention to a sparse tensor.
44
 
@@ -48,8 +54,11 @@ def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> to
48
  """
49
  ...
50
 
 
51
  @overload
52
- def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor:
 
 
53
  """
54
  Apply scaled dot product attention to a sparse tensor.
55
 
@@ -63,8 +72,11 @@ def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: Spa
63
  """
64
  ...
65
 
 
66
  @overload
67
- def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor:
 
 
68
  """
69
  Apply scaled dot product attention to a sparse tensor.
70
 
@@ -75,8 +87,11 @@ def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: tor
75
  """
76
  ...
77
 
 
78
  @overload
79
- def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor:
 
 
80
  """
81
  Apply scaled dot product attention to a sparse tensor.
82
 
@@ -87,106 +102,158 @@ def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: Spa
87
  """
88
  ...
89
 
 
90
  def sparse_scaled_dot_product_attention(*args, **kwargs):
91
- arg_names_dict = {
92
- 1: ['qkv'],
93
- 2: ['q', 'kv'],
94
- 3: ['q', 'k', 'v']
95
- }
96
  num_all_args = len(args) + len(kwargs)
97
- assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
98
- for key in arg_names_dict[num_all_args][len(args):]:
 
 
99
  assert key in kwargs, f"Missing argument {key}"
100
 
101
  if num_all_args == 1:
102
- qkv = args[0] if len(args) > 0 else kwargs['qkv']
103
- assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}"
104
- assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
 
 
 
 
105
  device = qkv.device
106
 
107
  s = qkv
108
- q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
 
 
109
  kv_seqlen = q_seqlen
110
- qkv = qkv.feats # [T, 3, H, C]
111
 
112
  elif num_all_args == 2:
113
- q = args[0] if len(args) > 0 else kwargs['q']
114
- kv = args[1] if len(args) > 1 else kwargs['kv']
115
- assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \
116
- isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \
117
- f"Invalid types, got {type(q)} and {type(kv)}"
118
- assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
 
 
 
 
 
119
  device = q.device
120
 
121
  if isinstance(q, SparseTensor):
122
- assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
 
 
123
  s = q
124
  q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
125
- q = q.feats # [T_Q, H, C]
126
  else:
127
- assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
 
 
128
  s = None
129
  N, L, H, C = q.shape
130
  q_seqlen = [L] * N
131
- q = q.reshape(N * L, H, C) # [T_Q, H, C]
132
 
133
  if isinstance(kv, SparseTensor):
134
- assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
135
- kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
136
- kv = kv.feats # [T_KV, 2, H, C]
 
 
 
 
137
  else:
138
- assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
 
 
139
  N, L, _, H, C = kv.shape
140
  kv_seqlen = [L] * N
141
- kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
142
 
143
  elif num_all_args == 3:
144
- q = args[0] if len(args) > 0 else kwargs['q']
145
- k = args[1] if len(args) > 1 else kwargs['k']
146
- v = args[2] if len(args) > 2 else kwargs['v']
147
- assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \
148
- isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \
149
- f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
150
- assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
 
 
 
 
 
 
 
151
  device = q.device
152
 
153
  if isinstance(q, SparseTensor):
154
- assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
 
 
155
  s = q
156
  q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
157
- q = q.feats # [T_Q, H, Ci]
158
  else:
159
- assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
 
 
160
  s = None
161
  N, L, H, CI = q.shape
162
  q_seqlen = [L] * N
163
  q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
164
 
165
  if isinstance(k, SparseTensor):
166
- assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
167
- assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
168
- kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
169
- k = k.feats # [T_KV, H, Ci]
170
- v = v.feats # [T_KV, H, Co]
 
 
 
 
 
 
171
  else:
172
- assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
173
- assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
 
 
 
 
174
  N, L, H, CI, CO = *k.shape, v.shape[-1]
175
  kv_seqlen = [L] * N
176
- k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
177
- v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
178
 
179
  if DEBUG:
180
  if s is not None:
181
  for i in range(s.shape[0]):
182
- assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
 
 
183
  if num_all_args in [2, 3]:
184
- assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch"
 
 
 
185
  if num_all_args == 3:
186
- assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch"
187
- assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch"
 
 
 
 
 
 
188
 
189
- if ATTN == 'xformers':
190
  if num_all_args == 1:
191
  q, k, v = qkv.unbind(dim=1)
192
  elif num_all_args == 2:
@@ -196,19 +263,35 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
196
  v = v.unsqueeze(0)
197
  mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
198
  out = xops.memory_efficient_attention(q, k, v, mask)[0]
199
- elif ATTN == 'flash_attn':
200
- cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
 
 
 
 
201
  if num_all_args in [2, 3]:
202
- cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
 
 
 
 
 
 
203
  if num_all_args == 1:
204
- out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
 
 
205
  elif num_all_args == 2:
206
- out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
 
 
207
  elif num_all_args == 3:
208
- out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
 
 
209
  else:
210
  raise ValueError(f"Unknown attention module: {ATTN}")
211
-
212
  if s is not None:
213
  return s.replace(out)
214
  else:
 
3
  from .. import SparseTensor
4
  from .. import DEBUG, ATTN
5
 
6
+ if ATTN == "xformers":
7
  import xformers.ops as xops
8
+ elif ATTN == "flash_attn":
9
  import flash_attn
10
  else:
11
  raise ValueError(f"Unknown attention module: {ATTN}")
12
 
13
 
14
  __all__ = [
15
+ "sparse_scaled_dot_product_attention",
16
  ]
17
 
18
 
 
26
  """
27
  ...
28
 
29
+
30
  @overload
31
+ def sparse_scaled_dot_product_attention(
32
+ q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]
33
+ ) -> SparseTensor:
34
  """
35
  Apply scaled dot product attention to a sparse tensor.
36
 
 
40
  """
41
  ...
42
 
43
+
44
  @overload
45
+ def sparse_scaled_dot_product_attention(
46
+ q: torch.Tensor, kv: SparseTensor
47
+ ) -> torch.Tensor:
48
  """
49
  Apply scaled dot product attention to a sparse tensor.
50
 
 
54
  """
55
  ...
56
 
57
+
58
  @overload
59
+ def sparse_scaled_dot_product_attention(
60
+ q: SparseTensor, k: SparseTensor, v: SparseTensor
61
+ ) -> SparseTensor:
62
  """
63
  Apply scaled dot product attention to a sparse tensor.
64
 
 
72
  """
73
  ...
74
 
75
+
76
  @overload
77
+ def sparse_scaled_dot_product_attention(
78
+ q: SparseTensor, k: torch.Tensor, v: torch.Tensor
79
+ ) -> SparseTensor:
80
  """
81
  Apply scaled dot product attention to a sparse tensor.
82
 
 
87
  """
88
  ...
89
 
90
+
91
  @overload
92
+ def sparse_scaled_dot_product_attention(
93
+ q: torch.Tensor, k: SparseTensor, v: SparseTensor
94
+ ) -> torch.Tensor:
95
  """
96
  Apply scaled dot product attention to a sparse tensor.
97
 
 
102
  """
103
  ...
104
 
105
+
106
  def sparse_scaled_dot_product_attention(*args, **kwargs):
107
+ arg_names_dict = {1: ["qkv"], 2: ["q", "kv"], 3: ["q", "k", "v"]}
 
 
 
 
108
  num_all_args = len(args) + len(kwargs)
109
+ assert (
110
+ num_all_args in arg_names_dict
111
+ ), f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
112
+ for key in arg_names_dict[num_all_args][len(args) :]:
113
  assert key in kwargs, f"Missing argument {key}"
114
 
115
  if num_all_args == 1:
116
+ qkv = args[0] if len(args) > 0 else kwargs["qkv"]
117
+ assert isinstance(
118
+ qkv, SparseTensor
119
+ ), f"qkv must be a SparseTensor, got {type(qkv)}"
120
+ assert (
121
+ len(qkv.shape) == 4 and qkv.shape[1] == 3
122
+ ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
123
  device = qkv.device
124
 
125
  s = qkv
126
+ q_seqlen = [
127
+ qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])
128
+ ]
129
  kv_seqlen = q_seqlen
130
+ qkv = qkv.feats # [T, 3, H, C]
131
 
132
  elif num_all_args == 2:
133
+ q = args[0] if len(args) > 0 else kwargs["q"]
134
+ kv = args[1] if len(args) > 1 else kwargs["kv"]
135
+ assert (
136
+ isinstance(q, SparseTensor)
137
+ and isinstance(kv, (SparseTensor, torch.Tensor))
138
+ or isinstance(q, torch.Tensor)
139
+ and isinstance(kv, SparseTensor)
140
+ ), f"Invalid types, got {type(q)} and {type(kv)}"
141
+ assert (
142
+ q.shape[0] == kv.shape[0]
143
+ ), f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
144
  device = q.device
145
 
146
  if isinstance(q, SparseTensor):
147
+ assert (
148
+ len(q.shape) == 3
149
+ ), f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
150
  s = q
151
  q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
152
+ q = q.feats # [T_Q, H, C]
153
  else:
154
+ assert (
155
+ len(q.shape) == 4
156
+ ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
157
  s = None
158
  N, L, H, C = q.shape
159
  q_seqlen = [L] * N
160
+ q = q.reshape(N * L, H, C) # [T_Q, H, C]
161
 
162
  if isinstance(kv, SparseTensor):
163
+ assert (
164
+ len(kv.shape) == 4 and kv.shape[1] == 2
165
+ ), f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
166
+ kv_seqlen = [
167
+ kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])
168
+ ]
169
+ kv = kv.feats # [T_KV, 2, H, C]
170
  else:
171
+ assert (
172
+ len(kv.shape) == 5
173
+ ), f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
174
  N, L, _, H, C = kv.shape
175
  kv_seqlen = [L] * N
176
+ kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
177
 
178
  elif num_all_args == 3:
179
+ q = args[0] if len(args) > 0 else kwargs["q"]
180
+ k = args[1] if len(args) > 1 else kwargs["k"]
181
+ v = args[2] if len(args) > 2 else kwargs["v"]
182
+ assert (
183
+ isinstance(q, SparseTensor)
184
+ and isinstance(k, (SparseTensor, torch.Tensor))
185
+ and type(k) == type(v)
186
+ or isinstance(q, torch.Tensor)
187
+ and isinstance(k, SparseTensor)
188
+ and isinstance(v, SparseTensor)
189
+ ), f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
190
+ assert (
191
+ q.shape[0] == k.shape[0] == v.shape[0]
192
+ ), f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
193
  device = q.device
194
 
195
  if isinstance(q, SparseTensor):
196
+ assert (
197
+ len(q.shape) == 3
198
+ ), f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
199
  s = q
200
  q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
201
+ q = q.feats # [T_Q, H, Ci]
202
  else:
203
+ assert (
204
+ len(q.shape) == 4
205
+ ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
206
  s = None
207
  N, L, H, CI = q.shape
208
  q_seqlen = [L] * N
209
  q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
210
 
211
  if isinstance(k, SparseTensor):
212
+ assert (
213
+ len(k.shape) == 3
214
+ ), f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
215
+ assert (
216
+ len(v.shape) == 3
217
+ ), f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
218
+ kv_seqlen = [
219
+ k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])
220
+ ]
221
+ k = k.feats # [T_KV, H, Ci]
222
+ v = v.feats # [T_KV, H, Co]
223
  else:
224
+ assert (
225
+ len(k.shape) == 4
226
+ ), f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
227
+ assert (
228
+ len(v.shape) == 4
229
+ ), f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
230
  N, L, H, CI, CO = *k.shape, v.shape[-1]
231
  kv_seqlen = [L] * N
232
+ k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
233
+ v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
234
 
235
  if DEBUG:
236
  if s is not None:
237
  for i in range(s.shape[0]):
238
+ assert (
239
+ s.coords[s.layout[i]] == i
240
+ ).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
241
  if num_all_args in [2, 3]:
242
+ assert q.shape[:2] == [
243
+ 1,
244
+ sum(q_seqlen),
245
+ ], f"SparseScaledDotProductSelfAttention: q shape mismatch"
246
  if num_all_args == 3:
247
+ assert k.shape[:2] == [
248
+ 1,
249
+ sum(kv_seqlen),
250
+ ], f"SparseScaledDotProductSelfAttention: k shape mismatch"
251
+ assert v.shape[:2] == [
252
+ 1,
253
+ sum(kv_seqlen),
254
+ ], f"SparseScaledDotProductSelfAttention: v shape mismatch"
255
 
256
+ if ATTN == "xformers":
257
  if num_all_args == 1:
258
  q, k, v = qkv.unbind(dim=1)
259
  elif num_all_args == 2:
 
263
  v = v.unsqueeze(0)
264
  mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
265
  out = xops.memory_efficient_attention(q, k, v, mask)[0]
266
+ elif ATTN == "flash_attn":
267
+ cu_seqlens_q = (
268
+ torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)])
269
+ .int()
270
+ .to(device)
271
+ )
272
  if num_all_args in [2, 3]:
273
+ cu_seqlens_kv = (
274
+ torch.cat(
275
+ [torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]
276
+ )
277
+ .int()
278
+ .to(device)
279
+ )
280
  if num_all_args == 1:
281
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(
282
+ qkv, cu_seqlens_q, max(q_seqlen)
283
+ )
284
  elif num_all_args == 2:
285
+ out = flash_attn.flash_attn_varlen_kvpacked_func(
286
+ q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)
287
+ )
288
  elif num_all_args == 3:
289
+ out = flash_attn.flash_attn_varlen_func(
290
+ q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)
291
+ )
292
  else:
293
  raise ValueError(f"Unknown attention module: {ATTN}")
294
+
295
  if s is not None:
296
  return s.replace(out)
297
  else:
trellis/modules/sparse/attention/modules.py CHANGED
@@ -4,7 +4,10 @@ import torch.nn as nn
4
  import torch.nn.functional as F
5
  from .. import SparseTensor
6
  from .full_attn import sparse_scaled_dot_product_attention
7
- from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention
 
 
 
8
  from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
9
  from ...attention import RotaryPositionEmbedder
10
 
@@ -12,16 +15,18 @@ from ...attention import RotaryPositionEmbedder
12
  class SparseMultiHeadRMSNorm(nn.Module):
13
  def __init__(self, dim: int, heads: int):
14
  super().__init__()
15
- self.scale = dim ** 0.5
16
  self.gamma = nn.Parameter(torch.ones(heads, dim))
17
 
18
- def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
 
 
19
  x_type = x.dtype
20
  x = x.float()
21
  if isinstance(x, SparseTensor):
22
  x = x.replace(F.normalize(x.feats, dim=-1))
23
  else:
24
- x = F.normalize(x, dim=-1)
25
  return (x * self.gamma * self.scale).to(x_type)
26
 
27
 
@@ -44,9 +49,17 @@ class SparseMultiHeadAttention(nn.Module):
44
  super().__init__()
45
  assert channels % num_heads == 0
46
  assert type in ["self", "cross"], f"Invalid attention type: {type}"
47
- assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}"
48
- assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
49
- assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
 
 
 
 
 
 
 
 
50
  self.channels = channels
51
  self.ctx_channels = ctx_channels if ctx_channels is not None else channels
52
  self.num_heads = num_heads
@@ -64,31 +77,37 @@ class SparseMultiHeadAttention(nn.Module):
64
  else:
65
  self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
66
  self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
67
-
68
  if self.qk_rms_norm:
69
  self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
70
  self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
71
-
72
  self.to_out = nn.Linear(channels, channels)
73
 
74
  if use_rope:
75
  self.rope = RotaryPositionEmbedder(channels)
76
 
77
  @staticmethod
78
- def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
 
 
79
  if isinstance(x, SparseTensor):
80
  return x.replace(module(x.feats))
81
  else:
82
  return module(x)
83
 
84
  @staticmethod
85
- def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]:
 
 
86
  if isinstance(x, SparseTensor):
87
  return x.reshape(*shape)
88
  else:
89
  return x.reshape(*x.shape[:2], *shape)
90
 
91
- def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]:
 
 
92
  if isinstance(x, SparseTensor):
93
  x_feats = x.feats.unsqueeze(0)
94
  else:
@@ -97,12 +116,16 @@ class SparseMultiHeadAttention(nn.Module):
97
  return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
98
 
99
  def _rope(self, qkv: SparseTensor) -> SparseTensor:
100
- q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
101
  q, k = self.rope(q, k, qkv.coords[:, 1:])
102
- qkv = qkv.replace(torch.stack([q, k, v], dim=1))
103
  return qkv
104
-
105
- def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]:
 
 
 
 
106
  if self._type == "self":
107
  qkv = self._linear(self.to_qkv, x)
108
  qkv = self._fused_pre(qkv, num_fused=3)
@@ -117,7 +140,11 @@ class SparseMultiHeadAttention(nn.Module):
117
  h = sparse_scaled_dot_product_attention(qkv)
118
  elif self.attn_mode == "serialized":
119
  h = sparse_serialized_scaled_dot_product_self_attention(
120
- qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window
 
 
 
 
121
  )
122
  elif self.attn_mode == "windowed":
123
  h = sparse_windowed_scaled_dot_product_self_attention(
 
4
  import torch.nn.functional as F
5
  from .. import SparseTensor
6
  from .full_attn import sparse_scaled_dot_product_attention
7
+ from .serialized_attn import (
8
+ SerializeMode,
9
+ sparse_serialized_scaled_dot_product_self_attention,
10
+ )
11
  from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
12
  from ...attention import RotaryPositionEmbedder
13
 
 
15
  class SparseMultiHeadRMSNorm(nn.Module):
16
  def __init__(self, dim: int, heads: int):
17
  super().__init__()
18
+ self.scale = dim**0.5
19
  self.gamma = nn.Parameter(torch.ones(heads, dim))
20
 
21
+ def forward(
22
+ self, x: Union[SparseTensor, torch.Tensor]
23
+ ) -> Union[SparseTensor, torch.Tensor]:
24
  x_type = x.dtype
25
  x = x.float()
26
  if isinstance(x, SparseTensor):
27
  x = x.replace(F.normalize(x.feats, dim=-1))
28
  else:
29
+ x = F.normalize(x, dim=-1)
30
  return (x * self.gamma * self.scale).to(x_type)
31
 
32
 
 
49
  super().__init__()
50
  assert channels % num_heads == 0
51
  assert type in ["self", "cross"], f"Invalid attention type: {type}"
52
+ assert attn_mode in [
53
+ "full",
54
+ "serialized",
55
+ "windowed",
56
+ ], f"Invalid attention mode: {attn_mode}"
57
+ assert (
58
+ type == "self" or attn_mode == "full"
59
+ ), "Cross-attention only supports full attention"
60
+ assert (
61
+ type == "self" or use_rope is False
62
+ ), "Rotary position embeddings only supported for self-attention"
63
  self.channels = channels
64
  self.ctx_channels = ctx_channels if ctx_channels is not None else channels
65
  self.num_heads = num_heads
 
77
  else:
78
  self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
79
  self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
80
+
81
  if self.qk_rms_norm:
82
  self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
83
  self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
84
+
85
  self.to_out = nn.Linear(channels, channels)
86
 
87
  if use_rope:
88
  self.rope = RotaryPositionEmbedder(channels)
89
 
90
  @staticmethod
91
+ def _linear(
92
+ module: nn.Linear, x: Union[SparseTensor, torch.Tensor]
93
+ ) -> Union[SparseTensor, torch.Tensor]:
94
  if isinstance(x, SparseTensor):
95
  return x.replace(module(x.feats))
96
  else:
97
  return module(x)
98
 
99
  @staticmethod
100
+ def _reshape_chs(
101
+ x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]
102
+ ) -> Union[SparseTensor, torch.Tensor]:
103
  if isinstance(x, SparseTensor):
104
  return x.reshape(*shape)
105
  else:
106
  return x.reshape(*x.shape[:2], *shape)
107
 
108
+ def _fused_pre(
109
+ self, x: Union[SparseTensor, torch.Tensor], num_fused: int
110
+ ) -> Union[SparseTensor, torch.Tensor]:
111
  if isinstance(x, SparseTensor):
112
  x_feats = x.feats.unsqueeze(0)
113
  else:
 
116
  return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
117
 
118
  def _rope(self, qkv: SparseTensor) -> SparseTensor:
119
+ q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
120
  q, k = self.rope(q, k, qkv.coords[:, 1:])
121
+ qkv = qkv.replace(torch.stack([q, k, v], dim=1))
122
  return qkv
123
+
124
+ def forward(
125
+ self,
126
+ x: Union[SparseTensor, torch.Tensor],
127
+ context: Optional[Union[SparseTensor, torch.Tensor]] = None,
128
+ ) -> Union[SparseTensor, torch.Tensor]:
129
  if self._type == "self":
130
  qkv = self._linear(self.to_qkv, x)
131
  qkv = self._fused_pre(qkv, num_fused=3)
 
140
  h = sparse_scaled_dot_product_attention(qkv)
141
  elif self.attn_mode == "serialized":
142
  h = sparse_serialized_scaled_dot_product_self_attention(
143
+ qkv,
144
+ self.window_size,
145
+ serialize_mode=self.serialize_mode,
146
+ shift_sequence=self.shift_sequence,
147
+ shift_window=self.shift_window,
148
  )
149
  elif self.attn_mode == "windowed":
150
  h = sparse_windowed_scaled_dot_product_self_attention(
trellis/modules/sparse/attention/serialized_attn.py CHANGED
@@ -5,16 +5,16 @@ import math
5
  from .. import SparseTensor
6
  from .. import DEBUG, ATTN
7
 
8
- if ATTN == 'xformers':
9
  import xformers.ops as xops
10
- elif ATTN == 'flash_attn':
11
  import flash_attn
12
  else:
13
  raise ValueError(f"Unknown attention module: {ATTN}")
14
 
15
 
16
  __all__ = [
17
- 'sparse_serialized_scaled_dot_product_self_attention',
18
  ]
19
 
20
 
@@ -29,7 +29,7 @@ SerializeModes = [
29
  SerializeMode.Z_ORDER,
30
  SerializeMode.Z_ORDER_TRANSPOSED,
31
  SerializeMode.HILBERT,
32
- SerializeMode.HILBERT_TRANSPOSED
33
  ]
34
 
35
 
@@ -38,7 +38,7 @@ def calc_serialization(
38
  window_size: int,
39
  serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
40
  shift_sequence: int = 0,
41
- shift_window: Tuple[int, int, int] = (0, 0, 0)
42
  ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
43
  """
44
  Calculate serialization and partitioning for a set of coordinates.
@@ -58,32 +58,38 @@ def calc_serialization(
58
  seq_lens = []
59
  seq_batch_indices = []
60
  offsets = [0]
61
-
62
- if 'vox2seq' not in globals():
63
  import vox2seq
64
 
65
  # Serialize the input
66
  serialize_coords = tensor.coords[:, 1:].clone()
67
- serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
 
 
68
  if serialize_mode == SerializeMode.Z_ORDER:
69
- code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
70
  elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
71
- code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
72
  elif serialize_mode == SerializeMode.HILBERT:
73
- code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
74
  elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
75
- code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
76
  else:
77
  raise ValueError(f"Unknown serialize mode: {serialize_mode}")
78
-
79
  for bi, s in enumerate(tensor.layout):
80
  num_points = s.stop - s.start
81
  num_windows = (num_points + window_size - 1) // window_size
82
  valid_window_size = num_points / num_windows
83
- to_ordered = torch.argsort(code[s.start:s.stop])
84
  if num_windows == 1:
85
  fwd_indices.append(to_ordered)
86
- bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
 
 
 
 
87
  fwd_indices[-1] += s.start
88
  bwd_indices[-1] += offsets[-1]
89
  seq_lens.append(num_points)
@@ -92,18 +98,39 @@ def calc_serialization(
92
  else:
93
  # Partition the input
94
  offset = 0
95
- mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
96
- split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
97
- bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
 
 
 
 
 
 
 
 
98
  for i in range(num_windows):
99
  mid = mids[i]
100
  valid_start = split[i]
101
  valid_end = split[i + 1]
102
  padded_start = math.floor(mid - 0.5 * window_size)
103
  padded_end = padded_start + window_size
104
- fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
 
 
 
 
 
105
  offset += valid_start - padded_start
106
- bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
 
 
 
 
 
 
 
 
107
  offset += padded_end - valid_start
108
  fwd_indices[-1] += s.start
109
  seq_lens.extend([window_size] * num_windows)
@@ -115,14 +142,14 @@ def calc_serialization(
115
  bwd_indices = torch.cat(bwd_indices)
116
 
117
  return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
118
-
119
 
120
  def sparse_serialized_scaled_dot_product_self_attention(
121
  qkv: SparseTensor,
122
  window_size: int,
123
  serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
124
  shift_sequence: int = 0,
125
- shift_window: Tuple[int, int, int] = (0, 0, 0)
126
  ) -> SparseTensor:
127
  """
128
  Apply serialized scaled dot product self attention to a sparse tensor.
@@ -135,59 +162,89 @@ def sparse_serialized_scaled_dot_product_self_attention(
135
  shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
136
  shift (int): The shift to use.
137
  """
138
- assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
139
-
140
- serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
141
- serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
 
 
 
 
 
 
142
  if serialization_spatial_cache is None:
143
- fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
144
- qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
 
 
 
 
 
145
  else:
146
- fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
 
 
 
 
 
147
 
148
  M = fwd_indices.shape[0]
149
  T = qkv.feats.shape[0]
150
  H = qkv.feats.shape[2]
151
  C = qkv.feats.shape[3]
152
-
153
- qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
154
 
155
  if DEBUG:
156
  start = 0
157
  qkv_coords = qkv.coords[fwd_indices]
158
  for i in range(len(seq_lens)):
159
- assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
 
 
 
 
160
  start += seq_lens[i]
161
 
162
  if all([seq_len == window_size for seq_len in seq_lens]):
163
  B = len(seq_lens)
164
  N = window_size
165
  qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
166
- if ATTN == 'xformers':
167
- q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
168
- out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
169
- elif ATTN == 'flash_attn':
170
- out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
171
  else:
172
  raise ValueError(f"Unknown attention module: {ATTN}")
173
- out = out.reshape(B * N, H, C) # [M, H, C]
174
  else:
175
- if ATTN == 'xformers':
176
- q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
177
- q = q.unsqueeze(0) # [1, M, H, C]
178
- k = k.unsqueeze(0) # [1, M, H, C]
179
- v = v.unsqueeze(0) # [1, M, H, C]
180
  mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
181
- out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
182
- elif ATTN == 'flash_attn':
183
- cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
184
- .to(qkv.device).int()
185
- out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
186
-
187
- out = out[bwd_indices] # [T, H, C]
 
 
 
 
 
 
 
 
188
 
189
  if DEBUG:
190
  qkv_coords = qkv_coords[bwd_indices]
191
- assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
 
 
192
 
193
  return qkv.replace(out)
 
5
  from .. import SparseTensor
6
  from .. import DEBUG, ATTN
7
 
8
+ if ATTN == "xformers":
9
  import xformers.ops as xops
10
+ elif ATTN == "flash_attn":
11
  import flash_attn
12
  else:
13
  raise ValueError(f"Unknown attention module: {ATTN}")
14
 
15
 
16
  __all__ = [
17
+ "sparse_serialized_scaled_dot_product_self_attention",
18
  ]
19
 
20
 
 
29
  SerializeMode.Z_ORDER,
30
  SerializeMode.Z_ORDER_TRANSPOSED,
31
  SerializeMode.HILBERT,
32
+ SerializeMode.HILBERT_TRANSPOSED,
33
  ]
34
 
35
 
 
38
  window_size: int,
39
  serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
40
  shift_sequence: int = 0,
41
+ shift_window: Tuple[int, int, int] = (0, 0, 0),
42
  ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
43
  """
44
  Calculate serialization and partitioning for a set of coordinates.
 
58
  seq_lens = []
59
  seq_batch_indices = []
60
  offsets = [0]
61
+
62
+ if "vox2seq" not in globals():
63
  import vox2seq
64
 
65
  # Serialize the input
66
  serialize_coords = tensor.coords[:, 1:].clone()
67
+ serialize_coords += torch.tensor(
68
+ shift_window, dtype=torch.int32, device=tensor.device
69
+ ).reshape(1, 3)
70
  if serialize_mode == SerializeMode.Z_ORDER:
71
+ code = vox2seq.encode(serialize_coords, mode="z_order", permute=[0, 1, 2])
72
  elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
73
+ code = vox2seq.encode(serialize_coords, mode="z_order", permute=[1, 0, 2])
74
  elif serialize_mode == SerializeMode.HILBERT:
75
+ code = vox2seq.encode(serialize_coords, mode="hilbert", permute=[0, 1, 2])
76
  elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
77
+ code = vox2seq.encode(serialize_coords, mode="hilbert", permute=[1, 0, 2])
78
  else:
79
  raise ValueError(f"Unknown serialize mode: {serialize_mode}")
80
+
81
  for bi, s in enumerate(tensor.layout):
82
  num_points = s.stop - s.start
83
  num_windows = (num_points + window_size - 1) // window_size
84
  valid_window_size = num_points / num_windows
85
+ to_ordered = torch.argsort(code[s.start : s.stop])
86
  if num_windows == 1:
87
  fwd_indices.append(to_ordered)
88
+ bwd_indices.append(
89
+ torch.zeros_like(to_ordered).scatter_(
90
+ 0, to_ordered, torch.arange(num_points, device=tensor.device)
91
+ )
92
+ )
93
  fwd_indices[-1] += s.start
94
  bwd_indices[-1] += offsets[-1]
95
  seq_lens.append(num_points)
 
98
  else:
99
  # Partition the input
100
  offset = 0
101
+ mids = [
102
+ (i + 0.5) * valid_window_size + shift_sequence
103
+ for i in range(num_windows)
104
+ ]
105
+ split = [
106
+ math.floor(i * valid_window_size + shift_sequence)
107
+ for i in range(num_windows + 1)
108
+ ]
109
+ bwd_index = torch.zeros(
110
+ (num_points,), dtype=torch.int64, device=tensor.device
111
+ )
112
  for i in range(num_windows):
113
  mid = mids[i]
114
  valid_start = split[i]
115
  valid_end = split[i + 1]
116
  padded_start = math.floor(mid - 0.5 * window_size)
117
  padded_end = padded_start + window_size
118
+ fwd_indices.append(
119
+ to_ordered[
120
+ torch.arange(padded_start, padded_end, device=tensor.device)
121
+ % num_points
122
+ ]
123
+ )
124
  offset += valid_start - padded_start
125
+ bwd_index.scatter_(
126
+ 0,
127
+ fwd_indices[-1][
128
+ valid_start - padded_start : valid_end - padded_start
129
+ ],
130
+ torch.arange(
131
+ offset, offset + valid_end - valid_start, device=tensor.device
132
+ ),
133
+ )
134
  offset += padded_end - valid_start
135
  fwd_indices[-1] += s.start
136
  seq_lens.extend([window_size] * num_windows)
 
142
  bwd_indices = torch.cat(bwd_indices)
143
 
144
  return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
145
+
146
 
147
  def sparse_serialized_scaled_dot_product_self_attention(
148
  qkv: SparseTensor,
149
  window_size: int,
150
  serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
151
  shift_sequence: int = 0,
152
+ shift_window: Tuple[int, int, int] = (0, 0, 0),
153
  ) -> SparseTensor:
154
  """
155
  Apply serialized scaled dot product self attention to a sparse tensor.
 
162
  shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
163
  shift (int): The shift to use.
164
  """
165
+ assert (
166
+ len(qkv.shape) == 4 and qkv.shape[1] == 3
167
+ ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
168
+
169
+ serialization_spatial_cache_name = (
170
+ f"serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}"
171
+ )
172
+ serialization_spatial_cache = qkv.get_spatial_cache(
173
+ serialization_spatial_cache_name
174
+ )
175
  if serialization_spatial_cache is None:
176
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(
177
+ qkv, window_size, serialize_mode, shift_sequence, shift_window
178
+ )
179
+ qkv.register_spatial_cache(
180
+ serialization_spatial_cache_name,
181
+ (fwd_indices, bwd_indices, seq_lens, seq_batch_indices),
182
+ )
183
  else:
184
+ (
185
+ fwd_indices,
186
+ bwd_indices,
187
+ seq_lens,
188
+ seq_batch_indices,
189
+ ) = serialization_spatial_cache
190
 
191
  M = fwd_indices.shape[0]
192
  T = qkv.feats.shape[0]
193
  H = qkv.feats.shape[2]
194
  C = qkv.feats.shape[3]
195
+
196
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
197
 
198
  if DEBUG:
199
  start = 0
200
  qkv_coords = qkv.coords[fwd_indices]
201
  for i in range(len(seq_lens)):
202
+ assert (
203
+ qkv_coords[start : start + seq_lens[i], 0] == seq_batch_indices[i]
204
+ ).all(), (
205
+ f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
206
+ )
207
  start += seq_lens[i]
208
 
209
  if all([seq_len == window_size for seq_len in seq_lens]):
210
  B = len(seq_lens)
211
  N = window_size
212
  qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
213
+ if ATTN == "xformers":
214
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
215
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
216
+ elif ATTN == "flash_attn":
217
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
218
  else:
219
  raise ValueError(f"Unknown attention module: {ATTN}")
220
+ out = out.reshape(B * N, H, C) # [M, H, C]
221
  else:
222
+ if ATTN == "xformers":
223
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
224
+ q = q.unsqueeze(0) # [1, M, H, C]
225
+ k = k.unsqueeze(0) # [1, M, H, C]
226
+ v = v.unsqueeze(0) # [1, M, H, C]
227
  mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
228
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
229
+ elif ATTN == "flash_attn":
230
+ cu_seqlens = (
231
+ torch.cat(
232
+ [torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)],
233
+ dim=0,
234
+ )
235
+ .to(qkv.device)
236
+ .int()
237
+ )
238
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(
239
+ qkv_feats, cu_seqlens, max(seq_lens)
240
+ ) # [M, H, C]
241
+
242
+ out = out[bwd_indices] # [T, H, C]
243
 
244
  if DEBUG:
245
  qkv_coords = qkv_coords[bwd_indices]
246
+ assert torch.equal(
247
+ qkv_coords, qkv.coords
248
+ ), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
249
 
250
  return qkv.replace(out)
trellis/modules/sparse/attention/windowed_attn.py CHANGED
@@ -4,23 +4,23 @@ import math
4
  from .. import SparseTensor
5
  from .. import DEBUG, ATTN
6
 
7
- if ATTN == 'xformers':
8
  import xformers.ops as xops
9
- elif ATTN == 'flash_attn':
10
  import flash_attn
11
  else:
12
  raise ValueError(f"Unknown attention module: {ATTN}")
13
 
14
 
15
  __all__ = [
16
- 'sparse_windowed_scaled_dot_product_self_attention',
17
  ]
18
 
19
 
20
  def calc_window_partition(
21
  tensor: SparseTensor,
22
  window_size: Union[int, Tuple[int, ...]],
23
- shift_window: Union[int, Tuple[int, ...]] = 0
24
  ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
25
  """
26
  Calculate serialization and partitioning for a set of coordinates.
@@ -37,33 +37,43 @@ def calc_window_partition(
37
  (List[int]): Sequence batch indices.
38
  """
39
  DIM = tensor.coords.shape[1] - 1
40
- shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
 
 
41
  window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
42
  shifted_coords = tensor.coords.clone().detach()
43
- shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
 
 
44
 
45
  MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
46
  NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
47
  OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
48
 
49
- shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
50
- shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
 
 
 
 
 
51
  fwd_indices = torch.argsort(shifted_indices)
52
  bwd_indices = torch.empty_like(fwd_indices)
53
  bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
54
  seq_lens = torch.bincount(shifted_indices)
55
- seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0]
 
 
 
56
  mask = seq_lens != 0
57
  seq_lens = seq_lens[mask].tolist()
58
  seq_batch_indices = seq_batch_indices[mask].tolist()
59
 
60
  return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
61
-
62
 
63
  def sparse_windowed_scaled_dot_product_self_attention(
64
- qkv: SparseTensor,
65
- window_size: int,
66
- shift_window: Tuple[int, int, int] = (0, 0, 0)
67
  ) -> SparseTensor:
68
  """
69
  Apply windowed scaled dot product self attention to a sparse tensor.
@@ -74,62 +84,95 @@ def sparse_windowed_scaled_dot_product_self_attention(
74
  shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
75
  shift (int): The shift to use.
76
  """
77
- assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
78
-
79
- serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}'
80
- serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
 
 
 
 
81
  if serialization_spatial_cache is None:
82
- fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window)
83
- qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
 
 
 
 
 
84
  else:
85
- fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
 
 
 
 
 
86
 
87
  M = fwd_indices.shape[0]
88
  T = qkv.feats.shape[0]
89
  H = qkv.feats.shape[2]
90
  C = qkv.feats.shape[3]
91
-
92
- qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
93
 
94
  if DEBUG:
95
  start = 0
96
  qkv_coords = qkv.coords[fwd_indices]
97
  for i in range(len(seq_lens)):
98
- seq_coords = qkv_coords[start:start+seq_lens[i]]
99
- assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
100
- assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
101
- f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
 
 
 
 
 
 
 
 
 
102
  start += seq_lens[i]
103
 
104
  if all([seq_len == window_size for seq_len in seq_lens]):
105
  B = len(seq_lens)
106
  N = window_size
107
  qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
108
- if ATTN == 'xformers':
109
- q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
110
- out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
111
- elif ATTN == 'flash_attn':
112
- out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
113
  else:
114
  raise ValueError(f"Unknown attention module: {ATTN}")
115
- out = out.reshape(B * N, H, C) # [M, H, C]
116
  else:
117
- if ATTN == 'xformers':
118
- q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
119
- q = q.unsqueeze(0) # [1, M, H, C]
120
- k = k.unsqueeze(0) # [1, M, H, C]
121
- v = v.unsqueeze(0) # [1, M, H, C]
122
  mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
123
- out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
124
- elif ATTN == 'flash_attn':
125
- cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
126
- .to(qkv.device).int()
127
- out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
128
-
129
- out = out[bwd_indices] # [T, H, C]
 
 
 
 
 
 
 
 
130
 
131
  if DEBUG:
132
  qkv_coords = qkv_coords[bwd_indices]
133
- assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
 
 
134
 
135
  return qkv.replace(out)
 
4
  from .. import SparseTensor
5
  from .. import DEBUG, ATTN
6
 
7
+ if ATTN == "xformers":
8
  import xformers.ops as xops
9
+ elif ATTN == "flash_attn":
10
  import flash_attn
11
  else:
12
  raise ValueError(f"Unknown attention module: {ATTN}")
13
 
14
 
15
  __all__ = [
16
+ "sparse_windowed_scaled_dot_product_self_attention",
17
  ]
18
 
19
 
20
  def calc_window_partition(
21
  tensor: SparseTensor,
22
  window_size: Union[int, Tuple[int, ...]],
23
+ shift_window: Union[int, Tuple[int, ...]] = 0,
24
  ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
25
  """
26
  Calculate serialization and partitioning for a set of coordinates.
 
37
  (List[int]): Sequence batch indices.
38
  """
39
  DIM = tensor.coords.shape[1] - 1
40
+ shift_window = (
41
+ (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
42
+ )
43
  window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
44
  shifted_coords = tensor.coords.clone().detach()
45
+ shifted_coords[:, 1:] += torch.tensor(
46
+ shift_window, device=tensor.device, dtype=torch.int32
47
+ ).unsqueeze(0)
48
 
49
  MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
50
  NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
51
  OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
52
 
53
+ shifted_coords[:, 1:] //= torch.tensor(
54
+ window_size, device=tensor.device, dtype=torch.int32
55
+ ).unsqueeze(0)
56
+ shifted_indices = (
57
+ shifted_coords
58
+ * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)
59
+ ).sum(dim=1)
60
  fwd_indices = torch.argsort(shifted_indices)
61
  bwd_indices = torch.empty_like(fwd_indices)
62
  bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
63
  seq_lens = torch.bincount(shifted_indices)
64
+ seq_batch_indices = (
65
+ torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32)
66
+ // OFFSET[0]
67
+ )
68
  mask = seq_lens != 0
69
  seq_lens = seq_lens[mask].tolist()
70
  seq_batch_indices = seq_batch_indices[mask].tolist()
71
 
72
  return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
73
+
74
 
75
  def sparse_windowed_scaled_dot_product_self_attention(
76
+ qkv: SparseTensor, window_size: int, shift_window: Tuple[int, int, int] = (0, 0, 0)
 
 
77
  ) -> SparseTensor:
78
  """
79
  Apply windowed scaled dot product self attention to a sparse tensor.
 
84
  shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
85
  shift (int): The shift to use.
86
  """
87
+ assert (
88
+ len(qkv.shape) == 4 and qkv.shape[1] == 3
89
+ ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
90
+
91
+ serialization_spatial_cache_name = f"window_partition_{window_size}_{shift_window}"
92
+ serialization_spatial_cache = qkv.get_spatial_cache(
93
+ serialization_spatial_cache_name
94
+ )
95
  if serialization_spatial_cache is None:
96
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(
97
+ qkv, window_size, shift_window
98
+ )
99
+ qkv.register_spatial_cache(
100
+ serialization_spatial_cache_name,
101
+ (fwd_indices, bwd_indices, seq_lens, seq_batch_indices),
102
+ )
103
  else:
104
+ (
105
+ fwd_indices,
106
+ bwd_indices,
107
+ seq_lens,
108
+ seq_batch_indices,
109
+ ) = serialization_spatial_cache
110
 
111
  M = fwd_indices.shape[0]
112
  T = qkv.feats.shape[0]
113
  H = qkv.feats.shape[2]
114
  C = qkv.feats.shape[3]
115
+
116
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
117
 
118
  if DEBUG:
119
  start = 0
120
  qkv_coords = qkv.coords[fwd_indices]
121
  for i in range(len(seq_lens)):
122
+ seq_coords = qkv_coords[start : start + seq_lens[i]]
123
+ assert (
124
+ seq_coords[:, 0] == seq_batch_indices[i]
125
+ ).all(), (
126
+ f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
127
+ )
128
+ assert (
129
+ seq_coords[:, 1:].max(dim=0).values
130
+ - seq_coords[:, 1:].min(dim=0).values
131
+ < window_size
132
+ ).all(), (
133
+ f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
134
+ )
135
  start += seq_lens[i]
136
 
137
  if all([seq_len == window_size for seq_len in seq_lens]):
138
  B = len(seq_lens)
139
  N = window_size
140
  qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
141
+ if ATTN == "xformers":
142
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
143
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
144
+ elif ATTN == "flash_attn":
145
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
146
  else:
147
  raise ValueError(f"Unknown attention module: {ATTN}")
148
+ out = out.reshape(B * N, H, C) # [M, H, C]
149
  else:
150
+ if ATTN == "xformers":
151
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
152
+ q = q.unsqueeze(0) # [1, M, H, C]
153
+ k = k.unsqueeze(0) # [1, M, H, C]
154
+ v = v.unsqueeze(0) # [1, M, H, C]
155
  mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
156
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
157
+ elif ATTN == "flash_attn":
158
+ cu_seqlens = (
159
+ torch.cat(
160
+ [torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)],
161
+ dim=0,
162
+ )
163
+ .to(qkv.device)
164
+ .int()
165
+ )
166
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(
167
+ qkv_feats, cu_seqlens, max(seq_lens)
168
+ ) # [M, H, C]
169
+
170
+ out = out[bwd_indices] # [T, H, C]
171
 
172
  if DEBUG:
173
  qkv_coords = qkv_coords[bwd_indices]
174
+ assert torch.equal(
175
+ qkv_coords, qkv.coords
176
+ ), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
177
 
178
  return qkv.replace(out)
trellis/modules/sparse/basic.py CHANGED
@@ -2,22 +2,23 @@ from typing import *
2
  import torch
3
  import torch.nn as nn
4
  from . import BACKEND, DEBUG
5
- SparseTensorData = None # Lazy import
 
6
 
7
 
8
  __all__ = [
9
- 'SparseTensor',
10
- 'sparse_batch_broadcast',
11
- 'sparse_batch_op',
12
- 'sparse_cat',
13
- 'sparse_unbind',
14
  ]
15
 
16
 
17
  class SparseTensor:
18
  """
19
  Sparse tensor with support for both torchsparse and spconv backends.
20
-
21
  Parameters:
22
  - feats (torch.Tensor): Features of the sparse tensor.
23
  - coords (torch.Tensor): Coordinates of the sparse tensor.
@@ -29,64 +30,89 @@ class SparseTensor:
29
  - Data corresponding to a same batch should be contiguous.
30
  - Coords should be in [0, 1023]
31
  """
 
32
  @overload
33
- def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
 
 
 
 
 
 
 
 
34
 
35
  @overload
36
- def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
 
 
 
 
 
 
 
37
 
38
  def __init__(self, *args, **kwargs):
39
  # Lazy import of sparse tensor backend
40
  global SparseTensorData
41
  if SparseTensorData is None:
42
  import importlib
43
- if BACKEND == 'torchsparse':
44
- SparseTensorData = importlib.import_module('torchsparse').SparseTensor
45
- elif BACKEND == 'spconv':
46
- SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
47
-
 
 
 
48
  method_id = 0
49
  if len(args) != 0:
50
  method_id = 0 if isinstance(args[0], torch.Tensor) else 1
51
  else:
52
- method_id = 1 if 'data' in kwargs else 0
53
 
54
  if method_id == 0:
55
  feats, coords, shape, layout = args + (None,) * (4 - len(args))
56
- if 'feats' in kwargs:
57
- feats = kwargs['feats']
58
- del kwargs['feats']
59
- if 'coords' in kwargs:
60
- coords = kwargs['coords']
61
- del kwargs['coords']
62
- if 'shape' in kwargs:
63
- shape = kwargs['shape']
64
- del kwargs['shape']
65
- if 'layout' in kwargs:
66
- layout = kwargs['layout']
67
- del kwargs['layout']
68
 
69
  if shape is None:
70
  shape = self.__cal_shape(feats, coords)
71
  if layout is None:
72
  layout = self.__cal_layout(coords, shape[0])
73
- if BACKEND == 'torchsparse':
74
  self.data = SparseTensorData(feats, coords, **kwargs)
75
- elif BACKEND == 'spconv':
76
  spatial_shape = list(coords.max(0)[0] + 1)[1:]
77
- self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs)
 
 
 
 
 
 
78
  self.data._features = feats
79
  elif method_id == 1:
80
  data, shape, layout = args + (None,) * (3 - len(args))
81
- if 'data' in kwargs:
82
- data = kwargs['data']
83
- del kwargs['data']
84
- if 'shape' in kwargs:
85
- shape = kwargs['shape']
86
- del kwargs['shape']
87
- if 'layout' in kwargs:
88
- layout = kwargs['layout']
89
- del kwargs['layout']
90
 
91
  self.data = data
92
  if shape is None:
@@ -96,73 +122,84 @@ class SparseTensor:
96
 
97
  self._shape = shape
98
  self._layout = layout
99
- self._scale = kwargs.get('scale', (1, 1, 1))
100
- self._spatial_cache = kwargs.get('spatial_cache', {})
101
 
102
  if DEBUG:
103
  try:
104
- assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
105
- assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
106
- assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
 
 
 
 
 
 
107
  for i in range(self.shape[0]):
108
- assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
 
 
109
  except Exception as e:
110
- print('Debugging information:')
111
  print(f"- Shape: {self.shape}")
112
  print(f"- Layout: {self.layout}")
113
  print(f"- Scale: {self._scale}")
114
  print(f"- Coords: {self.coords}")
115
  raise e
116
-
117
  def __cal_shape(self, feats, coords):
118
  shape = []
119
  shape.append(coords[:, 0].max().item() + 1)
120
  shape.extend([*feats.shape[1:]])
121
  return torch.Size(shape)
122
-
123
  def __cal_layout(self, coords, batch_size):
124
  seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
125
- offset = torch.cumsum(seq_len, dim=0)
126
- layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
 
 
 
127
  return layout
128
-
129
  @property
130
  def shape(self) -> torch.Size:
131
  return self._shape
132
-
133
  def dim(self) -> int:
134
  return len(self.shape)
135
-
136
  @property
137
  def layout(self) -> List[slice]:
138
  return self._layout
139
 
140
  @property
141
  def feats(self) -> torch.Tensor:
142
- if BACKEND == 'torchsparse':
143
  return self.data.F
144
- elif BACKEND == 'spconv':
145
  return self.data.features
146
-
147
  @feats.setter
148
  def feats(self, value: torch.Tensor):
149
- if BACKEND == 'torchsparse':
150
  self.data.F = value
151
- elif BACKEND == 'spconv':
152
  self.data.features = value
153
 
154
  @property
155
  def coords(self) -> torch.Tensor:
156
- if BACKEND == 'torchsparse':
157
  return self.data.C
158
- elif BACKEND == 'spconv':
159
  return self.data.indices
160
-
161
  @coords.setter
162
  def coords(self, value: torch.Tensor):
163
- if BACKEND == 'torchsparse':
164
  self.data.C = value
165
- elif BACKEND == 'spconv':
166
  self.data.indices = value
167
 
168
  @property
@@ -174,12 +211,18 @@ class SparseTensor:
174
  return self.feats.device
175
 
176
  @overload
177
- def to(self, dtype: torch.dtype) -> 'SparseTensor': ...
 
178
 
179
  @overload
180
- def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ...
181
-
182
- def to(self, *args, **kwargs) -> 'SparseTensor':
 
 
 
 
 
183
  device = None
184
  dtype = None
185
  if len(args) == 2:
@@ -189,13 +232,13 @@ class SparseTensor:
189
  dtype = args[0]
190
  else:
191
  device = args[0]
192
- if 'dtype' in kwargs:
193
  assert dtype is None, "to() received multiple values for argument 'dtype'"
194
- dtype = kwargs['dtype']
195
- if 'device' in kwargs:
196
  assert device is None, "to() received multiple values for argument 'device'"
197
- device = kwargs['device']
198
-
199
  new_feats = self.feats.to(device=device, dtype=dtype)
200
  new_coords = self.coords.to(device=device)
201
  return self.replace(new_feats, new_coords)
@@ -204,46 +247,48 @@ class SparseTensor:
204
  new_feats = self.feats.type(dtype)
205
  return self.replace(new_feats)
206
 
207
- def cpu(self) -> 'SparseTensor':
208
  new_feats = self.feats.cpu()
209
  new_coords = self.coords.cpu()
210
  return self.replace(new_feats, new_coords)
211
-
212
- def cuda(self) -> 'SparseTensor':
213
  new_feats = self.feats.cuda()
214
  new_coords = self.coords.cuda()
215
  return self.replace(new_feats, new_coords)
216
 
217
- def half(self) -> 'SparseTensor':
218
  new_feats = self.feats.half()
219
  return self.replace(new_feats)
220
-
221
- def float(self) -> 'SparseTensor':
222
  new_feats = self.feats.float()
223
  return self.replace(new_feats)
224
-
225
- def detach(self) -> 'SparseTensor':
226
  new_coords = self.coords.detach()
227
  new_feats = self.feats.detach()
228
  return self.replace(new_feats, new_coords)
229
 
230
  def dense(self) -> torch.Tensor:
231
- if BACKEND == 'torchsparse':
232
  return self.data.dense()
233
- elif BACKEND == 'spconv':
234
  return self.data.dense()
235
 
236
- def reshape(self, *shape) -> 'SparseTensor':
237
  new_feats = self.feats.reshape(self.feats.shape[0], *shape)
238
  return self.replace(new_feats)
239
-
240
- def unbind(self, dim: int) -> List['SparseTensor']:
241
  return sparse_unbind(self, dim)
242
 
243
- def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
 
 
244
  new_shape = [self.shape[0]]
245
  new_shape.extend(feats.shape[1:])
246
- if BACKEND == 'torchsparse':
247
  new_data = SparseTensorData(
248
  feats=feats,
249
  coords=self.data.coords if coords is None else coords,
@@ -251,7 +296,7 @@ class SparseTensor:
251
  spatial_range=self.data.spatial_range,
252
  )
253
  new_data._caches = self.data._caches
254
- elif BACKEND == 'spconv':
255
  new_data = SparseTensorData(
256
  self.data.features.reshape(self.data.features.shape[0], -1),
257
  self.data.indices,
@@ -259,7 +304,7 @@ class SparseTensor:
259
  self.data.batch_size,
260
  self.data.grid,
261
  self.data.voxel_num,
262
- self.data.indice_dict
263
  )
264
  new_data._features = feats
265
  new_data.benchmark = self.data.benchmark
@@ -270,26 +315,39 @@ class SparseTensor:
270
  new_data.int8_scale = self.data.int8_scale
271
  if coords is not None:
272
  new_data.indices = coords
273
- new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache)
 
 
 
 
 
 
274
  return new_tensor
275
 
276
  @staticmethod
277
- def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
278
  N, C = dim
279
  x = torch.arange(aabb[0], aabb[3] + 1)
280
  y = torch.arange(aabb[1], aabb[4] + 1)
281
  z = torch.arange(aabb[2], aabb[5] + 1)
282
- coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
283
- coords = torch.cat([
284
- torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
285
- coords.repeat(N, 1),
286
- ], dim=1).to(dtype=torch.int32, device=device)
 
 
 
 
 
287
  feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
288
  return SparseTensor(feats=feats, coords=coords)
289
 
290
- def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
291
  new_cache = {}
292
- for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
 
 
293
  if k in self._spatial_cache:
294
  new_cache[k] = self._spatial_cache[k]
295
  if k in other._spatial_cache:
@@ -299,10 +357,12 @@ class SparseTensor:
299
  new_cache[k].update(other._spatial_cache[k])
300
  return new_cache
301
 
302
- def __neg__(self) -> 'SparseTensor':
303
  return self.replace(-self.feats)
304
-
305
- def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor':
 
 
306
  if isinstance(other, torch.Tensor):
307
  try:
308
  other = torch.broadcast_to(other, self.shape)
@@ -317,28 +377,44 @@ class SparseTensor:
317
  new_tensor._spatial_cache = self.__merge_sparse_cache(other)
318
  return new_tensor
319
 
320
- def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
321
  return self.__elemwise__(other, torch.add)
322
 
323
- def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
324
  return self.__elemwise__(other, torch.add)
325
-
326
- def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
327
  return self.__elemwise__(other, torch.sub)
328
-
329
- def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
330
  return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
331
 
332
- def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
333
  return self.__elemwise__(other, torch.mul)
334
 
335
- def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
336
  return self.__elemwise__(other, torch.mul)
337
 
338
- def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
339
  return self.__elemwise__(other, torch.div)
340
 
341
- def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
342
  return self.__elemwise__(other, lambda x, y: torch.div(y, x))
343
 
344
  def __getitem__(self, idx):
@@ -348,7 +424,9 @@ class SparseTensor:
348
  idx = range(*idx.indices(self.shape[0]))
349
  elif isinstance(idx, torch.Tensor):
350
  if idx.dtype == torch.bool:
351
- assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
 
 
352
  idx = idx.nonzero().squeeze(1)
353
  elif idx.dtype in [torch.int32, torch.int64]:
354
  assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
@@ -356,7 +434,7 @@ class SparseTensor:
356
  raise ValueError(f"Unknown index type: {idx.dtype}")
357
  else:
358
  raise ValueError(f"Unknown index type: {type(idx)}")
359
-
360
  coords = []
361
  feats = []
362
  for new_idx, old_idx in enumerate(idx):
@@ -392,7 +470,7 @@ class SparseTensor:
392
  def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
393
  """
394
  Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
395
-
396
  Args:
397
  input (torch.Tensor): 1D tensor to broadcast.
398
  target (SparseTensor): Sparse tensor to broadcast to.
@@ -405,10 +483,12 @@ def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Te
405
  return broadcasted
406
 
407
 
408
- def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor:
 
 
409
  """
410
  Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
411
-
412
  Args:
413
  input (torch.Tensor): 1D tensor to broadcast.
414
  target (SparseTensor): Sparse tensor to broadcast to.
@@ -420,7 +500,7 @@ def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = tor
420
  def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
421
  """
422
  Concatenate a list of sparse tensors.
423
-
424
  Args:
425
  inputs (List[SparseTensor]): List of sparse tensors to concatenate.
426
  """
@@ -447,7 +527,7 @@ def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
447
  def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
448
  """
449
  Unbind a sparse tensor along a dimension.
450
-
451
  Args:
452
  input (SparseTensor): Sparse tensor to unbind.
453
  dim (int): Dimension to unbind.
 
2
  import torch
3
  import torch.nn as nn
4
  from . import BACKEND, DEBUG
5
+
6
+ SparseTensorData = None # Lazy import
7
 
8
 
9
  __all__ = [
10
+ "SparseTensor",
11
+ "sparse_batch_broadcast",
12
+ "sparse_batch_op",
13
+ "sparse_cat",
14
+ "sparse_unbind",
15
  ]
16
 
17
 
18
  class SparseTensor:
19
  """
20
  Sparse tensor with support for both torchsparse and spconv backends.
21
+
22
  Parameters:
23
  - feats (torch.Tensor): Features of the sparse tensor.
24
  - coords (torch.Tensor): Coordinates of the sparse tensor.
 
30
  - Data corresponding to a same batch should be contiguous.
31
  - Coords should be in [0, 1023]
32
  """
33
+
34
  @overload
35
+ def __init__(
36
+ self,
37
+ feats: torch.Tensor,
38
+ coords: torch.Tensor,
39
+ shape: Optional[torch.Size] = None,
40
+ layout: Optional[List[slice]] = None,
41
+ **kwargs,
42
+ ):
43
+ ...
44
 
45
  @overload
46
+ def __init__(
47
+ self,
48
+ data,
49
+ shape: Optional[torch.Size] = None,
50
+ layout: Optional[List[slice]] = None,
51
+ **kwargs,
52
+ ):
53
+ ...
54
 
55
  def __init__(self, *args, **kwargs):
56
  # Lazy import of sparse tensor backend
57
  global SparseTensorData
58
  if SparseTensorData is None:
59
  import importlib
60
+
61
+ if BACKEND == "torchsparse":
62
+ SparseTensorData = importlib.import_module("torchsparse").SparseTensor
63
+ elif BACKEND == "spconv":
64
+ SparseTensorData = importlib.import_module(
65
+ "spconv.pytorch"
66
+ ).SparseConvTensor
67
+
68
  method_id = 0
69
  if len(args) != 0:
70
  method_id = 0 if isinstance(args[0], torch.Tensor) else 1
71
  else:
72
+ method_id = 1 if "data" in kwargs else 0
73
 
74
  if method_id == 0:
75
  feats, coords, shape, layout = args + (None,) * (4 - len(args))
76
+ if "feats" in kwargs:
77
+ feats = kwargs["feats"]
78
+ del kwargs["feats"]
79
+ if "coords" in kwargs:
80
+ coords = kwargs["coords"]
81
+ del kwargs["coords"]
82
+ if "shape" in kwargs:
83
+ shape = kwargs["shape"]
84
+ del kwargs["shape"]
85
+ if "layout" in kwargs:
86
+ layout = kwargs["layout"]
87
+ del kwargs["layout"]
88
 
89
  if shape is None:
90
  shape = self.__cal_shape(feats, coords)
91
  if layout is None:
92
  layout = self.__cal_layout(coords, shape[0])
93
+ if BACKEND == "torchsparse":
94
  self.data = SparseTensorData(feats, coords, **kwargs)
95
+ elif BACKEND == "spconv":
96
  spatial_shape = list(coords.max(0)[0] + 1)[1:]
97
+ self.data = SparseTensorData(
98
+ feats.reshape(feats.shape[0], -1),
99
+ coords,
100
+ spatial_shape,
101
+ shape[0],
102
+ **kwargs,
103
+ )
104
  self.data._features = feats
105
  elif method_id == 1:
106
  data, shape, layout = args + (None,) * (3 - len(args))
107
+ if "data" in kwargs:
108
+ data = kwargs["data"]
109
+ del kwargs["data"]
110
+ if "shape" in kwargs:
111
+ shape = kwargs["shape"]
112
+ del kwargs["shape"]
113
+ if "layout" in kwargs:
114
+ layout = kwargs["layout"]
115
+ del kwargs["layout"]
116
 
117
  self.data = data
118
  if shape is None:
 
122
 
123
  self._shape = shape
124
  self._layout = layout
125
+ self._scale = kwargs.get("scale", (1, 1, 1))
126
+ self._spatial_cache = kwargs.get("spatial_cache", {})
127
 
128
  if DEBUG:
129
  try:
130
+ assert (
131
+ self.feats.shape[0] == self.coords.shape[0]
132
+ ), f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
133
+ assert self.shape == self.__cal_shape(
134
+ self.feats, self.coords
135
+ ), f"Invalid shape: {self.shape}"
136
+ assert self.layout == self.__cal_layout(
137
+ self.coords, self.shape[0]
138
+ ), f"Invalid layout: {self.layout}"
139
  for i in range(self.shape[0]):
140
+ assert torch.all(
141
+ self.coords[self.layout[i], 0] == i
142
+ ), f"The data of batch {i} is not contiguous"
143
  except Exception as e:
144
+ print("Debugging information:")
145
  print(f"- Shape: {self.shape}")
146
  print(f"- Layout: {self.layout}")
147
  print(f"- Scale: {self._scale}")
148
  print(f"- Coords: {self.coords}")
149
  raise e
150
+
151
  def __cal_shape(self, feats, coords):
152
  shape = []
153
  shape.append(coords[:, 0].max().item() + 1)
154
  shape.extend([*feats.shape[1:]])
155
  return torch.Size(shape)
156
+
157
  def __cal_layout(self, coords, batch_size):
158
  seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
159
+ offset = torch.cumsum(seq_len, dim=0)
160
+ layout = [
161
+ slice((offset[i] - seq_len[i]).item(), offset[i].item())
162
+ for i in range(batch_size)
163
+ ]
164
  return layout
165
+
166
  @property
167
  def shape(self) -> torch.Size:
168
  return self._shape
169
+
170
  def dim(self) -> int:
171
  return len(self.shape)
172
+
173
  @property
174
  def layout(self) -> List[slice]:
175
  return self._layout
176
 
177
  @property
178
  def feats(self) -> torch.Tensor:
179
+ if BACKEND == "torchsparse":
180
  return self.data.F
181
+ elif BACKEND == "spconv":
182
  return self.data.features
183
+
184
  @feats.setter
185
  def feats(self, value: torch.Tensor):
186
+ if BACKEND == "torchsparse":
187
  self.data.F = value
188
+ elif BACKEND == "spconv":
189
  self.data.features = value
190
 
191
  @property
192
  def coords(self) -> torch.Tensor:
193
+ if BACKEND == "torchsparse":
194
  return self.data.C
195
+ elif BACKEND == "spconv":
196
  return self.data.indices
197
+
198
  @coords.setter
199
  def coords(self, value: torch.Tensor):
200
+ if BACKEND == "torchsparse":
201
  self.data.C = value
202
+ elif BACKEND == "spconv":
203
  self.data.indices = value
204
 
205
  @property
 
211
  return self.feats.device
212
 
213
  @overload
214
+ def to(self, dtype: torch.dtype) -> "SparseTensor":
215
+ ...
216
 
217
  @overload
218
+ def to(
219
+ self,
220
+ device: Optional[Union[str, torch.device]] = None,
221
+ dtype: Optional[torch.dtype] = None,
222
+ ) -> "SparseTensor":
223
+ ...
224
+
225
+ def to(self, *args, **kwargs) -> "SparseTensor":
226
  device = None
227
  dtype = None
228
  if len(args) == 2:
 
232
  dtype = args[0]
233
  else:
234
  device = args[0]
235
+ if "dtype" in kwargs:
236
  assert dtype is None, "to() received multiple values for argument 'dtype'"
237
+ dtype = kwargs["dtype"]
238
+ if "device" in kwargs:
239
  assert device is None, "to() received multiple values for argument 'device'"
240
+ device = kwargs["device"]
241
+
242
  new_feats = self.feats.to(device=device, dtype=dtype)
243
  new_coords = self.coords.to(device=device)
244
  return self.replace(new_feats, new_coords)
 
247
  new_feats = self.feats.type(dtype)
248
  return self.replace(new_feats)
249
 
250
+ def cpu(self) -> "SparseTensor":
251
  new_feats = self.feats.cpu()
252
  new_coords = self.coords.cpu()
253
  return self.replace(new_feats, new_coords)
254
+
255
+ def cuda(self) -> "SparseTensor":
256
  new_feats = self.feats.cuda()
257
  new_coords = self.coords.cuda()
258
  return self.replace(new_feats, new_coords)
259
 
260
+ def half(self) -> "SparseTensor":
261
  new_feats = self.feats.half()
262
  return self.replace(new_feats)
263
+
264
+ def float(self) -> "SparseTensor":
265
  new_feats = self.feats.float()
266
  return self.replace(new_feats)
267
+
268
+ def detach(self) -> "SparseTensor":
269
  new_coords = self.coords.detach()
270
  new_feats = self.feats.detach()
271
  return self.replace(new_feats, new_coords)
272
 
273
  def dense(self) -> torch.Tensor:
274
+ if BACKEND == "torchsparse":
275
  return self.data.dense()
276
+ elif BACKEND == "spconv":
277
  return self.data.dense()
278
 
279
+ def reshape(self, *shape) -> "SparseTensor":
280
  new_feats = self.feats.reshape(self.feats.shape[0], *shape)
281
  return self.replace(new_feats)
282
+
283
+ def unbind(self, dim: int) -> List["SparseTensor"]:
284
  return sparse_unbind(self, dim)
285
 
286
+ def replace(
287
+ self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None
288
+ ) -> "SparseTensor":
289
  new_shape = [self.shape[0]]
290
  new_shape.extend(feats.shape[1:])
291
+ if BACKEND == "torchsparse":
292
  new_data = SparseTensorData(
293
  feats=feats,
294
  coords=self.data.coords if coords is None else coords,
 
296
  spatial_range=self.data.spatial_range,
297
  )
298
  new_data._caches = self.data._caches
299
+ elif BACKEND == "spconv":
300
  new_data = SparseTensorData(
301
  self.data.features.reshape(self.data.features.shape[0], -1),
302
  self.data.indices,
 
304
  self.data.batch_size,
305
  self.data.grid,
306
  self.data.voxel_num,
307
+ self.data.indice_dict,
308
  )
309
  new_data._features = feats
310
  new_data.benchmark = self.data.benchmark
 
315
  new_data.int8_scale = self.data.int8_scale
316
  if coords is not None:
317
  new_data.indices = coords
318
+ new_tensor = SparseTensor(
319
+ new_data,
320
+ shape=torch.Size(new_shape),
321
+ layout=self.layout,
322
+ scale=self._scale,
323
+ spatial_cache=self._spatial_cache,
324
+ )
325
  return new_tensor
326
 
327
  @staticmethod
328
+ def full(aabb, dim, value, dtype=torch.float32, device=None) -> "SparseTensor":
329
  N, C = dim
330
  x = torch.arange(aabb[0], aabb[3] + 1)
331
  y = torch.arange(aabb[1], aabb[4] + 1)
332
  z = torch.arange(aabb[2], aabb[5] + 1)
333
+ coords = torch.stack(torch.meshgrid(x, y, z, indexing="ij"), dim=-1).reshape(
334
+ -1, 3
335
+ )
336
+ coords = torch.cat(
337
+ [
338
+ torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
339
+ coords.repeat(N, 1),
340
+ ],
341
+ dim=1,
342
+ ).to(dtype=torch.int32, device=device)
343
  feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
344
  return SparseTensor(feats=feats, coords=coords)
345
 
346
+ def __merge_sparse_cache(self, other: "SparseTensor") -> dict:
347
  new_cache = {}
348
+ for k in set(
349
+ list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())
350
+ ):
351
  if k in self._spatial_cache:
352
  new_cache[k] = self._spatial_cache[k]
353
  if k in other._spatial_cache:
 
357
  new_cache[k].update(other._spatial_cache[k])
358
  return new_cache
359
 
360
+ def __neg__(self) -> "SparseTensor":
361
  return self.replace(-self.feats)
362
+
363
+ def __elemwise__(
364
+ self, other: Union[torch.Tensor, "SparseTensor"], op: callable
365
+ ) -> "SparseTensor":
366
  if isinstance(other, torch.Tensor):
367
  try:
368
  other = torch.broadcast_to(other, self.shape)
 
377
  new_tensor._spatial_cache = self.__merge_sparse_cache(other)
378
  return new_tensor
379
 
380
+ def __add__(
381
+ self, other: Union[torch.Tensor, "SparseTensor", float]
382
+ ) -> "SparseTensor":
383
  return self.__elemwise__(other, torch.add)
384
 
385
+ def __radd__(
386
+ self, other: Union[torch.Tensor, "SparseTensor", float]
387
+ ) -> "SparseTensor":
388
  return self.__elemwise__(other, torch.add)
389
+
390
+ def __sub__(
391
+ self, other: Union[torch.Tensor, "SparseTensor", float]
392
+ ) -> "SparseTensor":
393
  return self.__elemwise__(other, torch.sub)
394
+
395
+ def __rsub__(
396
+ self, other: Union[torch.Tensor, "SparseTensor", float]
397
+ ) -> "SparseTensor":
398
  return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
399
 
400
+ def __mul__(
401
+ self, other: Union[torch.Tensor, "SparseTensor", float]
402
+ ) -> "SparseTensor":
403
  return self.__elemwise__(other, torch.mul)
404
 
405
+ def __rmul__(
406
+ self, other: Union[torch.Tensor, "SparseTensor", float]
407
+ ) -> "SparseTensor":
408
  return self.__elemwise__(other, torch.mul)
409
 
410
+ def __truediv__(
411
+ self, other: Union[torch.Tensor, "SparseTensor", float]
412
+ ) -> "SparseTensor":
413
  return self.__elemwise__(other, torch.div)
414
 
415
+ def __rtruediv__(
416
+ self, other: Union[torch.Tensor, "SparseTensor", float]
417
+ ) -> "SparseTensor":
418
  return self.__elemwise__(other, lambda x, y: torch.div(y, x))
419
 
420
  def __getitem__(self, idx):
 
424
  idx = range(*idx.indices(self.shape[0]))
425
  elif isinstance(idx, torch.Tensor):
426
  if idx.dtype == torch.bool:
427
+ assert idx.shape == (
428
+ self.shape[0],
429
+ ), f"Invalid index shape: {idx.shape}"
430
  idx = idx.nonzero().squeeze(1)
431
  elif idx.dtype in [torch.int32, torch.int64]:
432
  assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
 
434
  raise ValueError(f"Unknown index type: {idx.dtype}")
435
  else:
436
  raise ValueError(f"Unknown index type: {type(idx)}")
437
+
438
  coords = []
439
  feats = []
440
  for new_idx, old_idx in enumerate(idx):
 
470
  def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
471
  """
472
  Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
473
+
474
  Args:
475
  input (torch.Tensor): 1D tensor to broadcast.
476
  target (SparseTensor): Sparse tensor to broadcast to.
 
483
  return broadcasted
484
 
485
 
486
+ def sparse_batch_op(
487
+ input: SparseTensor, other: torch.Tensor, op: callable = torch.add
488
+ ) -> SparseTensor:
489
  """
490
  Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
491
+
492
  Args:
493
  input (torch.Tensor): 1D tensor to broadcast.
494
  target (SparseTensor): Sparse tensor to broadcast to.
 
500
  def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
501
  """
502
  Concatenate a list of sparse tensors.
503
+
504
  Args:
505
  inputs (List[SparseTensor]): List of sparse tensors to concatenate.
506
  """
 
527
  def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
528
  """
529
  Unbind a sparse tensor along a dimension.
530
+
531
  Args:
532
  input (SparseTensor): Sparse tensor to unbind.
533
  dim (int): Dimension to unbind.
trellis/modules/sparse/conv/__init__.py CHANGED
@@ -1,21 +1,26 @@
1
  from .. import BACKEND
2
 
3
 
4
- SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
 
5
 
6
  def __from_env():
7
  import os
8
-
9
  global SPCONV_ALGO
10
- env_spconv_algo = os.environ.get('SPCONV_ALGO')
11
- if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
 
 
 
 
12
  SPCONV_ALGO = env_spconv_algo
13
  print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
14
-
15
 
16
  __from_env()
17
 
18
- if BACKEND == 'torchsparse':
19
  from .conv_torchsparse import *
20
- elif BACKEND == 'spconv':
21
  from .conv_spconv import *
 
1
  from .. import BACKEND
2
 
3
 
4
+ SPCONV_ALGO = "auto" # 'auto', 'implicit_gemm', 'native'
5
+
6
 
7
  def __from_env():
8
  import os
9
+
10
  global SPCONV_ALGO
11
+ env_spconv_algo = os.environ.get("SPCONV_ALGO")
12
+ if env_spconv_algo is not None and env_spconv_algo in [
13
+ "auto",
14
+ "implicit_gemm",
15
+ "native",
16
+ ]:
17
  SPCONV_ALGO = env_spconv_algo
18
  print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
19
+
20
 
21
  __from_env()
22
 
23
+ if BACKEND == "torchsparse":
24
  from .conv_torchsparse import *
25
+ elif BACKEND == "spconv":
26
  from .conv_spconv import *
trellis/modules/sparse/conv/conv_spconv.py CHANGED
@@ -4,21 +4,54 @@ from .. import SparseTensor
4
  from .. import DEBUG
5
  from . import SPCONV_ALGO
6
 
 
7
  class SparseConv3d(nn.Module):
8
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
 
 
 
 
 
 
 
 
 
 
9
  super(SparseConv3d, self).__init__()
10
- if 'spconv' not in globals():
11
  import spconv.pytorch as spconv
12
  algo = None
13
- if SPCONV_ALGO == 'native':
14
  algo = spconv.ConvAlgo.Native
15
- elif SPCONV_ALGO == 'implicit_gemm':
16
  algo = spconv.ConvAlgo.MaskImplicitGemm
17
  if stride == 1 and (padding is None):
18
- self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
 
 
 
 
 
 
 
 
19
  else:
20
- self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
21
- self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  self.padding = padding
23
 
24
  def forward(self, x: SparseTensor) -> SparseTensor:
@@ -30,42 +63,65 @@ class SparseConv3d(nn.Module):
30
  if spatial_changed and (x.shape[0] != 1):
31
  # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
32
  fwd = new_data.indices[:, 0].argsort()
33
- bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device))
 
 
34
  sorted_feats = new_data.features[fwd]
35
  sorted_coords = new_data.indices[fwd]
36
  unsorted_data = new_data
37
  new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore
38
 
39
  out = SparseTensor(
40
- new_data, shape=torch.Size(new_shape), layout=new_layout,
 
 
41
  scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
42
  spatial_cache=x._spatial_cache,
43
  )
44
 
45
  if spatial_changed and (x.shape[0] != 1):
46
- out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data)
47
- out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd)
48
-
 
 
49
  return out
50
 
51
 
52
  class SparseInverseConv3d(nn.Module):
53
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
 
 
 
 
 
 
 
 
 
54
  super(SparseInverseConv3d, self).__init__()
55
- if 'spconv' not in globals():
56
  import spconv.pytorch as spconv
57
- self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key)
58
- self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
 
 
 
 
 
 
59
 
60
  def forward(self, x: SparseTensor) -> SparseTensor:
61
  spatial_changed = any(s != 1 for s in self.stride)
62
  if spatial_changed:
63
  # recover the original spconv order
64
- data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data')
65
- bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd')
66
  data = data.replace_feature(x.feats[bwd])
67
  if DEBUG:
68
- assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed'
 
 
69
  else:
70
  data = x.data
71
 
@@ -73,7 +129,9 @@ class SparseInverseConv3d(nn.Module):
73
  new_shape = [x.shape[0], self.conv.out_channels]
74
  new_layout = None if spatial_changed else x.layout
75
  out = SparseTensor(
76
- new_data, shape=torch.Size(new_shape), layout=new_layout,
 
 
77
  scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
78
  spatial_cache=x._spatial_cache,
79
  )
 
4
  from .. import DEBUG
5
  from . import SPCONV_ALGO
6
 
7
+
8
  class SparseConv3d(nn.Module):
9
+ def __init__(
10
+ self,
11
+ in_channels,
12
+ out_channels,
13
+ kernel_size,
14
+ stride=1,
15
+ dilation=1,
16
+ padding=None,
17
+ bias=True,
18
+ indice_key=None,
19
+ ):
20
  super(SparseConv3d, self).__init__()
21
+ if "spconv" not in globals():
22
  import spconv.pytorch as spconv
23
  algo = None
24
+ if SPCONV_ALGO == "native":
25
  algo = spconv.ConvAlgo.Native
26
+ elif SPCONV_ALGO == "implicit_gemm":
27
  algo = spconv.ConvAlgo.MaskImplicitGemm
28
  if stride == 1 and (padding is None):
29
+ self.conv = spconv.SubMConv3d(
30
+ in_channels,
31
+ out_channels,
32
+ kernel_size,
33
+ dilation=dilation,
34
+ bias=bias,
35
+ indice_key=indice_key,
36
+ algo=algo,
37
+ )
38
  else:
39
+ self.conv = spconv.SparseConv3d(
40
+ in_channels,
41
+ out_channels,
42
+ kernel_size,
43
+ stride=stride,
44
+ dilation=dilation,
45
+ padding=padding,
46
+ bias=bias,
47
+ indice_key=indice_key,
48
+ algo=algo,
49
+ )
50
+ self.stride = (
51
+ tuple(stride)
52
+ if isinstance(stride, (list, tuple))
53
+ else (stride, stride, stride)
54
+ )
55
  self.padding = padding
56
 
57
  def forward(self, x: SparseTensor) -> SparseTensor:
 
63
  if spatial_changed and (x.shape[0] != 1):
64
  # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
65
  fwd = new_data.indices[:, 0].argsort()
66
+ bwd = torch.zeros_like(fwd).scatter_(
67
+ 0, fwd, torch.arange(fwd.shape[0], device=fwd.device)
68
+ )
69
  sorted_feats = new_data.features[fwd]
70
  sorted_coords = new_data.indices[fwd]
71
  unsorted_data = new_data
72
  new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore
73
 
74
  out = SparseTensor(
75
+ new_data,
76
+ shape=torch.Size(new_shape),
77
+ layout=new_layout,
78
  scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
79
  spatial_cache=x._spatial_cache,
80
  )
81
 
82
  if spatial_changed and (x.shape[0] != 1):
83
+ out.register_spatial_cache(
84
+ f"conv_{self.stride}_unsorted_data", unsorted_data
85
+ )
86
+ out.register_spatial_cache(f"conv_{self.stride}_sort_bwd", bwd)
87
+
88
  return out
89
 
90
 
91
  class SparseInverseConv3d(nn.Module):
92
+ def __init__(
93
+ self,
94
+ in_channels,
95
+ out_channels,
96
+ kernel_size,
97
+ stride=1,
98
+ dilation=1,
99
+ bias=True,
100
+ indice_key=None,
101
+ ):
102
  super(SparseInverseConv3d, self).__init__()
103
+ if "spconv" not in globals():
104
  import spconv.pytorch as spconv
105
+ self.conv = spconv.SparseInverseConv3d(
106
+ in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key
107
+ )
108
+ self.stride = (
109
+ tuple(stride)
110
+ if isinstance(stride, (list, tuple))
111
+ else (stride, stride, stride)
112
+ )
113
 
114
  def forward(self, x: SparseTensor) -> SparseTensor:
115
  spatial_changed = any(s != 1 for s in self.stride)
116
  if spatial_changed:
117
  # recover the original spconv order
118
+ data = x.get_spatial_cache(f"conv_{self.stride}_unsorted_data")
119
+ bwd = x.get_spatial_cache(f"conv_{self.stride}_sort_bwd")
120
  data = data.replace_feature(x.feats[bwd])
121
  if DEBUG:
122
+ assert torch.equal(
123
+ data.indices, x.coords[bwd]
124
+ ), "Recover the original order failed"
125
  else:
126
  data = x.data
127
 
 
129
  new_shape = [x.shape[0], self.conv.out_channels]
130
  new_layout = None if spatial_changed else x.layout
131
  out = SparseTensor(
132
+ new_data,
133
+ shape=torch.Size(new_shape),
134
+ layout=new_layout,
135
  scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
136
  spatial_cache=x._spatial_cache,
137
  )
trellis/modules/sparse/conv/conv_torchsparse.py CHANGED
@@ -4,35 +4,73 @@ from .. import SparseTensor
4
 
5
 
6
  class SparseConv3d(nn.Module):
7
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
 
 
 
 
 
 
 
 
 
8
  super(SparseConv3d, self).__init__()
9
- if 'torchsparse' not in globals():
10
  import torchsparse
11
- self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias)
 
 
12
 
13
  def forward(self, x: SparseTensor) -> SparseTensor:
14
  out = self.conv(x.data)
15
  new_shape = [x.shape[0], self.conv.out_channels]
16
- out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
 
 
 
 
17
  out._spatial_cache = x._spatial_cache
18
- out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)])
 
 
19
  return out
20
 
21
 
22
  class SparseInverseConv3d(nn.Module):
23
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
 
 
 
 
 
 
 
 
 
24
  super(SparseInverseConv3d, self).__init__()
25
- if 'torchsparse' not in globals():
26
  import torchsparse
27
- self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True)
 
 
 
 
 
 
 
 
 
28
 
29
  def forward(self, x: SparseTensor) -> SparseTensor:
30
- out = self.conv(x.data)
31
  new_shape = [x.shape[0], self.conv.out_channels]
32
- out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
 
 
 
 
33
  out._spatial_cache = x._spatial_cache
34
- out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)])
 
 
35
  return out
36
-
37
-
38
-
 
4
 
5
 
6
  class SparseConv3d(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_channels,
10
+ out_channels,
11
+ kernel_size,
12
+ stride=1,
13
+ dilation=1,
14
+ bias=True,
15
+ indice_key=None,
16
+ ):
17
  super(SparseConv3d, self).__init__()
18
+ if "torchsparse" not in globals():
19
  import torchsparse
20
+ self.conv = torchsparse.nn.Conv3d(
21
+ in_channels, out_channels, kernel_size, stride, 0, dilation, bias
22
+ )
23
 
24
  def forward(self, x: SparseTensor) -> SparseTensor:
25
  out = self.conv(x.data)
26
  new_shape = [x.shape[0], self.conv.out_channels]
27
+ out = SparseTensor(
28
+ out,
29
+ shape=torch.Size(new_shape),
30
+ layout=x.layout if all(s == 1 for s in self.conv.stride) else None,
31
+ )
32
  out._spatial_cache = x._spatial_cache
33
+ out._scale = tuple(
34
+ [s * stride for s, stride in zip(x._scale, self.conv.stride)]
35
+ )
36
  return out
37
 
38
 
39
  class SparseInverseConv3d(nn.Module):
40
+ def __init__(
41
+ self,
42
+ in_channels,
43
+ out_channels,
44
+ kernel_size,
45
+ stride=1,
46
+ dilation=1,
47
+ bias=True,
48
+ indice_key=None,
49
+ ):
50
  super(SparseInverseConv3d, self).__init__()
51
+ if "torchsparse" not in globals():
52
  import torchsparse
53
+ self.conv = torchsparse.nn.Conv3d(
54
+ in_channels,
55
+ out_channels,
56
+ kernel_size,
57
+ stride,
58
+ 0,
59
+ dilation,
60
+ bias,
61
+ transposed=True,
62
+ )
63
 
64
  def forward(self, x: SparseTensor) -> SparseTensor:
65
+ out = self.conv(x.data)
66
  new_shape = [x.shape[0], self.conv.out_channels]
67
+ out = SparseTensor(
68
+ out,
69
+ shape=torch.Size(new_shape),
70
+ layout=x.layout if all(s == 1 for s in self.conv.stride) else None,
71
+ )
72
  out._spatial_cache = x._spatial_cache
73
+ out._scale = tuple(
74
+ [s // stride for s, stride in zip(x._scale, self.conv.stride)]
75
+ )
76
  return out
 
 
 
trellis/modules/sparse/linear.py CHANGED
@@ -2,9 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from . import SparseTensor
4
 
5
- __all__ = [
6
- 'SparseLinear'
7
- ]
8
 
9
 
10
  class SparseLinear(nn.Linear):
 
2
  import torch.nn as nn
3
  from . import SparseTensor
4
 
5
+ __all__ = ["SparseLinear"]
 
 
6
 
7
 
8
  class SparseLinear(nn.Linear):
trellis/modules/sparse/nonlinearity.py CHANGED
@@ -2,18 +2,13 @@ import torch
2
  import torch.nn as nn
3
  from . import SparseTensor
4
 
5
- __all__ = [
6
- 'SparseReLU',
7
- 'SparseSiLU',
8
- 'SparseGELU',
9
- 'SparseActivation'
10
- ]
11
 
12
 
13
  class SparseReLU(nn.ReLU):
14
  def forward(self, input: SparseTensor) -> SparseTensor:
15
  return input.replace(super().forward(input.feats))
16
-
17
 
18
  class SparseSiLU(nn.SiLU):
19
  def forward(self, input: SparseTensor) -> SparseTensor:
@@ -32,4 +27,3 @@ class SparseActivation(nn.Module):
32
 
33
  def forward(self, input: SparseTensor) -> SparseTensor:
34
  return input.replace(self.activation(input.feats))
35
-
 
2
  import torch.nn as nn
3
  from . import SparseTensor
4
 
5
+ __all__ = ["SparseReLU", "SparseSiLU", "SparseGELU", "SparseActivation"]
 
 
 
 
 
6
 
7
 
8
  class SparseReLU(nn.ReLU):
9
  def forward(self, input: SparseTensor) -> SparseTensor:
10
  return input.replace(super().forward(input.feats))
11
+
12
 
13
  class SparseSiLU(nn.SiLU):
14
  def forward(self, input: SparseTensor) -> SparseTensor:
 
27
 
28
  def forward(self, input: SparseTensor) -> SparseTensor:
29
  return input.replace(self.activation(input.feats))
 
trellis/modules/sparse/norm.py CHANGED
@@ -4,10 +4,10 @@ from . import SparseTensor
4
  from . import DEBUG
5
 
6
  __all__ = [
7
- 'SparseGroupNorm',
8
- 'SparseLayerNorm',
9
- 'SparseGroupNorm32',
10
- 'SparseLayerNorm32',
11
  ]
12
 
13
 
@@ -19,7 +19,9 @@ class SparseGroupNorm(nn.GroupNorm):
19
  nfeats = torch.zeros_like(input.feats)
20
  for k in range(input.shape[0]):
21
  if DEBUG:
22
- assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch"
 
 
23
  bfeats = input.feats[input.layout[k]]
24
  bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
25
  bfeats = super().forward(bfeats)
@@ -47,12 +49,15 @@ class SparseGroupNorm32(SparseGroupNorm):
47
  """
48
  A GroupNorm layer that converts to float32 before the forward pass.
49
  """
 
50
  def forward(self, x: SparseTensor) -> SparseTensor:
51
  return super().forward(x.float()).type(x.dtype)
52
 
 
53
  class SparseLayerNorm32(SparseLayerNorm):
54
  """
55
  A LayerNorm layer that converts to float32 before the forward pass.
56
  """
 
57
  def forward(self, x: SparseTensor) -> SparseTensor:
58
  return super().forward(x.float()).type(x.dtype)
 
4
  from . import DEBUG
5
 
6
  __all__ = [
7
+ "SparseGroupNorm",
8
+ "SparseLayerNorm",
9
+ "SparseGroupNorm32",
10
+ "SparseLayerNorm32",
11
  ]
12
 
13
 
 
19
  nfeats = torch.zeros_like(input.feats)
20
  for k in range(input.shape[0]):
21
  if DEBUG:
22
+ assert (
23
+ input.coords[input.layout[k], 0] == k
24
+ ).all(), f"SparseGroupNorm: batch index mismatch"
25
  bfeats = input.feats[input.layout[k]]
26
  bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
27
  bfeats = super().forward(bfeats)
 
49
  """
50
  A GroupNorm layer that converts to float32 before the forward pass.
51
  """
52
+
53
  def forward(self, x: SparseTensor) -> SparseTensor:
54
  return super().forward(x.float()).type(x.dtype)
55
 
56
+
57
  class SparseLayerNorm32(SparseLayerNorm):
58
  """
59
  A LayerNorm layer that converts to float32 before the forward pass.
60
  """
61
+
62
  def forward(self, x: SparseTensor) -> SparseTensor:
63
  return super().forward(x.float()).type(x.dtype)
trellis/modules/sparse/spatial.py CHANGED
@@ -3,11 +3,7 @@ import torch
3
  import torch.nn as nn
4
  from . import SparseTensor
5
 
6
- __all__ = [
7
- 'SparseDownsample',
8
- 'SparseUpsample',
9
- 'SparseSubdivide'
10
- ]
11
 
12
 
13
  class SparseDownsample(nn.Module):
@@ -15,6 +11,7 @@ class SparseDownsample(nn.Module):
15
  Downsample a sparse tensor by a factor of `factor`.
16
  Implemented as average pooling.
17
  """
 
18
  def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]):
19
  super(SparseDownsample, self).__init__()
20
  self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
@@ -22,36 +19,47 @@ class SparseDownsample(nn.Module):
22
  def forward(self, input: SparseTensor) -> SparseTensor:
23
  DIM = input.coords.shape[-1] - 1
24
  factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
25
- assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.'
 
 
26
 
27
  coord = list(input.coords.unbind(dim=-1))
28
  for i, f in enumerate(factor):
29
- coord[i+1] = coord[i+1] // f
30
 
31
- MAX = [coord[i+1].max().item() + 1 for i in range(DIM)]
32
  OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
33
  code = sum([c * o for c, o in zip(coord, OFFSET)])
34
  code, idx = code.unique(return_inverse=True)
35
 
36
  new_feats = torch.scatter_reduce(
37
- torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype),
 
 
 
 
 
38
  dim=0,
39
  index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]),
40
  src=input.feats,
41
- reduce='mean'
42
  )
43
  new_coords = torch.stack(
44
- [code // OFFSET[0]] +
45
- [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
46
- dim=-1
 
 
 
 
 
47
  )
48
- out = SparseTensor(new_feats, new_coords, input.shape,)
49
  out._scale = tuple([s // f for s, f in zip(input._scale, factor)])
50
  out._spatial_cache = input._spatial_cache
51
 
52
- out.register_spatial_cache(f'upsample_{factor}_coords', input.coords)
53
- out.register_spatial_cache(f'upsample_{factor}_layout', input.layout)
54
- out.register_spatial_cache(f'upsample_{factor}_idx', idx)
55
 
56
  return out
57
 
@@ -61,6 +69,7 @@ class SparseUpsample(nn.Module):
61
  Upsample a sparse tensor by a factor of `factor`.
62
  Implemented as nearest neighbor interpolation.
63
  """
 
64
  def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]):
65
  super(SparseUpsample, self).__init__()
66
  self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
@@ -68,24 +77,30 @@ class SparseUpsample(nn.Module):
68
  def forward(self, input: SparseTensor) -> SparseTensor:
69
  DIM = input.coords.shape[-1] - 1
70
  factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
71
- assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.'
 
 
72
 
73
- new_coords = input.get_spatial_cache(f'upsample_{factor}_coords')
74
- new_layout = input.get_spatial_cache(f'upsample_{factor}_layout')
75
- idx = input.get_spatial_cache(f'upsample_{factor}_idx')
76
  if any([x is None for x in [new_coords, new_layout, idx]]):
77
- raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.')
 
 
78
  new_feats = input.feats[idx]
79
  out = SparseTensor(new_feats, new_coords, input.shape, new_layout)
80
  out._scale = tuple([s * f for s, f in zip(input._scale, factor)])
81
  out._spatial_cache = input._spatial_cache
82
  return out
83
-
 
84
  class SparseSubdivide(nn.Module):
85
  """
86
  Upsample a sparse tensor by a factor of `factor`.
87
  Implemented as nearest neighbor interpolation.
88
  """
 
89
  def __init__(self):
90
  super(SparseSubdivide, self).__init__()
91
 
@@ -96,15 +111,20 @@ class SparseSubdivide(nn.Module):
96
  n_coords = torch.nonzero(n_cube)
97
  n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
98
  factor = n_coords.shape[0]
99
- assert factor == 2 ** DIM
100
  # print(n_coords.shape)
101
  new_coords = input.coords.clone()
102
  new_coords[:, 1:] *= 2
103
- new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype)
104
-
105
- new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:])
106
- out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape)
 
 
 
 
 
 
107
  out._scale = input._scale * 2
108
  out._spatial_cache = input._spatial_cache
109
  return out
110
-
 
3
  import torch.nn as nn
4
  from . import SparseTensor
5
 
6
+ __all__ = ["SparseDownsample", "SparseUpsample", "SparseSubdivide"]
 
 
 
 
7
 
8
 
9
  class SparseDownsample(nn.Module):
 
11
  Downsample a sparse tensor by a factor of `factor`.
12
  Implemented as average pooling.
13
  """
14
+
15
  def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]):
16
  super(SparseDownsample, self).__init__()
17
  self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
 
19
  def forward(self, input: SparseTensor) -> SparseTensor:
20
  DIM = input.coords.shape[-1] - 1
21
  factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
22
+ assert DIM == len(
23
+ factor
24
+ ), "Input coordinates must have the same dimension as the downsample factor."
25
 
26
  coord = list(input.coords.unbind(dim=-1))
27
  for i, f in enumerate(factor):
28
+ coord[i + 1] = coord[i + 1] // f
29
 
30
+ MAX = [coord[i + 1].max().item() + 1 for i in range(DIM)]
31
  OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
32
  code = sum([c * o for c, o in zip(coord, OFFSET)])
33
  code, idx = code.unique(return_inverse=True)
34
 
35
  new_feats = torch.scatter_reduce(
36
+ torch.zeros(
37
+ code.shape[0],
38
+ input.feats.shape[1],
39
+ device=input.feats.device,
40
+ dtype=input.feats.dtype,
41
+ ),
42
  dim=0,
43
  index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]),
44
  src=input.feats,
45
+ reduce="mean",
46
  )
47
  new_coords = torch.stack(
48
+ [code // OFFSET[0]]
49
+ + [(code // OFFSET[i + 1]) % MAX[i] for i in range(DIM)],
50
+ dim=-1,
51
+ )
52
+ out = SparseTensor(
53
+ new_feats,
54
+ new_coords,
55
+ input.shape,
56
  )
 
57
  out._scale = tuple([s // f for s, f in zip(input._scale, factor)])
58
  out._spatial_cache = input._spatial_cache
59
 
60
+ out.register_spatial_cache(f"upsample_{factor}_coords", input.coords)
61
+ out.register_spatial_cache(f"upsample_{factor}_layout", input.layout)
62
+ out.register_spatial_cache(f"upsample_{factor}_idx", idx)
63
 
64
  return out
65
 
 
69
  Upsample a sparse tensor by a factor of `factor`.
70
  Implemented as nearest neighbor interpolation.
71
  """
72
+
73
  def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]):
74
  super(SparseUpsample, self).__init__()
75
  self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
 
77
  def forward(self, input: SparseTensor) -> SparseTensor:
78
  DIM = input.coords.shape[-1] - 1
79
  factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
80
+ assert DIM == len(
81
+ factor
82
+ ), "Input coordinates must have the same dimension as the upsample factor."
83
 
84
+ new_coords = input.get_spatial_cache(f"upsample_{factor}_coords")
85
+ new_layout = input.get_spatial_cache(f"upsample_{factor}_layout")
86
+ idx = input.get_spatial_cache(f"upsample_{factor}_idx")
87
  if any([x is None for x in [new_coords, new_layout, idx]]):
88
+ raise ValueError(
89
+ "Upsample cache not found. SparseUpsample must be paired with SparseDownsample."
90
+ )
91
  new_feats = input.feats[idx]
92
  out = SparseTensor(new_feats, new_coords, input.shape, new_layout)
93
  out._scale = tuple([s * f for s, f in zip(input._scale, factor)])
94
  out._spatial_cache = input._spatial_cache
95
  return out
96
+
97
+
98
  class SparseSubdivide(nn.Module):
99
  """
100
  Upsample a sparse tensor by a factor of `factor`.
101
  Implemented as nearest neighbor interpolation.
102
  """
103
+
104
  def __init__(self):
105
  super(SparseSubdivide, self).__init__()
106
 
 
111
  n_coords = torch.nonzero(n_cube)
112
  n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
113
  factor = n_coords.shape[0]
114
+ assert factor == 2**DIM
115
  # print(n_coords.shape)
116
  new_coords = input.coords.clone()
117
  new_coords[:, 1:] *= 2
118
+ new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(
119
+ new_coords.dtype
120
+ )
121
+
122
+ new_feats = input.feats.unsqueeze(1).expand(
123
+ input.feats.shape[0], factor, *input.feats.shape[1:]
124
+ )
125
+ out = SparseTensor(
126
+ new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape
127
+ )
128
  out._scale = input._scale * 2
129
  out._spatial_cache = input._spatial_cache
130
  return out
 
trellis/modules/sparse/transformer/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
  from .blocks import *
2
- from .modulated import *
 
1
  from .blocks import *
2
+ from .modulated import *
trellis/modules/sparse/transformer/blocks.py CHANGED
@@ -25,12 +25,15 @@ class SparseTransformerBlock(nn.Module):
25
  """
26
  Sparse Transformer block (MSA + FFN).
27
  """
 
28
  def __init__(
29
  self,
30
  channels: int,
31
  num_heads: int,
32
  mlp_ratio: float = 4.0,
33
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
 
 
34
  window_size: Optional[int] = None,
35
  shift_sequence: Optional[int] = None,
36
  shift_window: Optional[Tuple[int, int, int]] = None,
@@ -73,7 +76,9 @@ class SparseTransformerBlock(nn.Module):
73
 
74
  def forward(self, x: SparseTensor) -> SparseTensor:
75
  if self.use_checkpoint:
76
- return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
 
 
77
  else:
78
  return self._forward(x)
79
 
@@ -82,13 +87,16 @@ class SparseTransformerCrossBlock(nn.Module):
82
  """
83
  Sparse Transformer cross-attention block (MSA + MCA + FFN).
84
  """
 
85
  def __init__(
86
  self,
87
  channels: int,
88
  ctx_channels: int,
89
  num_heads: int,
90
  mlp_ratio: float = 4.0,
91
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
 
 
92
  window_size: Optional[int] = None,
93
  shift_sequence: Optional[int] = None,
94
  shift_window: Optional[Tuple[int, int, int]] = None,
@@ -146,6 +154,8 @@ class SparseTransformerCrossBlock(nn.Module):
146
 
147
  def forward(self, x: SparseTensor, context: torch.Tensor):
148
  if self.use_checkpoint:
149
- return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
 
 
150
  else:
151
  return self._forward(x, context)
 
25
  """
26
  Sparse Transformer block (MSA + FFN).
27
  """
28
+
29
  def __init__(
30
  self,
31
  channels: int,
32
  num_heads: int,
33
  mlp_ratio: float = 4.0,
34
+ attn_mode: Literal[
35
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
36
+ ] = "full",
37
  window_size: Optional[int] = None,
38
  shift_sequence: Optional[int] = None,
39
  shift_window: Optional[Tuple[int, int, int]] = None,
 
76
 
77
  def forward(self, x: SparseTensor) -> SparseTensor:
78
  if self.use_checkpoint:
79
+ return torch.utils.checkpoint.checkpoint(
80
+ self._forward, x, use_reentrant=False
81
+ )
82
  else:
83
  return self._forward(x)
84
 
 
87
  """
88
  Sparse Transformer cross-attention block (MSA + MCA + FFN).
89
  """
90
+
91
  def __init__(
92
  self,
93
  channels: int,
94
  ctx_channels: int,
95
  num_heads: int,
96
  mlp_ratio: float = 4.0,
97
+ attn_mode: Literal[
98
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
99
+ ] = "full",
100
  window_size: Optional[int] = None,
101
  shift_sequence: Optional[int] = None,
102
  shift_window: Optional[Tuple[int, int, int]] = None,
 
154
 
155
  def forward(self, x: SparseTensor, context: torch.Tensor):
156
  if self.use_checkpoint:
157
+ return torch.utils.checkpoint.checkpoint(
158
+ self._forward, x, context, use_reentrant=False
159
+ )
160
  else:
161
  return self._forward(x, context)
trellis/modules/sparse/transformer/modulated.py CHANGED
@@ -11,12 +11,15 @@ class ModulatedSparseTransformerBlock(nn.Module):
11
  """
12
  Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
13
  """
 
14
  def __init__(
15
  self,
16
  channels: int,
17
  num_heads: int,
18
  mlp_ratio: float = 4.0,
19
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
 
 
20
  window_size: Optional[int] = None,
21
  shift_sequence: Optional[int] = None,
22
  shift_window: Optional[Tuple[int, int, int]] = None,
@@ -50,15 +53,23 @@ class ModulatedSparseTransformerBlock(nn.Module):
50
  )
51
  if not share_mod:
52
  self.adaLN_modulation = nn.Sequential(
53
- nn.SiLU(),
54
- nn.Linear(channels, 6 * channels, bias=True)
55
  )
56
 
57
  def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
58
  if self.share_mod:
59
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
 
 
60
  else:
61
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
 
 
 
 
 
 
 
62
  h = x.replace(self.norm1(x.feats))
63
  h = h * (1 + scale_msa) + shift_msa
64
  h = self.attn(h)
@@ -73,7 +84,9 @@ class ModulatedSparseTransformerBlock(nn.Module):
73
 
74
  def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
75
  if self.use_checkpoint:
76
- return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
 
 
77
  else:
78
  return self._forward(x, mod)
79
 
@@ -82,13 +95,16 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
82
  """
83
  Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
84
  """
 
85
  def __init__(
86
  self,
87
  channels: int,
88
  ctx_channels: int,
89
  num_heads: int,
90
  mlp_ratio: float = 4.0,
91
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
 
 
92
  window_size: Optional[int] = None,
93
  shift_sequence: Optional[int] = None,
94
  shift_window: Optional[Tuple[int, int, int]] = None,
@@ -99,7 +115,6 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
99
  qk_rms_norm_cross: bool = False,
100
  qkv_bias: bool = True,
101
  share_mod: bool = False,
102
-
103
  ):
104
  super().__init__()
105
  self.use_checkpoint = use_checkpoint
@@ -135,15 +150,25 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
135
  )
136
  if not share_mod:
137
  self.adaLN_modulation = nn.Sequential(
138
- nn.SiLU(),
139
- nn.Linear(channels, 6 * channels, bias=True)
140
  )
141
 
142
- def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
 
 
143
  if self.share_mod:
144
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
 
 
145
  else:
146
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
 
 
 
 
 
 
 
147
  h = x.replace(self.norm1(x.feats))
148
  h = h * (1 + scale_msa) + shift_msa
149
  h = self.self_attn(h)
@@ -159,8 +184,12 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
159
  x = x + h
160
  return x
161
 
162
- def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
 
 
163
  if self.use_checkpoint:
164
- return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
 
 
165
  else:
166
  return self._forward(x, mod, context)
 
11
  """
12
  Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
13
  """
14
+
15
  def __init__(
16
  self,
17
  channels: int,
18
  num_heads: int,
19
  mlp_ratio: float = 4.0,
20
+ attn_mode: Literal[
21
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
22
+ ] = "full",
23
  window_size: Optional[int] = None,
24
  shift_sequence: Optional[int] = None,
25
  shift_window: Optional[Tuple[int, int, int]] = None,
 
53
  )
54
  if not share_mod:
55
  self.adaLN_modulation = nn.Sequential(
56
+ nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True)
 
57
  )
58
 
59
  def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
60
  if self.share_mod:
61
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(
62
+ 6, dim=1
63
+ )
64
  else:
65
+ (
66
+ shift_msa,
67
+ scale_msa,
68
+ gate_msa,
69
+ shift_mlp,
70
+ scale_mlp,
71
+ gate_mlp,
72
+ ) = self.adaLN_modulation(mod).chunk(6, dim=1)
73
  h = x.replace(self.norm1(x.feats))
74
  h = h * (1 + scale_msa) + shift_msa
75
  h = self.attn(h)
 
84
 
85
  def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
86
  if self.use_checkpoint:
87
+ return torch.utils.checkpoint.checkpoint(
88
+ self._forward, x, mod, use_reentrant=False
89
+ )
90
  else:
91
  return self._forward(x, mod)
92
 
 
95
  """
96
  Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
97
  """
98
+
99
  def __init__(
100
  self,
101
  channels: int,
102
  ctx_channels: int,
103
  num_heads: int,
104
  mlp_ratio: float = 4.0,
105
+ attn_mode: Literal[
106
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
107
+ ] = "full",
108
  window_size: Optional[int] = None,
109
  shift_sequence: Optional[int] = None,
110
  shift_window: Optional[Tuple[int, int, int]] = None,
 
115
  qk_rms_norm_cross: bool = False,
116
  qkv_bias: bool = True,
117
  share_mod: bool = False,
 
118
  ):
119
  super().__init__()
120
  self.use_checkpoint = use_checkpoint
 
150
  )
151
  if not share_mod:
152
  self.adaLN_modulation = nn.Sequential(
153
+ nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True)
 
154
  )
155
 
156
+ def _forward(
157
+ self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor
158
+ ) -> SparseTensor:
159
  if self.share_mod:
160
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(
161
+ 6, dim=1
162
+ )
163
  else:
164
+ (
165
+ shift_msa,
166
+ scale_msa,
167
+ gate_msa,
168
+ shift_mlp,
169
+ scale_mlp,
170
+ gate_mlp,
171
+ ) = self.adaLN_modulation(mod).chunk(6, dim=1)
172
  h = x.replace(self.norm1(x.feats))
173
  h = h * (1 + scale_msa) + shift_msa
174
  h = self.self_attn(h)
 
184
  x = x + h
185
  return x
186
 
187
+ def forward(
188
+ self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor
189
+ ) -> SparseTensor:
190
  if self.use_checkpoint:
191
+ return torch.utils.checkpoint.checkpoint(
192
+ self._forward, x, mod, context, use_reentrant=False
193
+ )
194
  else:
195
  return self._forward(x, mod, context)