Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import json | |
import torch | |
import torch.nn as nn | |
import diffusers | |
from einops import rearrange | |
from PIL import Image | |
from omegaconf import OmegaConf | |
from tqdm import tqdm | |
import cv2 | |
NUM_STEPS = 64 | |
FRAMES = 192 | |
FPS=32 | |
mycss = """ | |
.contain { | |
width: 1000px; | |
margin: 0 auto; | |
} | |
.svelte-1pijsyv { | |
width: 448px; | |
} | |
.arrow { | |
display: flex; | |
align-items: center; | |
margin: 7px 0; | |
} | |
.arrow-tail { | |
width: 270px; | |
height: 50px; | |
background-color: black; | |
transition: background-color 0.3s; | |
} | |
.arrow-head { | |
width: 0; | |
height: 0; | |
border-top: 70px solid transparent; | |
border-bottom: 70px solid transparent; | |
border-left: 120px solid black; | |
transition: border-left-color 0.3s; | |
} | |
@media (prefers-color-scheme: dark) { | |
.arrow-tail { | |
background-color: white; | |
} | |
.arrow-head { | |
border-left-color: white; | |
} | |
} | |
""" | |
myhtml = """ | |
<div class="arrow"> | |
<div class="arrow-tail"></div> | |
<div class="arrow-head"></div> | |
</div> | |
""" | |
myjs = """ | |
function setLoopTrue() { | |
let videos = document.getElementsByTagName('video'); | |
if (videos.length > 0) { | |
document.getElementsByTagName('video')[0].loop = true; | |
} | |
setTimeout(setLoopTrue, 3000); | |
} | |
""" | |
def load_model(path): | |
# find config.json | |
json_path = os.path.join(path, "config.json") | |
assert os.path.exists(json_path), f"Could not find config.json at {json_path}" | |
with open(json_path, "r") as f: | |
config = json.load(f) | |
# instantiate class | |
klass_name = config["_class_name"] | |
klass = getattr(diffusers, klass_name, None) | |
if klass is None: | |
klass = globals().get(klass_name, None) | |
assert klass is not None, f"Could not find class {klass_name} in diffusers or global scope." | |
assert getattr(klass, "from_pretrained", None) is not None, f"Class {klass_name} does not support 'from_pretrained'." | |
# load checkpoint | |
model = klass.from_pretrained(path) | |
return model, config | |
def load_scheduler(config): | |
scheduler_kwargs = OmegaConf.to_container(config.noise_scheduler) | |
scheduler_klass_name = scheduler_kwargs.pop("_class_name") | |
scheduler_klass = getattr(diffusers, scheduler_klass_name, None) | |
scheduler = scheduler_klass(**scheduler_kwargs) | |
return scheduler | |
def padf(tensor, mult=3): | |
pad = 2**mult - (tensor.shape[-1] % 2**mult) | |
pad = pad//2 | |
tensor = nn.functional.pad(tensor, (pad, pad, pad, pad, 0, 0), mode='replicate') | |
return tensor, pad | |
def unpadf(tensor, pad=1): | |
return tensor[..., pad:-pad, pad:-pad] | |
def pad_reshape(tensor, mult=3): | |
tensor, pad = padf(tensor, mult=mult) | |
tensor = rearrange(tensor, "b c t h w -> b t c h w") | |
return tensor, pad | |
def unpad_reshape(tensor, pad=1): | |
tensor = rearrange(tensor, "b t c h w -> b c t h w") | |
tensor = unpadf(tensor, pad=pad) | |
return tensor | |
class Context: | |
def __init__(self, lidm_path, lvdm_path, vae_path, config_path): | |
self.lidm, self.lidm_config = load_model(lidm_path) | |
self.lvdm, self.lvdm_config = load_model(lvdm_path) | |
self.vae, self.vae_config = load_model(vae_path) | |
self.config = OmegaConf.load(config_path) | |
self.models = [self.lidm, self.lvdm, self.vae] | |
self.scheduler = load_scheduler(self.config) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.dtype = torch.float32 | |
for model in self.models: | |
model.to(self.device, dtype=self.dtype) | |
model.eval() | |
print("Models loaded") | |
def get_img(self, steps): | |
print("generating image") | |
self.scheduler.set_timesteps(steps) | |
with torch.no_grad(): | |
B, C, H, W = 1, self.lidm_config["in_channels"], self.lidm_config["sample_size"], self.lidm_config["sample_size"] | |
timesteps = self.scheduler.timesteps | |
forward_kwargs = {} | |
latents = torch.randn((B, C, H, W), device=self.device, dtype=self.dtype) | |
with torch.autocast("cuda"): | |
for t in tqdm(timesteps): | |
forward_kwargs["timestep"] = t | |
latent_model_input = latents | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep=t) | |
latent_model_input, padding = padf(latent_model_input, mult=3) | |
noise_pred = self.lidm(latent_model_input, **forward_kwargs).sample | |
noise_pred = unpadf(noise_pred, pad=padding) | |
latents = self.scheduler.step(noise_pred, t, latents).prev_sample | |
# latent shape[B,C,H,W] | |
latents = latents / self.vae.config.scaling_factor | |
img = self.vae.decode(latents).sample | |
img = (img + 1) * 128 # [-1, 1] -> [0, 256] | |
img = img.mean(1).unsqueeze(1).repeat([1, 3, 1, 1]) | |
img = img.clamp(0, 255).to(torch.uint8).cpu().numpy() | |
img = img[0].transpose(1, 2, 0) | |
img = Image.fromarray(img) | |
return img, latents | |
def get_vid(self, lvef: int, ref_latent: torch.Tensor, steps: int): | |
print("generating video") | |
self.scheduler.set_timesteps(steps) | |
with torch.no_grad(): | |
B, C, T, H, W = 1, 4, self.lvdm_config["num_frames"], self.lvdm_config["sample_size"], self.lvdm_config["sample_size"] | |
if FRAMES > T: | |
OT = T//2 # overlap 64//2 | |
TR = (FRAMES - T) / 32 # total frames (192 - 64) / 32 = 4 | |
TR = int(TR + 1) # total repetitions | |
NT = (T-OT) * TR + OT | |
else: | |
OT = 0 | |
TR = 1 | |
NT = T | |
timesteps = self.scheduler.timesteps | |
lvef = lvef / 100 | |
lvef = torch.tensor([lvef]*TR, device=self.device, dtype=self.dtype) | |
lvef = lvef[:, None, None] | |
print(lvef.shape) | |
forward_kwargs = {} | |
forward_kwargs["added_time_ids"] = torch.zeros((B*TR, self.config.unet.addition_time_embed_dim), device=self.device, dtype=self.dtype) | |
forward_kwargs["encoder_hidden_states"] = lvef | |
print(forward_kwargs["added_time_ids"].shape) | |
latent_cond_images = ref_latent * self.vae.config.scaling_factor | |
latent_cond_images = latent_cond_images[:,:,None,:,:].repeat([1, 1, NT, 1, 1]).to(self.device, dtype=self.dtype) | |
print(latent_cond_images.shape) | |
latents = torch.randn((B, C, NT, H, W), device=self.device, dtype=self.dtype) | |
print(latents.shape) | |
with torch.autocast("cuda"): | |
for t in tqdm(timesteps): | |
forward_kwargs["timestep"] = t | |
latent_model_input = latents | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep=t) | |
latent_model_input = torch.cat((latent_model_input, latent_cond_images), dim=1) # B x 2C x T x H x W | |
latent_model_input, padding = pad_reshape(latent_model_input, mult=3) # B x T x 2C x H+P x W+P | |
inputs = torch.cat([latent_model_input[:,r*(T-OT):r*(T-OT)+T] for r in range(TR)], dim=0) # B*TR x T x 2C x H+P x W+P | |
noise_pred = self.lvdm(inputs, **forward_kwargs).sample | |
outputs = torch.chunk(noise_pred, TR, dim=0) # TR x B x T x C x H x W | |
noise_predictions = [] | |
for r in range(TR): | |
noise_predictions.append(outputs[r] if r == 0 else outputs[r][:,OT:]) | |
noise_pred = torch.cat(noise_predictions, dim=1) # B x NT x C x H x W | |
noise_pred = unpad_reshape(noise_pred, pad=padding) | |
latents = self.scheduler.step(noise_pred, t, latents).prev_sample | |
print("done generating noise") | |
# latent shape[B,C,T,H,W] | |
latents = latents / self.vae.config.scaling_factor | |
latents = rearrange(latents, "b c t h w -> (b t) c h w") | |
chunk_size = 16 | |
chunked_latents = torch.split(latents, chunk_size, dim=0) | |
decoded_chunks = [] | |
for chunk in chunked_latents: | |
decoded_chunks.append(self.vae.decode(chunk.float().cuda()).sample.cpu()) | |
video = torch.cat(decoded_chunks, dim=0) # (B*T) x H x W x C | |
video = rearrange(video, "(b t) c h w -> b t h w c", b=B)[0] # T H W C | |
video = (video + 1) * 128 # [-1, 1] -> [0, 256] | |
video = video.mean(-1).unsqueeze(-1).repeat([1, 1, 1, 3]) # T H W 3 | |
video = video.clamp(0, 255).to(torch.uint8).cpu().numpy() | |
out = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), FPS, (112, 112)) | |
for img in video: | |
out.write(img) | |
out.release() | |
return "output.mp4" | |
ctx = Context( | |
lidm_path="resources/lidm", | |
lvdm_path="resources/lvdm", | |
vae_path="resources/ivae", | |
config_path="resources/config.yaml" | |
) | |
with gr.Blocks(css=mycss, js=myjs) as demo: | |
with gr.Row(): | |
# Greet user with an explanation of the demo | |
gr.Markdown(""" | |
# EchoNet-Synthetic: Privacy-preserving Video Generation for Safe Medical Data Sharing | |
This demo is attached to a paper under review at MICCAI 2024, and is targeted at the reviewers of that paper. | |
1. Start by generating an image using the "Generate Image" button. This will generate a random image, similar to the EchoNet-Dynamic dataset. | |
2. Adjust the "Ejection Fraction Score" slider to change the ejection fraction of the generated image. | |
3. Generate a video using the "Generate Video" button. This will generate a video from the generated image, with the ejection fraction score you chose. | |
We leave the ejection fraction input completely open, so you can see how the video generation changes with different ejection fraction scores, even unrealistic ones. The normal ejection fraction range is 50-75.<br> | |
We recommend 64 steps for ideal image quality, but you can adjust this to see how it affects the video generation. | |
""") | |
with gr.Row(): | |
# core activity | |
# 3 columns | |
with gr.Column(): | |
# Image generation goes here | |
img = gr.Image(interactive=False, label="Generated Image") # allow user upload | |
img_btn = gr.Button("Generate Image") | |
with gr.Column(): | |
# LVEF slider goes here | |
# Add an big arrow image for show | |
gr.HTML(myhtml) | |
efslider = gr.Slider(minimum=0, maximum=100, value=65, step=1, label="Ejection Fraction Score (%)") # | |
dsslider = gr.Slider(minimum=1, maximum=999, value=64, step=1, label="Sampling Steps") # | |
pass | |
with gr.Column(): | |
# Video generation goes here | |
vid = gr.Video(interactive=False, autoplay=True, label="Generated Video") | |
vid_btn = gr.Button("Generate Video") | |
with gr.Row(): | |
# Additional informations | |
gr.Examples( | |
examples=[[f"resources/examples/ef{i}.png", f"resources/examples/ef{i}.mp4", i, 64] for i in [20, 30, 40, 50, 60, 70, 80, 90]], | |
inputs=[img, vid, efslider, dsslider], | |
outputs=None, | |
fn=None, | |
cache_examples=False, | |
) | |
ltt_img = gr.State() # latent image state | |
img.change() # apply center-cropping | |
img_btn.click(fn=ctx.get_img, inputs=[dsslider], outputs=[img, ltt_img]) # generate image with lidm | |
vid_btn.click(fn=ctx.get_vid, inputs=[efslider, ltt_img, dsslider], outputs=[vid]) # generate video with lvdm | |
demo.launch(share=False) |