Spaces:
Sleeping
Sleeping
File size: 5,951 Bytes
316f1d5 439103c 316f1d5 e3044ba 316f1d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import gradio as gr
import os
from omegaconf import OmegaConf
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer, ElucidatedImagenConfig, NullUnet, Imagen
import torch
import numpy as np
import cv2
from PIL import Image
import torchvision.transforms as T
device = "cuda" if torch.cuda.is_available() else "cpu"
exp_path = "model"
class BetterCenterCrop(T.CenterCrop):
def __call__(self, img):
h = img.shape[-2]
w = img.shape[-1]
dim = min(h, w)
return T.functional.center_crop(img, dim)
class ImageLoader:
def __init__(self, path) -> None:
self.path = path
self.all_files = os.listdir(path)
self.transform = T.Compose([
T.ToTensor(),
BetterCenterCrop((112, 112)),
T.Resize((112, 112)),
])
def get_image(self):
idx = np.random.randint(0, len(self.all_files))
img = Image.open(os.path.join(self.path, self.all_files[idx]))
return img
class Context:
def __init__(self, path, device):
self.path = path
self.config_path = os.path.join(path, "config.yaml")
self.weight_path = os.path.join(path, "merged.pt")
self.config = OmegaConf.load(self.config_path)
self.config.dataset.num_frames = int(self.config.dataset.fps * self.config.dataset.duration)
self.im_load = ImageLoader("echo_images")
unets = []
for i, (k, v) in enumerate(self.config.unets.items()):
unets.append(Unet3D(**v, lowres_cond=(i>0))) # type: ignore
imagen_klass = ElucidatedImagen if self.config.imagen.elucidated == True else Imagen
del self.config.imagen.elucidated
imagen = imagen_klass(
unets = unets,
**OmegaConf.to_container(self.config.imagen), # type: ignore
)
self.trainer = ImagenTrainer(
imagen = imagen,
**self.config.trainer
).to(device)
print("Loading weights from", self.weight_path)
additional_data = self.trainer.load(self.weight_path)
print("Loaded weights from", self.weight_path)
def reshape_image(self, image):
try:
image = self.im_load.transform(image).multiply(255).byte().permute(1,2,0).numpy()
return image
except:
return None
def load_random_image(self):
print("Loading random image")
image = self.im_load.get_image()
return image
def generate_video(self, image, lvef, cond_scale):
print("Generating video")
print(f"lvef: {lvef}, cond_scale: {cond_scale}")
image = self.im_load.transform(image).unsqueeze(0)
sample_kwargs = {}
sample_kwargs = {
"text_embeds": torch.tensor([[[lvef/100.0]]]),
"cond_scale": cond_scale,
"cond_images": image,
}
self.trainer.eval()
with torch.no_grad():
video = self.trainer.sample(
batch_size=1,
video_frames=self.config.dataset.num_frames,
**sample_kwargs,
use_tqdm = True,
).detach().cpu() # C x F x H x W
if video.shape[-3:] != (64, 112, 112):
video = torch.nn.functional.interpolate(video, size=(64, 112, 112), mode='trilinear', align_corners=False)
video = video.repeat((1,1,5,1,1)) # make the video loop 5 times - easier to see
uid = np.random.randint(0, 10) # prevent overwriting if multiple users are using the app
path = f"tmp/{uid}.mp4"
video = video.multiply(255).byte().squeeze(0).permute(1, 2, 3, 0).numpy()
out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), 32, (112, 112))
for i in video:
out.write(i)
out.release()
return path
context = Context(exp_path, device)
with gr.Blocks(css="style.css") as demo:
with gr.Row():
gr.Label("Feature-Conditioned Cascaded Video Diffusion Models for Precise Echocardiogram Synthesis")
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column(scale=3, variant="panel"):
text = gr.Markdown(value="This is a live demo of our work on cardiac ultrasound video generation. The model is trained on 4-chamber cardiac ultrasound videos and can generate realistic 4-chamber videos given a target Left Ventricle Ejection Fraction. Please, start by sampling a random frame from the pool of 100 images taken from the EchoNet-Dynamic dataset, which will act as the conditional image, representing the anatomy of the video. Then, set the target LVEF, and click the button to generate a video. The process takes 30s to 60s. The model running here corresponds to the 1SCM from the paper. **Click on the video to play it.** [Code is available here](https://github.com/HReynaud/EchoDiffusion) ")
with gr.Column(scale=1, min_width="226"):
image = gr.Image(interactive=True)
with gr.Column(scale=1, min_width="226"):
video = gr.Video(interactive=False)
slider_ef = gr.Slider(minimum=10, maximum=90, step=1, label="Target LVEF", value=60, interactive=True)
slider_cond = gr.Slider(minimum=0, maximum=20, step=1, label="Conditional scale (if set to more than 1, generation time is 60s)", value=1, interactive=True)
with gr.Row():
img_btn = gr.Button(value="❶ Get a random cardiac ultrasound image (4Ch)")
run_btn = gr.Button(value="❷ Generate a video (~30s) 🚀")
image.change(context.reshape_image, inputs=[image], outputs=[image])
img_btn.click(context.load_random_image, inputs=[], outputs=[image])
run_btn.click(context.generate_video, inputs=[image, slider_ef, slider_cond], outputs=[video])
if __name__ == "__main__":
demo.queue()
demo.launch() |