Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import numpy as np | |
import torch | |
from monopriors.relative_depth_models import ( | |
RelativeDepthPrediction, | |
get_relative_predictor, | |
RELATIVE_PREDICTORS, | |
) | |
from monopriors.relative_depth_models.base_relative_depth import BaseRelativePredictor | |
from monopriors.rr_logging_utils import ( | |
log_relative_pred, | |
create_relative_depth_blueprint, | |
) | |
import rerun as rr | |
from gradio_rerun import Rerun | |
from pathlib import Path | |
from typing import Literal, get_args | |
import gc | |
from jaxtyping import UInt8 | |
import mmcv | |
try: | |
import spaces # type: ignore | |
IN_SPACES = True | |
except ImportError: | |
print("Not running on Zero") | |
IN_SPACES = False | |
title = "# Depth Comparison" | |
description1 = """Demo to help compare different depth models. Including both Scale | Shift Invariant and Metric Depth types.""" | |
description2 = """Invariant models mean they have no true scale and are only relative, where as Metric models have a true scale and are absolute (meters).""" | |
description3 = """Checkout the [Github Repo](https://github.com/pablovela5620/monoprior) [![GitHub Repo stars](https://img.shields.io/github/stars/pablovela5620/monoprior)](https://github.com/pablovela5620/monoprior)""" | |
model_load_status: str = "Models loaded and ready to use!" | |
DEVICE: Literal["cuda"] | Literal["cpu"] = ( | |
"cuda" if torch.cuda.is_available() else "cpu" | |
) | |
MODELS_TO_SKIP: list[str] = [] | |
if gr.NO_RELOAD: | |
MODEL_1 = get_relative_predictor("DepthAnythingV2Predictor")(device=DEVICE) | |
MODEL_2 = get_relative_predictor("UniDepthRelativePredictor")(device=DEVICE) | |
def predict_depth( | |
model: BaseRelativePredictor, rgb: UInt8[np.ndarray, "h w 3"] | |
) -> RelativeDepthPrediction: | |
model.set_model_device(device=DEVICE) | |
relative_pred: RelativeDepthPrediction = model(rgb, None) | |
return relative_pred | |
if IN_SPACES: | |
predict_depth = spaces.GPU(predict_depth) | |
# remove any model that fails on zerogpu spaces | |
MODELS_TO_SKIP.extend(["Metric3DRelativePredictor"]) | |
def load_models( | |
model_1: RELATIVE_PREDICTORS, | |
model_2: RELATIVE_PREDICTORS, | |
progress=gr.Progress(), | |
) -> str: | |
models: list[int] = [model_1, model_2] | |
# check if the models are in the list of models to skip | |
if any(model in MODELS_TO_SKIP for model in models): | |
raise gr.Error( | |
f"Model not supported on ZeroGPU, please try another model: {MODELS_TO_SKIP}" | |
) | |
global MODEL_1, MODEL_2 | |
# delete the previous models and clear gpu memory | |
if "MODEL_1" in globals(): | |
del MODEL_1 | |
if "MODEL_2" in globals(): | |
del MODEL_2 | |
torch.cuda.empty_cache() | |
gc.collect() | |
progress(0, desc="Loading Models please wait...") | |
loaded_models = [] | |
for model in models: | |
loaded_models.append(get_relative_predictor(model)(device=DEVICE)) | |
progress(0.5, desc=f"Loaded {model}") | |
progress(1, desc="Models Loaded") | |
MODEL_1, MODEL_2 = loaded_models | |
return model_load_status | |
def on_submit(rgb: UInt8[np.ndarray, "h w 3"], remove_flying_pixels: bool): | |
stream: rr.BinaryStream = rr.binary_stream() | |
models_list = [MODEL_1, MODEL_2] | |
blueprint = create_relative_depth_blueprint(models_list) | |
rr.send_blueprint(blueprint) | |
# resize the image to have a max dim of 1024 | |
max_dim = 1024 | |
current_dim = max(rgb.shape[0], rgb.shape[1]) | |
if current_dim > max_dim: | |
scale_factor = max_dim / current_dim | |
rgb = mmcv.imrescale(img=rgb, scale=scale_factor) | |
try: | |
for model in models_list: | |
# get the name of the model | |
parent_log_path = Path(f"{model.__class__.__name__}") | |
rr.log(f"{parent_log_path}", rr.ViewCoordinates.RDF, timeless=True) | |
relative_pred: RelativeDepthPrediction = predict_depth(model, rgb) | |
log_relative_pred( | |
parent_log_path=parent_log_path, | |
relative_pred=relative_pred, | |
rgb_hw3=rgb, | |
remove_flying_pixels=remove_flying_pixels, | |
) | |
yield stream.read() | |
except Exception as e: | |
raise gr.Error(f"Error with model {model.__class__.__name__}: {e}") | |
with gr.Blocks() as demo: | |
gr.Markdown(title) | |
gr.Markdown(description1) | |
gr.Markdown(description2) | |
gr.Markdown(description3) | |
gr.Markdown("### Depth Prediction demo") | |
with gr.Row(): | |
input_image = gr.Image( | |
label="Input Image", | |
type="numpy", | |
height=300, | |
) | |
with gr.Column(): | |
gr.Radio( | |
choices=["Scale | Shift Invariant", "Metric (TODO)"], | |
label="Depth Type", | |
value="Scale | Shift Invariant", | |
interactive=True, | |
) | |
remove_flying_pixels = gr.Checkbox( | |
label="Remove Flying Pixels", | |
value=True, | |
interactive=True, | |
) | |
with gr.Row(): | |
model_1_dropdown = gr.Dropdown( | |
choices=list(get_args(RELATIVE_PREDICTORS)), | |
label="Model1", | |
value="DepthAnythingV2Predictor", | |
) | |
model_2_dropdown = gr.Dropdown( | |
choices=list(get_args(RELATIVE_PREDICTORS)), | |
label="Model2", | |
value="UniDepthRelativePredictor", | |
) | |
model_status = gr.Textbox( | |
label="Model Status", | |
value=model_load_status, | |
interactive=False, | |
) | |
with gr.Row(): | |
submit = gr.Button(value="Compute Depth") | |
load_models_btn = gr.Button(value="Load Models") | |
rr_viewer = Rerun(streaming=True, height=800) | |
submit.click( | |
on_submit, | |
inputs=[input_image, remove_flying_pixels], | |
outputs=[rr_viewer], | |
) | |
load_models_btn.click( | |
load_models, | |
inputs=[model_1_dropdown, model_2_dropdown], | |
outputs=[model_status], | |
) | |
examples_paths = Path("examples").glob("*.jpeg") | |
examples_list = sorted([str(path) for path in examples_paths]) | |
examples = gr.Examples( | |
examples=examples_list, | |
inputs=[input_image], | |
outputs=[rr_viewer], | |
fn=on_submit, | |
cache_examples=False, | |
) | |
if __name__ == "__main__": | |
demo.launch() | |