Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import gradio as gr | |
os.system('git clone https://github.com/openai/CLIP') | |
os.system('git clone https://github.com/crowsonkb/guided-diffusion') | |
os.system('pip install -e ./CLIP') | |
os.system('pip install -e ./guided-diffusion') | |
os.system('pip install lpips') | |
os.system("curl -OL 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'") | |
import io | |
import math | |
import sys | |
import lpips | |
from PIL import Image | |
import requests | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torchvision import transforms | |
from torchvision.transforms import functional as TF | |
from tqdm.notebook import tqdm | |
sys.path.append('./CLIP') | |
sys.path.append('./guided-diffusion') | |
import clip | |
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults | |
import numpy as np | |
import imageio | |
torch.hub.download_url_to_file('https://images.pexels.com/photos/68767/divers-underwater-ocean-swim-68767.jpeg', 'coralreef.jpeg') | |
def fetch(url_or_path): | |
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): | |
r = requests.get(url_or_path) | |
r.raise_for_status() | |
fd = io.BytesIO() | |
fd.write(r.content) | |
fd.seek(0) | |
return fd | |
return open(url_or_path, 'rb') | |
def parse_prompt(prompt): | |
if prompt.startswith('http://') or prompt.startswith('https://'): | |
vals = prompt.rsplit(':', 2) | |
vals = [vals[0] + ':' + vals[1], *vals[2:]] | |
else: | |
vals = prompt.rsplit(':', 1) | |
vals = vals + ['', '1'][len(vals):] | |
return vals[0], float(vals[1]) | |
class MakeCutouts(nn.Module): | |
def __init__(self, cut_size, cutn, cut_pow=1.): | |
super().__init__() | |
self.cut_size = cut_size | |
self.cutn = cutn | |
self.cut_pow = cut_pow | |
def forward(self, input): | |
sideY, sideX = input.shape[2:4] | |
max_size = min(sideX, sideY) | |
min_size = min(sideX, sideY, self.cut_size) | |
cutouts = [] | |
for _ in range(self.cutn): | |
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) | |
offsetx = torch.randint(0, sideX - size + 1, ()) | |
offsety = torch.randint(0, sideY - size + 1, ()) | |
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] | |
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | |
return torch.cat(cutouts) | |
def spherical_dist_loss(x, y): | |
x = F.normalize(x, dim=-1) | |
y = F.normalize(y, dim=-1) | |
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) | |
def tv_loss(input): | |
"""L2 total variation loss, as in Mahendran et al.""" | |
input = F.pad(input, (0, 1, 0, 1), 'replicate') | |
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] | |
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] | |
return (x_diff**2 + y_diff**2).mean([1, 2, 3]) | |
def range_loss(input): | |
return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3]) | |
def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompts,timestep_respacing, cutn): | |
# Model settings | |
model_config = model_and_diffusion_defaults() | |
model_config.update({ | |
'attention_resolutions': '32, 16, 8', | |
'class_cond': False, | |
'diffusion_steps': 1000, | |
'rescale_timesteps': True, | |
'timestep_respacing': str(timestep_respacing), # Modify this value to decrease the number of | |
# timesteps. | |
'image_size': 256, | |
'learn_sigma': True, | |
'noise_schedule': 'linear', | |
'num_channels': 256, | |
'num_head_channels': 64, | |
'num_res_blocks': 2, | |
'resblock_updown': True, | |
'use_fp16': True, | |
'use_scale_shift_norm': True, | |
}) | |
# Load models | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
print('Using device:', device) | |
model, diffusion = create_model_and_diffusion(**model_config) | |
model.load_state_dict(torch.load('256x256_diffusion_uncond.pt', map_location='cpu')) | |
model.requires_grad_(False).eval().to(device) | |
for name, param in model.named_parameters(): | |
if 'qkv' in name or 'norm' in name or 'proj' in name: | |
param.requires_grad_() | |
if model_config['use_fp16']: | |
model.convert_to_fp16() | |
clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device) | |
clip_size = clip_model.visual.input_resolution | |
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | |
std=[0.26862954, 0.26130258, 0.27577711]) | |
lpips_model = lpips.LPIPS(net='vgg').to(device) | |
#def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompt): | |
all_frames = [] | |
prompts = [text] | |
if image_prompts: | |
image_prompts = [image_prompts.name] | |
else: | |
image_prompts = [] | |
batch_size = 1 | |
clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt. | |
tv_scale = tv_scale # Controls the smoothness of the final output. | |
range_scale = range_scale # Controls how far out of range RGB values are allowed to be. | |
cutn = cutn | |
n_batches = 1 | |
if init_image: | |
init_image = init_image.name | |
else: | |
init_image = None # This can be an URL or Colab local path and must be in quotes. | |
skip_timesteps = skip_timesteps # This needs to be between approx. 200 and 500 when using an init image. | |
# Higher values make the output look more like the init. | |
init_scale = init_scale # This enhances the effect of the init image, a good value is 1000. | |
seed = seed | |
if seed is not None: | |
torch.manual_seed(seed) | |
make_cutouts = MakeCutouts(clip_size, cutn) | |
side_x = side_y = model_config['image_size'] | |
target_embeds, weights = [], [] | |
for prompt in prompts: | |
txt, weight = parse_prompt(prompt) | |
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float()) | |
weights.append(weight) | |
for prompt in image_prompts: | |
path, weight = parse_prompt(prompt) | |
img = Image.open(fetch(path)).convert('RGB') | |
img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS) | |
batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device)) | |
embed = clip_model.encode_image(normalize(batch)).float() | |
target_embeds.append(embed) | |
weights.extend([weight / cutn] * cutn) | |
target_embeds = torch.cat(target_embeds) | |
weights = torch.tensor(weights, device=device) | |
if weights.sum().abs() < 1e-3: | |
raise RuntimeError('The weights must not sum to 0.') | |
weights /= weights.sum().abs() | |
init = None | |
if init_image is not None: | |
init = Image.open(fetch(init_image)).convert('RGB') | |
init = init.resize((side_x, side_y), Image.LANCZOS) | |
init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1) | |
cur_t = None | |
def cond_fn(x, t, y=None): | |
with torch.enable_grad(): | |
x = x.detach().requires_grad_() | |
n = x.shape[0] | |
my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t | |
out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y}) | |
fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t] | |
x_in = out['pred_xstart'] * fac + x * (1 - fac) | |
clip_in = normalize(make_cutouts(x_in.add(1).div(2))) | |
image_embeds = clip_model.encode_image(clip_in).float() | |
dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0)) | |
dists = dists.view([cutn, n, -1]) | |
losses = dists.mul(weights).sum(2).mean(0) | |
tv_losses = tv_loss(x_in) | |
range_losses = range_loss(out['pred_xstart']) | |
loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale | |
if init is not None and init_scale: | |
init_losses = lpips_model(x_in, init) | |
loss = loss + init_losses.sum() * init_scale | |
return -torch.autograd.grad(loss, x)[0] | |
if model_config['timestep_respacing'].startswith('ddim'): | |
sample_fn = diffusion.ddim_sample_loop_progressive | |
else: | |
sample_fn = diffusion.p_sample_loop_progressive | |
for i in range(n_batches): | |
cur_t = diffusion.num_timesteps - skip_timesteps - 1 | |
samples = sample_fn( | |
model, | |
(batch_size, 3, side_y, side_x), | |
clip_denoised=False, | |
model_kwargs={}, | |
cond_fn=cond_fn, | |
progress=True, | |
skip_timesteps=skip_timesteps, | |
init_image=init, | |
randomize_class=True, | |
) | |
for j, sample in enumerate(samples): | |
cur_t -= 1 | |
if j % 1 == 0 or cur_t == -1: | |
print() | |
for k, image in enumerate(sample['pred_xstart']): | |
#filename = f'progress_{i * batch_size + k:05}.png' | |
img = TF.to_pil_image(image.add(1).div(2).clamp(0, 1)) | |
all_frames.append(img) | |
tqdm.write(f'Batch {i}, step {j}, output {k}:') | |
#display.display(display.Image(filename)) | |
writer = imageio.get_writer('video.mp4', fps=5) | |
for im in all_frames: | |
writer.append_data(np.array(im)) | |
writer.close() | |
return img, 'video.mp4' | |
title = "CLIP Guided Diffusion HQ" | |
description = "Gradio demo for CLIP Guided Diffusion. To use it, simply add your text, or click one of the examples to load them. Read more at the links below." | |
article = "<p style='text-align: center'> By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses OpenAI's 256x256 unconditional ImageNet diffusion model (https://github.com/openai/guided-diffusion) together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images. | <a href='https://colab.research.google.com/drive/12a_Wrfi2_gwwAuN3VvMTwVMz9TfqctNj' target='_blank'>Colab</a></p>" | |
iface = gr.Interface(inference, inputs=["text",gr.inputs.Image(type="file", label='initial image (optional)', optional=True),gr.inputs.Slider(minimum=0, maximum=45, step=1, default=10, label="skip_timesteps"), gr.inputs.Slider(minimum=0, maximum=3000, step=1, default=600, label="clip guidance scale (Controls how much the image should look like the prompt)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="tv_scale (Controls the smoothness of the final output)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="range_scale (Controls how far out of range RGB values are allowed to be)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="init_scale (This enhances the effect of the init image)"), gr.inputs.Number(default=0, label="Seed"), gr.inputs.Image(type="file", label='image prompt (optional)', optional=True), gr.inputs.Slider(minimum=50, maximum=500, step=1, default=50, label="timestep respacing"),gr.inputs.Slider(minimum=1, maximum=64, step=1, default=32, label="cutn")], outputs=["image","video"], title=title, description=description, article=article, examples=[["coral reef city by artistation artists", "coralreef.jpeg", 0, 1000, 150, 50, 0, 0, "coralreef.jpeg", 90, 32]]) | |
iface.launch() |