""" THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to generate videos gradio web demo. set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt. Usage: OpenAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py """ import math import os import random import threading import time import spaces import cv2 import tempfile import imageio_ffmpeg import gradio as gr import torch from PIL import Image # from diffusers import ( # CogVideoXPipeline, # CogVideoXDPMScheduler, # CogVideoXVideoToVideoPipeline, # CogVideoXImageToVideoPipeline, # CogVideoXTransformer3DModel, # ) from typing import Union, List from CogVideoX.pipeline_rgba import CogVideoXPipeline from CogVideoX.rgba_utils import * from diffusers import CogVideoXDPMScheduler from diffusers.utils import load_video, load_image, export_to_video from datetime import datetime, timedelta from diffusers.image_processor import VaeImageProcessor import moviepy.editor as mp import numpy as np from huggingface_hub import hf_hub_download, snapshot_download import gc device = "cuda" if torch.cuda.is_available() else "cpu" hf_hub_download(repo_id="wileewang/TransPixar", filename="cogvideox_rgba_lora.safetensors", local_dir="model_cogvideox_rgba_lora") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5B", torch_dtype=torch.bfloat16) # pipe.enable_sequential_cpu_offload() pipe.vae.enable_slicing() pipe.vae.enable_tiling() pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") seq_length = 2 * ( (480 // pipe.vae_scale_factor_spatial // 2) * (720 // pipe.vae_scale_factor_spatial // 2) * ((13 - 1) // pipe.vae_scale_factor_temporal + 1) ) prepare_for_rgba_inference( pipe.transformer, rgba_weights_path="model_cogvideox_rgba_lora/cogvideox_rgba_lora.safetensors", device=device, dtype=torch.bfloat16, text_length=226, seq_length=seq_length, # this is for the creation of attention mask. ) # pipe.transformer.to(memory_format=torch.channels_last) # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) # pipe_image.transformer.to(memory_format=torch.channels_last) # pipe_image.transformer = torch.compile(pipe_image.transformer, mode="max-autotune", fullgraph=True) os.makedirs("./output", exist_ok=True) os.makedirs("./gradio_tmp", exist_ok=True) # upscale_model = utils.load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device) # frame_interpolation_model = load_rife_model("model_rife") def save_video(tensor: Union[List[np.ndarray], List[Image.Image]], fps: int = 8, prefix='rgb'): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") video_path = f"./output/{prefix}_{timestamp}.mp4" os.makedirs(os.path.dirname(video_path), exist_ok=True) export_to_video(tensor, video_path, fps=fps) return video_path def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)): width, height = get_video_dimensions(input_video) if width == 720 and height == 480: processed_video = input_video else: processed_video = center_crop_resize(input_video) return processed_video def get_video_dimensions(input_video_path): reader = imageio_ffmpeg.read_frames(input_video_path) metadata = next(reader) return metadata["size"] def center_crop_resize(input_video_path, target_width=720, target_height=480): cap = cv2.VideoCapture(input_video_path) orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) orig_fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) width_factor = target_width / orig_width height_factor = target_height / orig_height resize_factor = max(width_factor, height_factor) inter_width = int(orig_width * resize_factor) inter_height = int(orig_height * resize_factor) target_fps = 8 ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1) skip = min(5, ideal_skip) # Cap at 5 while (total_frames / (skip + 1)) < 49 and skip > 0: skip -= 1 processed_frames = [] frame_count = 0 total_read = 0 while frame_count < 49 and total_read < total_frames: ret, frame = cap.read() if not ret: break if total_read % (skip + 1) == 0: resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA) start_x = (inter_width - target_width) // 2 start_y = (inter_height - target_height) // 2 cropped = resized[start_y : start_y + target_height, start_x : start_x + target_width] processed_frames.append(cropped) frame_count += 1 total_read += 1 cap.release() with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: temp_video_path = temp_file.name fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height)) for frame in processed_frames: out.write(frame) out.release() return temp_video_path @spaces.GPU(duration=300) def infer( prompt: str, num_inference_steps: int, guidance_scale: float, seed: int = -1, progress=gr.Progress(track_tqdm=True), ): if seed == -1: seed = random.randint(0, 2**8 - 1) device = "cuda" if torch.cuda.is_available() else "cpu" pipe.to(device) video_pt = pipe( prompt=prompt + ", isolated background", num_videos_per_prompt=1, num_inference_steps=num_inference_steps, num_frames=13, use_dynamic_cfg=True, output_type="latent", guidance_scale=guidance_scale, generator=torch.Generator(device=device).manual_seed(int(seed)), ).frames # pipe.to("cpu") gc.collect() return (video_pt, seed) def convert_to_gif(video_path): clip = mp.VideoFileClip(video_path) clip = clip.set_fps(8) clip = clip.resize(height=240) gif_path = video_path.replace(".mp4", ".gif") clip.write_gif(gif_path, fps=8) return gif_path def delete_old_files(): while True: now = datetime.now() cutoff = now - timedelta(minutes=10) directories = ["./output", "./gradio_tmp"] for directory in directories: for filename in os.listdir(directory): file_path = os.path.join(directory, filename) if os.path.isfile(file_path): file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) if file_mtime < cutoff: os.remove(file_path) time.sleep(600) threading.Thread(target=delete_old_files, daemon=True).start() # examples_videos = [["example_videos/horse.mp4"], ["example_videos/kitten.mp4"], ["example_videos/train_running.mp4"]] # examples_images = [["example_images/beach.png"], ["example_images/street.png"], ["example_images/camping.png"]] with gr.Blocks() as demo: gr.HTML("""
TransPixar + CogVideoX-5B Huggingface Space🤗
🤗 TransPixar LoRA Hub | 🌐 Github | 📜 arxiv
⚠️ This demo is for academic research and experiential use only.
""") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5) with gr.Group(): with gr.Column(): with gr.Row(): seed_param = gr.Number( label="Inference Seed (Enter a positive number, -1 for random)", value=-1 ) generate_button = gr.Button("🎬 Generate Video") # Add the note at the bottom-left with gr.Row(): gr.Markdown( """ **Note:** The output RGB is a premultiplied version to avoid the color decontamination problem. It can directly composite with a background using: ``` composite = rgb + (1 - alpha) * background ``` Due to limited online resources, we have restricted the inference steps to 25 and the number of frames to 13, which may affect the generation quality to some extent. For a better experience, we recommend visiting our GitHub repository and running the method locally by following the provided setup instructions. """ ) with gr.Column(): rgb_video_output = gr.Video(label="Generated RGB Video", width=720, height=480) alpha_video_output = gr.Video(label="Generated Alpha Video", width=720, height=480) with gr.Row(): download_rgb_video_button = gr.File(label="📥 Download RGB Video", visible=False) download_alpha_video_button = gr.File(label="📥 Download Alpha Video", visible=False) seed_text = gr.Number(label="Seed Used for Video Generation", visible=False) @spaces.GPU(duration=300) def generate( prompt, seed_value, progress=gr.Progress(track_tqdm=True) ): latents, seed = infer( prompt, num_inference_steps=25, # NOT Changed guidance_scale=7.0, # NOT Changed seed=seed_value, progress=progress, ) latents_rgb, latents_alpha = latents.chunk(2, dim=1) frames_rgb = decode_latents(pipe, latents_rgb) frames_alpha = decode_latents(pipe, latents_alpha) pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True) frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1) premultiplied_rgb = frames_rgb * frames_alpha_pooled rgb_video_path = save_video(premultiplied_rgb[0], fps=8, prefix='rgb') rgb_video_update = gr.update(visible=True, value=rgb_video_path) alpha_video_path = save_video(frames_alpha_pooled[0], fps=8, prefix='alpha') alpha_video_update = gr.update(visible=True, value=alpha_video_path) seed_update = gr.update(visible=True, value=seed) pipe.to("cpu") return rgb_video_path, alpha_video_path, rgb_video_update, alpha_video_update, seed_update generate_button.click( generate, inputs=[prompt, seed_param], outputs=[rgb_video_output, alpha_video_output, download_rgb_video_button, download_alpha_video_button, seed_text], ) if __name__ == "__main__": demo.queue(max_size=15) demo.launch()