pablovela5620 commited on
Commit
8f47742
·
verified ·
1 Parent(s): 18f7500

Upload gradio_app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gradio_app.py +7 -191
gradio_app.py CHANGED
@@ -1,134 +1,12 @@
1
  import gradio as gr
2
- import numpy as np
3
- import torch
4
- from monopriors.relative_depth_models import (
5
- RelativeDepthPrediction,
6
- get_relative_predictor,
7
- RELATIVE_PREDICTORS,
8
- )
9
- from monopriors.relative_depth_models.base_relative_depth import BaseRelativePredictor
10
- from monopriors.rr_logging_utils import (
11
- log_relative_pred,
12
- create_relative_depth_blueprint,
13
- )
14
- import rerun as rr
15
- from gradio_rerun import Rerun
16
- from pathlib import Path
17
- from typing import Literal, get_args
18
- import gc
19
-
20
- from jaxtyping import UInt8
21
- import mmcv
22
-
23
- try:
24
- import spaces # type: ignore
25
-
26
- IN_SPACES = True
27
- except ImportError:
28
- print("Not running on Zero")
29
- IN_SPACES = False
30
 
31
  title = "# Depth Comparison"
32
  description1 = """Demo to help compare different depth models. Including both Scale | Shift Invariant and Metric Depth types."""
33
  description2 = """Invariant models mean they have no true scale and are only relative, where as Metric models have a true scale and are absolute (meters)."""
34
  description3 = """Checkout the [Github Repo](https://github.com/pablovela5620/monoprior) [![GitHub Repo stars](https://img.shields.io/github/stars/pablovela5620/monoprior)](https://github.com/pablovela5620/monoprior)"""
35
  model_load_status: str = "Models loaded and ready to use!"
36
- DEVICE: Literal["cuda"] | Literal["cpu"] = (
37
- "cuda" if torch.cuda.is_available() else "cpu"
38
- )
39
- MODELS_TO_SKIP: list[str] = []
40
- if gr.NO_RELOAD:
41
- MODEL_1 = get_relative_predictor("DepthAnythingV2Predictor")(device=DEVICE)
42
- MODEL_2 = get_relative_predictor("UniDepthRelativePredictor")(device=DEVICE)
43
-
44
-
45
- def predict_depth(
46
- model: BaseRelativePredictor, rgb: UInt8[np.ndarray, "h w 3"]
47
- ) -> RelativeDepthPrediction:
48
- model.set_model_device(device=DEVICE)
49
- relative_pred: RelativeDepthPrediction = model(rgb, None)
50
- return relative_pred
51
-
52
-
53
- if IN_SPACES:
54
- predict_depth = spaces.GPU(predict_depth)
55
- # remove any model that fails on zerogpu spaces
56
- MODELS_TO_SKIP.extend(["Metric3DRelativePredictor"])
57
-
58
-
59
- def load_models(
60
- model_1: RELATIVE_PREDICTORS,
61
- model_2: RELATIVE_PREDICTORS,
62
- progress=gr.Progress(),
63
- ) -> str:
64
- models: list[int] = [model_1, model_2]
65
- # check if the models are in the list of models to skip
66
- if any(model in MODELS_TO_SKIP for model in models):
67
- raise gr.Error(
68
- f"Model not supported on ZeroGPU, please try another model: {MODELS_TO_SKIP}"
69
- )
70
-
71
- global MODEL_1, MODEL_2
72
- # delete the previous models and clear gpu memory
73
- if "MODEL_1" in globals():
74
- del MODEL_1
75
- if "MODEL_2" in globals():
76
- del MODEL_2
77
- torch.cuda.empty_cache()
78
- gc.collect()
79
-
80
- progress(0, desc="Loading Models please wait...")
81
-
82
- loaded_models = []
83
-
84
- for model in models:
85
- loaded_models.append(get_relative_predictor(model)(device=DEVICE))
86
- progress(0.5, desc=f"Loaded {model}")
87
-
88
- progress(1, desc="Models Loaded")
89
- MODEL_1, MODEL_2 = loaded_models
90
-
91
- return model_load_status
92
-
93
-
94
- @rr.thread_local_stream("depth")
95
- def on_submit(
96
- rgb: UInt8[np.ndarray, "h w 3"],
97
- remove_flying_pixels: bool,
98
- depth_map_threshold: float,
99
- ):
100
- stream: rr.BinaryStream = rr.binary_stream()
101
- models_list = [MODEL_1, MODEL_2]
102
- blueprint = create_relative_depth_blueprint(models_list)
103
- rr.send_blueprint(blueprint)
104
-
105
- # resize the image to have a max dim of 1024
106
- max_dim = 1024
107
- current_dim = max(rgb.shape[0], rgb.shape[1])
108
- if current_dim > max_dim:
109
- scale_factor = max_dim / current_dim
110
- rgb = mmcv.imrescale(img=rgb, scale=scale_factor)
111
-
112
- try:
113
- for model in models_list:
114
- # get the name of the model
115
- parent_log_path = Path(f"{model.__class__.__name__}")
116
- rr.log(f"{parent_log_path}", rr.ViewCoordinates.RDF, timeless=True)
117
-
118
- relative_pred: RelativeDepthPrediction = predict_depth(model, rgb)
119
-
120
- log_relative_pred(
121
- parent_log_path=parent_log_path,
122
- relative_pred=relative_pred,
123
- rgb_hw3=rgb,
124
- remove_flying_pixels=remove_flying_pixels,
125
- depth_edge_threshold=depth_map_threshold,
126
- )
127
-
128
- yield stream.read()
129
- except Exception as e:
130
- raise gr.Error(f"Error with model {model.__class__.__name__}: {e}")
131
-
132
 
133
  with gr.Blocks() as demo:
134
  gr.Markdown(title)
@@ -136,72 +14,10 @@ with gr.Blocks() as demo:
136
  gr.Markdown(description2)
137
  gr.Markdown(description3)
138
  gr.Markdown("### Depth Prediction demo")
139
-
140
- with gr.Row():
141
- input_image = gr.Image(
142
- label="Input Image",
143
- type="numpy",
144
- height=300,
145
- )
146
- with gr.Column():
147
- with gr.Row():
148
- remove_flying_pixels = gr.Checkbox(
149
- label="Remove Flying Pixels",
150
- value=True,
151
- interactive=True,
152
- )
153
- depth_map_threshold = gr.Slider(
154
- label="⬇️ number == more pruning ⬆️ less pruning",
155
- minimum=0.05,
156
- maximum=0.95,
157
- step=0.05,
158
- value=0.1,
159
- )
160
- with gr.Row():
161
- model_1_dropdown = gr.Dropdown(
162
- choices=list(get_args(RELATIVE_PREDICTORS)),
163
- label="Model1",
164
- value="DepthAnythingV2Predictor",
165
- )
166
- model_2_dropdown = gr.Dropdown(
167
- choices=list(get_args(RELATIVE_PREDICTORS)),
168
- label="Model2",
169
- value="UniDepthRelativePredictor",
170
- )
171
- model_status = gr.Textbox(
172
- label="Model Status",
173
- value=model_load_status,
174
- interactive=False,
175
- )
176
-
177
- with gr.Row():
178
- submit = gr.Button(value="Compute Depth")
179
- load_models_btn = gr.Button(value="Load Models")
180
- rr_viewer = Rerun(streaming=True, height=800)
181
-
182
- submit.click(
183
- on_submit,
184
- inputs=[input_image, remove_flying_pixels, depth_map_threshold],
185
- outputs=[rr_viewer],
186
- )
187
-
188
- load_models_btn.click(
189
- load_models,
190
- inputs=[model_1_dropdown, model_2_dropdown],
191
- outputs=[model_status],
192
- )
193
-
194
- # get all jpegs in examples path
195
- examples_paths = Path("examples").glob("*.jpeg")
196
- # set the examples to be the sorted list of input parameterss (path, remove_flying_pixels, depth_map_threshold)
197
- examples_list = sorted([[str(path), True, 0.1] for path in examples_paths])
198
- examples = gr.Examples(
199
- examples=examples_list,
200
- inputs=[input_image, remove_flying_pixels, depth_map_threshold],
201
- outputs=[rr_viewer],
202
- fn=on_submit,
203
- cache_examples=False,
204
- )
205
 
206
  if __name__ == "__main__":
207
- demo.launch()
 
1
  import gradio as gr
2
+ from monopriors.gradio_ui.depth_inference_ui import depth_inference_block
3
+ from monopriors.gradio_ui.depth_compare_ui import relative_compare_block
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  title = "# Depth Comparison"
6
  description1 = """Demo to help compare different depth models. Including both Scale | Shift Invariant and Metric Depth types."""
7
  description2 = """Invariant models mean they have no true scale and are only relative, where as Metric models have a true scale and are absolute (meters)."""
8
  description3 = """Checkout the [Github Repo](https://github.com/pablovela5620/monoprior) [![GitHub Repo stars](https://img.shields.io/github/stars/pablovela5620/monoprior)](https://github.com/pablovela5620/monoprior)"""
9
  model_load_status: str = "Models loaded and ready to use!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  with gr.Blocks() as demo:
12
  gr.Markdown(title)
 
14
  gr.Markdown(description2)
15
  gr.Markdown(description3)
16
  gr.Markdown("### Depth Prediction demo")
17
+ with gr.Tab(label="Depth Comparison"):
18
+ relative_compare_block.render()
19
+ with gr.Tab(label="Depth Inference"):
20
+ depth_inference_block.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  if __name__ == "__main__":
23
+ demo.queue().launch()