import gradio as gr import os import json import torch import torch.nn as nn import diffusers from einops import rearrange from PIL import Image from omegaconf import OmegaConf from tqdm import tqdm import cv2 NUM_STEPS = 64 FRAMES = 192 FPS=32 mycss = """ .contain { width: 1000px; margin: 0 auto; } .svelte-1pijsyv { width: 448px; } .arrow { display: flex; align-items: center; margin: 7px 0; } .arrow-tail { width: 270px; height: 50px; background-color: black; transition: background-color 0.3s; } .arrow-head { width: 0; height: 0; border-top: 70px solid transparent; border-bottom: 70px solid transparent; border-left: 120px solid black; transition: border-left-color 0.3s; } @media (prefers-color-scheme: dark) { .arrow-tail { background-color: white; } .arrow-head { border-left-color: white; } } """ myhtml = """
""" myjs = """ function setLoopTrue() { let videos = document.getElementsByTagName('video'); if (videos.length > 0) { document.getElementsByTagName('video')[0].loop = true; } setTimeout(setLoopTrue, 3000); } """ def load_model(path): # find config.json json_path = os.path.join(path, "config.json") assert os.path.exists(json_path), f"Could not find config.json at {json_path}" with open(json_path, "r") as f: config = json.load(f) # instantiate class klass_name = config["_class_name"] klass = getattr(diffusers, klass_name, None) if klass is None: klass = globals().get(klass_name, None) assert klass is not None, f"Could not find class {klass_name} in diffusers or global scope." assert getattr(klass, "from_pretrained", None) is not None, f"Class {klass_name} does not support 'from_pretrained'." # load checkpoint model = klass.from_pretrained(path) return model, config def load_scheduler(config): scheduler_kwargs = OmegaConf.to_container(config.noise_scheduler) scheduler_klass_name = scheduler_kwargs.pop("_class_name") scheduler_klass = getattr(diffusers, scheduler_klass_name, None) scheduler = scheduler_klass(**scheduler_kwargs) return scheduler def padf(tensor, mult=3): pad = 2**mult - (tensor.shape[-1] % 2**mult) pad = pad//2 tensor = nn.functional.pad(tensor, (pad, pad, pad, pad, 0, 0), mode='replicate') return tensor, pad def unpadf(tensor, pad=1): return tensor[..., pad:-pad, pad:-pad] def pad_reshape(tensor, mult=3): tensor, pad = padf(tensor, mult=mult) tensor = rearrange(tensor, "b c t h w -> b t c h w") return tensor, pad def unpad_reshape(tensor, pad=1): tensor = rearrange(tensor, "b t c h w -> b c t h w") tensor = unpadf(tensor, pad=pad) return tensor class Context: def __init__(self, lidm_path, lvdm_path, vae_path, config_path): self.lidm, self.lidm_config = load_model(lidm_path) self.lvdm, self.lvdm_config = load_model(lvdm_path) self.vae, self.vae_config = load_model(vae_path) self.config = OmegaConf.load(config_path) self.models = [self.lidm, self.lvdm, self.vae] self.scheduler = load_scheduler(self.config) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.dtype = torch.float32 for model in self.models: model.to(self.device, dtype=self.dtype) model.eval() print("Models loaded") def get_img(self, steps): print("generating image") self.scheduler.set_timesteps(steps) with torch.no_grad(): B, C, H, W = 1, self.lidm_config["in_channels"], self.lidm_config["sample_size"], self.lidm_config["sample_size"] timesteps = self.scheduler.timesteps forward_kwargs = {} latents = torch.randn((B, C, H, W), device=self.device, dtype=self.dtype) with torch.autocast("cuda"): for t in tqdm(timesteps): forward_kwargs["timestep"] = t latent_model_input = latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep=t) latent_model_input, padding = padf(latent_model_input, mult=3) noise_pred = self.lidm(latent_model_input, **forward_kwargs).sample noise_pred = unpadf(noise_pred, pad=padding) latents = self.scheduler.step(noise_pred, t, latents).prev_sample # latent shape[B,C,H,W] latents = latents / self.vae.config.scaling_factor img = self.vae.decode(latents).sample img = (img + 1) * 128 # [-1, 1] -> [0, 256] img = img.mean(1).unsqueeze(1).repeat([1, 3, 1, 1]) img = img.clamp(0, 255).to(torch.uint8).cpu().numpy() img = img[0].transpose(1, 2, 0) img = Image.fromarray(img) return img, latents def get_vid(self, lvef: int, ref_latent: torch.Tensor, steps: int): print("generating video") self.scheduler.set_timesteps(steps) with torch.no_grad(): B, C, T, H, W = 1, 4, self.lvdm_config["num_frames"], self.lvdm_config["sample_size"], self.lvdm_config["sample_size"] if FRAMES > T: OT = T//2 # overlap 64//2 TR = (FRAMES - T) / 32 # total frames (192 - 64) / 32 = 4 TR = int(TR + 1) # total repetitions NT = (T-OT) * TR + OT else: OT = 0 TR = 1 NT = T timesteps = self.scheduler.timesteps lvef = lvef / 100 lvef = torch.tensor([lvef]*TR, device=self.device, dtype=self.dtype) lvef = lvef[:, None, None] print(lvef.shape) forward_kwargs = {} forward_kwargs["added_time_ids"] = torch.zeros((B*TR, self.config.unet.addition_time_embed_dim), device=self.device, dtype=self.dtype) forward_kwargs["encoder_hidden_states"] = lvef print(forward_kwargs["added_time_ids"].shape) latent_cond_images = ref_latent * self.vae.config.scaling_factor latent_cond_images = latent_cond_images[:,:,None,:,:].repeat([1, 1, NT, 1, 1]).to(self.device, dtype=self.dtype) print(latent_cond_images.shape) latents = torch.randn((B, C, NT, H, W), device=self.device, dtype=self.dtype) print(latents.shape) with torch.autocast("cuda"): for t in tqdm(timesteps): forward_kwargs["timestep"] = t latent_model_input = latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep=t) latent_model_input = torch.cat((latent_model_input, latent_cond_images), dim=1) # B x 2C x T x H x W latent_model_input, padding = pad_reshape(latent_model_input, mult=3) # B x T x 2C x H+P x W+P inputs = torch.cat([latent_model_input[:,r*(T-OT):r*(T-OT)+T] for r in range(TR)], dim=0) # B*TR x T x 2C x H+P x W+P noise_pred = self.lvdm(inputs, **forward_kwargs).sample outputs = torch.chunk(noise_pred, TR, dim=0) # TR x B x T x C x H x W noise_predictions = [] for r in range(TR): noise_predictions.append(outputs[r] if r == 0 else outputs[r][:,OT:]) noise_pred = torch.cat(noise_predictions, dim=1) # B x NT x C x H x W noise_pred = unpad_reshape(noise_pred, pad=padding) latents = self.scheduler.step(noise_pred, t, latents).prev_sample print("done generating noise") # latent shape[B,C,T,H,W] latents = latents / self.vae.config.scaling_factor latents = rearrange(latents, "b c t h w -> (b t) c h w") chunk_size = 16 chunked_latents = torch.split(latents, chunk_size, dim=0) decoded_chunks = [] for chunk in chunked_latents: decoded_chunks.append(self.vae.decode(chunk.float().cuda()).sample.cpu()) video = torch.cat(decoded_chunks, dim=0) # (B*T) x H x W x C video = rearrange(video, "(b t) c h w -> b t h w c", b=B)[0] # T H W C video = (video + 1) * 128 # [-1, 1] -> [0, 256] video = video.mean(-1).unsqueeze(-1).repeat([1, 1, 1, 3]) # T H W 3 video = video.clamp(0, 255).to(torch.uint8).cpu().numpy() out = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), FPS, (112, 112)) for img in video: out.write(img) out.release() return "output.mp4" ctx = Context( lidm_path="resources/lidm", lvdm_path="resources/lvdm", vae_path="resources/ivae", config_path="resources/config.yaml" ) with gr.Blocks(css=mycss, js=myjs) as demo: with gr.Row(): # Greet user with an explanation of the demo gr.Markdown(""" # EchoNet-Synthetic: Privacy-preserving Video Generation for Safe Medical Data Sharing This demo is attached to a paper under review at MICCAI 2024, and is targeted at the reviewers of that paper. 1. Start by generating an image using the "Generate Image" button. This will generate a random image, similar to the EchoNet-Dynamic dataset. 2. Adjust the "Ejection Fraction Score" slider to change the ejection fraction of the generated image. 3. Generate a video using the "Generate Video" button. This will generate a video from the generated image, with the ejection fraction score you chose. We leave the ejection fraction input completely open, so you can see how the video generation changes with different ejection fraction scores, even unrealistic ones. The normal ejection fraction range is 50-75.
We recommend 64 steps for ideal image quality, but you can adjust this to see how it affects the video generation. """) with gr.Row(): # core activity # 3 columns with gr.Column(): # Image generation goes here img = gr.Image(interactive=False, label="Generated Image") # allow user upload img_btn = gr.Button("Generate Image") with gr.Column(): # LVEF slider goes here # Add an big arrow image for show gr.HTML(myhtml) efslider = gr.Slider(minimum=0, maximum=100, value=65, step=1, label="Ejection Fraction Score (%)") # dsslider = gr.Slider(minimum=1, maximum=999, value=64, step=1, label="Sampling Steps") # pass with gr.Column(): # Video generation goes here vid = gr.Video(interactive=False, autoplay=True, label="Generated Video") vid_btn = gr.Button("Generate Video") with gr.Row(): # Additional informations gr.Examples( examples=[[f"resources/examples/ef{i}.png", f"resources/examples/ef{i}.mp4", i, 64] for i in [20, 30, 40, 50, 60, 70, 80, 90]], inputs=[img, vid, efslider, dsslider], outputs=None, fn=None, cache_examples=False, ) ltt_img = gr.State() # latent image state img.change() # apply center-cropping img_btn.click(fn=ctx.get_img, inputs=[dsslider], outputs=[img, ltt_img]) # generate image with lidm vid_btn.click(fn=ctx.get_vid, inputs=[efslider, ltt_img, dsslider], outputs=[vid]) # generate video with lvdm demo.launch(share=False)