from dataclasses import dataclass from typing import Dict, Any, Optional import base64 import logging import random import torch from diffusers import HunyuanVideoPipeline from varnish import Varnish # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class GenerationConfig: """Configuration for video generation""" # Content settings prompt: str negative_prompt: str = "" # Model settings num_frames: int = 49 # Should be 4k + 1 format height: int = 320 width: int = 576 num_inference_steps: int = 50 guidance_scale: float = 7.0 # Reproducibility seed: int = -1 # Varnish post-processing settings fps: int = 30 double_num_frames: bool = False super_resolution: bool = False grain_amount: float = 0.0 quality: int = 18 # CRF scale (0-51, lower is better) # Audio settings enable_audio: bool = False audio_prompt: str = "" audio_negative_prompt: str = "voices, voice, talking, speaking, speech" def validate_and_adjust(self) -> 'GenerationConfig': """Validate and adjust parameters""" # Ensure num_frames follows 4k + 1 format k = (self.num_frames - 1) // 4 self.num_frames = (k * 4) + 1 # Set random seed if not specified if self.seed == -1: self.seed = random.randint(0, 2**32 - 1) return self class EndpointHandler: """Handles video generation requests using HunyuanVideo and Varnish""" def __init__(self, path: str = ""): """Initialize handler with models Args: path: Path to model weights """ self.device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize HunyuanVideo pipeline self.pipeline = HunyuanVideoPipeline.from_pretrained( path, torch_dtype=torch.float16, ).to(self.device) # Initialize text encoders in float16 self.pipeline.text_encoder = self.pipeline.text_encoder.half() self.pipeline.text_encoder_2 = self.pipeline.text_encoder_2.half() # Initialize transformer in bfloat16 self.pipeline.transformer = self.pipeline.transformer.to(torch.bfloat16) # Initialize VAE in float16 self.pipeline.vae = self.pipeline.vae.half() # Initialize Varnish for post-processing self.varnish = Varnish( device=self.device, model_base_dir="/repository/varnish" ) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """Process video generation requests Args: data: Request data containing: - inputs (str): Prompt for video generation - parameters (dict): Generation parameters Returns: Dictionary containing: - video: Base64 encoded MP4 data URI - content-type: MIME type - metadata: Generation metadata """ # Extract inputs inputs = data.pop("inputs", data) if isinstance(inputs, dict): prompt = inputs.get("prompt", "") else: prompt = inputs params = data.get("parameters", {}) # Create and validate config config = GenerationConfig( prompt=prompt, negative_prompt=params.get("negative_prompt", ""), num_frames=params.get("num_frames", 49), height=params.get("height", 320), width=params.get("width", 576), num_inference_steps=params.get("num_inference_steps", 50), guidance_scale=params.get("guidance_scale", 7.0), seed=params.get("seed", -1), fps=params.get("fps", 30), double_num_frames=params.get("double_num_frames", False), super_resolution=params.get("super_resolution", False), grain_amount=params.get("grain_amount", 0.0), quality=params.get("quality", 18), enable_audio=params.get("enable_audio", False), audio_prompt=params.get("audio_prompt", ""), audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"), ).validate_and_adjust() try: # Set random seeds if config.seed != -1: torch.manual_seed(config.seed) random.seed(config.seed) generator = torch.Generator(device=self.device).manual_seed(config.seed) else: generator = None # Generate video frames with torch.inference_mode(): output = self.pipeline( prompt=config.prompt, negative_prompt=config.negative_prompt, num_frames=config.num_frames, height=config.height, width=config.width, num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, generator=generator, output_type="pt", ).frames # Process with Varnish import asyncio try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) result = loop.run_until_complete( self.varnish( input_data=output, fps=config.fps, double_num_frames=config.double_num_frames, super_resolution=config.super_resolution, grain_amount=config.grain_amount, enable_audio=config.enable_audio, audio_prompt=config.audio_prompt, audio_negative_prompt=config.audio_negative_prompt, ) ) # Get video data URI video_uri = loop.run_until_complete( result.write( type="data-uri", quality=config.quality ) ) return { "video": video_uri, "content-type": "video/mp4", "metadata": { "width": result.metadata.width, "height": result.metadata.height, "num_frames": result.metadata.frame_count, "fps": result.metadata.fps, "duration": result.metadata.duration, "seed": config.seed, } } except Exception as e: logger.error(f"Error generating video: {str(e)}") raise RuntimeError(f"Failed to generate video: {str(e)}")