Spaces:
Running
Running
File size: 4,077 Bytes
ee133a5 904cde3 a596294 ee133a5 5a3d50d 904cde3 5a3d50d 904cde3 5a3d50d ee133a5 5a3d50d 904cde3 ee133a5 5a3d50d ee133a5 5a3d50d ee133a5 a596294 5a3d50d ee133a5 a596294 ee133a5 a596294 ee133a5 a596294 5a3d50d ee133a5 a596294 ee133a5 a596294 5a3d50d a596294 5a3d50d a596294 5a3d50d a596294 5a3d50d a596294 ee133a5 a596294 5a3d50d a596294 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gradio as gr
import os
import torch
import torch
from PIL import Image
from diffusers import (
AutoencoderKL,
)
from transformers import CLIPTextModel, CLIPTokenizer
from depthmaster import DepthMasterPipeline
from depthmaster.modules.unet_2d_condition import UNet2DConditionModel
def load_example(example_image):
# 返回选中的图片
return example_image
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "zysong212/DepthMaster" # Replace to the model you would like to use
torch_dtype = torch.float32
vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder="vae", torch_dtype=torch_dtype, allow_pickle=False)
unet = UNet2DConditionModel.from_pretrained(model_repo_id, subfolder="unet", torch_dtype=torch_dtype, allow_pickle=False)
text_encoder = CLIPTextModel.from_pretrained(model_repo_id, subfolder="text_encoder", torch_dtype=torch_dtype)
tokenizer = CLIPTokenizer.from_pretrained(model_repo_id, subfolder="tokenizer", torch_dtype=torch_dtype)
pipe = DepthMasterPipeline(vae=vae, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer)
try:
pipe.enable_xformers_memory_efficient_attention()
except ImportError:
pass # run without xformers
pipe = pipe.to(device)
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
input_image,
progress=gr.Progress(track_tqdm=True),
):
pipe_out = pipe(
input_image,
processing_res=768,
match_input_res=True,
batch_size=1,
color_map="Spectral",
show_progress_bar=True,
resample_method="bilinear",
)
# depth_pred: np.ndarray = pipe_out.depth_np
depth_colored: Image.Image = pipe_out.depth_colored
return depth_colored
# 默认图像路径
example_images = [
"wild_example/000000000776.jpg",
"wild_example/800x.jpg",
"wild_example/000000055950.jpg",
"wild_example/53441037037_c2cbd91ad2_k.jpg",
"wild_example/53501906161_6109e3da29_b.jpg",
"wild_example/m_1e31af1c.jpg",
"wild_example/sg-11134201-7rd5x-lvlh48byidbqca.jpg"
]
# css = """
# #col-container {
# margin: 0 auto;
# max-width: 640px;
# }
# #example-gallery {
# height: 80px; /* 设置缩略图高度 */
# width: auto; /* 保持宽高比 */
# margin: 0 auto; /* 图片间距 */
# cursor: pointer; /* 鼠标指针变为手型 */
# }
# """
css = """
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
#download {
height: 62px;
}
"""
title = "# DepthMaster"
description = """**Official demo for DepthMaster**.
Please refer to our [paper](https://arxiv.org/abs/2501.02576), [project page](https://indu1ge.github.io/DepthMaster_page/), and [github](https://github.com/indu1ge/DepthMaster) for more details."""
with gr.Blocks(css=css) as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(" ### Depth Estimation with DepthMaster.")
# with gr.Column(elem_id="col-container"):
# gr.Markdown(" # Depth Estimation")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil", elem_id="img-display-input")
with gr.Column():
# depth_img_slider = ImageSlider(label="Depth Map with Slider View", elem_id="img-display-output", position=0.5)
depth_map = gr.Image(label="Depth Map with Slider View", type="pil", interactive=False, elem_id="depth-map")
# 计算按钮
compute_button = gr.Button(value="Compute Depth")
# 设置计算按钮的回调
compute_button.click(
fn=infer, # 回调函数
inputs=[input_image], # 输入
outputs=[depth_map] # 输出
)
example_files = os.listdir('wild_example')
example_files.sort()
example_files = [os.path.join('wild_example', filename) for filename in example_files]
examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_map], fn=infer)
# 启动 Gradio 应用
demo.queue().launch(share=True)
|