depth-compare / gradio_app.py
pablovela5620's picture
Upload gradio_app.py with huggingface_hub
d695f4e verified
raw
history blame
6.36 kB
import gradio as gr
import numpy as np
import torch
from monopriors.relative_depth_models import (
RelativeDepthPrediction,
get_relative_predictor,
RELATIVE_PREDICTORS,
)
from monopriors.relative_depth_models.base_relative_depth import BaseRelativePredictor
from monopriors.rr_logging_utils import (
log_relative_pred,
create_relative_depth_blueprint,
)
import rerun as rr
from gradio_rerun import Rerun
from pathlib import Path
from typing import Literal, get_args
import gc
from jaxtyping import UInt8
import mmcv
try:
import spaces # type: ignore
IN_SPACES = True
except ImportError:
print("Not running on Zero")
IN_SPACES = False
title = "# Depth Comparison"
description1 = """Demo to help compare different depth models. Including both Scale | Shift Invariant and Metric Depth types."""
description2 = """Invariant models mean they have no true scale and are only relative, where as Metric models have a true scale and are absolute (meters)."""
description3 = """Checkout the [Github Repo](https://github.com/pablovela5620/monoprior) [![GitHub Repo stars](https://img.shields.io/github/stars/pablovela5620/monoprior)](https://github.com/pablovela5620/monoprior)"""
model_load_status: str = "Models loaded and ready to use!"
DEVICE: Literal["cuda"] | Literal["cpu"] = (
"cuda" if torch.cuda.is_available() else "cpu"
)
MODELS_TO_SKIP: list[str] = []
if gr.NO_RELOAD:
MODEL_1 = get_relative_predictor("DepthAnythingV2Predictor")(device=DEVICE)
MODEL_2 = get_relative_predictor("UniDepthRelativePredictor")(device=DEVICE)
def predict_depth(
model: BaseRelativePredictor, rgb: UInt8[np.ndarray, "h w 3"]
) -> RelativeDepthPrediction:
model.set_model_device(device=DEVICE)
relative_pred: RelativeDepthPrediction = model(rgb, None)
return relative_pred
if IN_SPACES:
predict_depth = spaces.GPU(predict_depth)
# remove any model that fails on zerogpu spaces
MODELS_TO_SKIP.extend(["Metric3DRelativePredictor"])
def load_models(
model_1: RELATIVE_PREDICTORS,
model_2: RELATIVE_PREDICTORS,
progress=gr.Progress(),
) -> str:
models: list[int] = [model_1, model_2]
# check if the models are in the list of models to skip
if any(model in MODELS_TO_SKIP for model in models):
raise gr.Error(
f"Model not supported on ZeroGPU, please try another model: {MODELS_TO_SKIP}"
)
global MODEL_1, MODEL_2
# delete the previous models and clear gpu memory
if "MODEL_1" in globals():
del MODEL_1
if "MODEL_2" in globals():
del MODEL_2
torch.cuda.empty_cache()
gc.collect()
progress(0, desc="Loading Models please wait...")
loaded_models = []
for model in models:
loaded_models.append(get_relative_predictor(model)(device=DEVICE))
progress(0.5, desc=f"Loaded {model}")
progress(1, desc="Models Loaded")
MODEL_1, MODEL_2 = loaded_models
return model_load_status
@rr.thread_local_stream("depth")
def on_submit(rgb: UInt8[np.ndarray, "h w 3"], remove_flying_pixels: bool):
stream: rr.BinaryStream = rr.binary_stream()
models_list = [MODEL_1, MODEL_2]
blueprint = create_relative_depth_blueprint(models_list)
rr.send_blueprint(blueprint)
# resize the image to have a max dim of 1024
max_dim = 1024
current_dim = max(rgb.shape[0], rgb.shape[1])
if current_dim > max_dim:
scale_factor = max_dim / current_dim
rgb = mmcv.imrescale(img=rgb, scale=scale_factor)
try:
for model in models_list:
# get the name of the model
parent_log_path = Path(f"{model.__class__.__name__}")
rr.log(f"{parent_log_path}", rr.ViewCoordinates.RDF, timeless=True)
relative_pred: RelativeDepthPrediction = predict_depth(model, rgb)
log_relative_pred(
parent_log_path=parent_log_path,
relative_pred=relative_pred,
rgb_hw3=rgb,
remove_flying_pixels=remove_flying_pixels,
)
yield stream.read()
except Exception as e:
raise gr.Error(f"Error with model {model.__class__.__name__}: {e}")
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description1)
gr.Markdown(description2)
gr.Markdown(description3)
gr.Markdown("### Depth Prediction demo")
with gr.Row():
input_image = gr.Image(
label="Input Image",
type="numpy",
height=300,
)
with gr.Column():
gr.Radio(
choices=["Scale | Shift Invariant", "Metric (TODO)"],
label="Depth Type",
value="Scale | Shift Invariant",
interactive=True,
)
remove_flying_pixels = gr.Checkbox(
label="Remove Flying Pixels",
value=True,
interactive=True,
)
with gr.Row():
model_1_dropdown = gr.Dropdown(
choices=list(get_args(RELATIVE_PREDICTORS)),
label="Model1",
value="DepthAnythingV2Predictor",
)
model_2_dropdown = gr.Dropdown(
choices=list(get_args(RELATIVE_PREDICTORS)),
label="Model2",
value="UniDepthRelativePredictor",
)
model_status = gr.Textbox(
label="Model Status",
value=model_load_status,
interactive=False,
)
with gr.Row():
submit = gr.Button(value="Compute Depth")
load_models_btn = gr.Button(value="Load Models")
rr_viewer = Rerun(streaming=True, height=800)
submit.click(
on_submit,
inputs=[input_image, remove_flying_pixels],
outputs=[rr_viewer],
)
load_models_btn.click(
load_models,
inputs=[model_1_dropdown, model_2_dropdown],
outputs=[model_status],
)
examples_paths = Path("examples").glob("*.jpeg")
examples_list = sorted([str(path) for path in examples_paths])
examples = gr.Examples(
examples=examples_list,
inputs=[input_image],
outputs=[rr_viewer],
fn=on_submit,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()