File size: 5,951 Bytes
316f1d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439103c
316f1d5
 
 
 
 
e3044ba
316f1d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import os
from omegaconf import OmegaConf
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer, ElucidatedImagenConfig, NullUnet, Imagen
import torch
import numpy as np
import cv2
from PIL import Image
import torchvision.transforms as T


device = "cuda" if torch.cuda.is_available() else "cpu"
exp_path = "model"

class BetterCenterCrop(T.CenterCrop):
    def __call__(self, img):
        h = img.shape[-2]
        w = img.shape[-1]
        dim = min(h, w)

        return T.functional.center_crop(img, dim)

class ImageLoader:
    def __init__(self, path) -> None:
        self.path = path
        self.all_files = os.listdir(path)
        self.transform = T.Compose([
            T.ToTensor(),
            BetterCenterCrop((112, 112)),
            T.Resize((112, 112)),
        ])
        
    def get_image(self):
        idx = np.random.randint(0, len(self.all_files))
        img = Image.open(os.path.join(self.path, self.all_files[idx]))
        return img

class Context:
    def __init__(self, path, device):
        self.path = path
        self.config_path = os.path.join(path, "config.yaml")
        self.weight_path = os.path.join(path, "merged.pt")
        
        self.config = OmegaConf.load(self.config_path)

        self.config.dataset.num_frames = int(self.config.dataset.fps * self.config.dataset.duration)

        self.im_load = ImageLoader("echo_images")

        unets = []
        for i, (k, v) in enumerate(self.config.unets.items()):
            unets.append(Unet3D(**v, lowres_cond=(i>0))) # type: ignore

        imagen_klass = ElucidatedImagen if self.config.imagen.elucidated == True else Imagen
        del self.config.imagen.elucidated
        imagen = imagen_klass(
            unets = unets,
            **OmegaConf.to_container(self.config.imagen), # type: ignore
        )

        self.trainer = ImagenTrainer(
            imagen = imagen,
            **self.config.trainer
        ).to(device)

        print("Loading weights from", self.weight_path)
        additional_data = self.trainer.load(self.weight_path)
        print("Loaded weights from", self.weight_path)
    
    def reshape_image(self, image):
        try:
            image = self.im_load.transform(image).multiply(255).byte().permute(1,2,0).numpy()
            return image
        except:
            return None
    
    def load_random_image(self):
        print("Loading random image")        
        image = self.im_load.get_image()
        return image
    
    def generate_video(self, image, lvef, cond_scale):
        print("Generating video")
        print(f"lvef: {lvef}, cond_scale: {cond_scale}")

        image = self.im_load.transform(image).unsqueeze(0)

        sample_kwargs = {}
        sample_kwargs = {
            "text_embeds": torch.tensor([[[lvef/100.0]]]),
            "cond_scale": cond_scale,
            "cond_images": image,
        }
            
        self.trainer.eval()
        with torch.no_grad():
            video = self.trainer.sample(
                batch_size=1, 
                video_frames=self.config.dataset.num_frames,
                **sample_kwargs,
                use_tqdm = True,
            ).detach().cpu() # C x F x H x W
        if video.shape[-3:] != (64, 112, 112):
            video = torch.nn.functional.interpolate(video, size=(64, 112, 112), mode='trilinear', align_corners=False)
        video = video.repeat((1,1,5,1,1)) # make the video loop 5 times - easier to see
        uid = np.random.randint(0, 10) # prevent overwriting if multiple users are using the app
        path = f"tmp/{uid}.mp4"
        video = video.multiply(255).byte().squeeze(0).permute(1, 2, 3, 0).numpy()
        out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), 32, (112, 112))
        for i in video:
            out.write(i)
        out.release()
        return path

context = Context(exp_path, device)

with gr.Blocks(css="style.css") as demo:

    with gr.Row():
        gr.Label("Feature-Conditioned Cascaded Video Diffusion Models for Precise Echocardiogram Synthesis")

    with gr.Row():
        with gr.Column():
            with gr.Row():
                with gr.Column(scale=3, variant="panel"):
                    text = gr.Markdown(value="This is a live demo of our work on cardiac ultrasound video generation. The model is trained on 4-chamber cardiac ultrasound videos and can generate realistic 4-chamber videos given a target Left Ventricle Ejection Fraction. Please, start by sampling a random frame from the pool of 100 images taken from the EchoNet-Dynamic dataset, which will act as the conditional image, representing the anatomy of the video. Then, set the target LVEF, and click the button to generate a video. The process takes 30s to 60s. The model running here corresponds to the 1SCM from the paper. **Click on the video to play it.** [Code is available here](https://github.com/HReynaud/EchoDiffusion) ")
                with gr.Column(scale=1, min_width="226"):
                    image = gr.Image(interactive=True)
                with gr.Column(scale=1, min_width="226"):
                    video = gr.Video(interactive=False)

            slider_ef = gr.Slider(minimum=10, maximum=90, step=1, label="Target LVEF", value=60, interactive=True)
            slider_cond = gr.Slider(minimum=0, maximum=20, step=1, label="Conditional scale (if set to more than 1, generation time is 60s)", value=1, interactive=True)

            with gr.Row():
                img_btn = gr.Button(value="❶ Get a random cardiac ultrasound image (4Ch)")
                run_btn = gr.Button(value="❷ Generate a video (~30s) 🚀")

    image.change(context.reshape_image, inputs=[image], outputs=[image])
    img_btn.click(context.load_random_image, inputs=[], outputs=[image])
    run_btn.click(context.generate_video, inputs=[image, slider_ef, slider_cond], outputs=[video])
 
if __name__ == "__main__":
    demo.queue()
    demo.launch()