|
from dataclasses import dataclass |
|
from typing import Dict, Any, Optional |
|
import base64 |
|
import logging |
|
import random |
|
import torch |
|
from diffusers import HunyuanVideoPipeline |
|
from varnish import Varnish |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
@dataclass |
|
class GenerationConfig: |
|
"""Configuration for video generation""" |
|
|
|
prompt: str |
|
negative_prompt: str = "" |
|
|
|
|
|
num_frames: int = 49 |
|
height: int = 320 |
|
width: int = 576 |
|
num_inference_steps: int = 50 |
|
guidance_scale: float = 7.0 |
|
|
|
|
|
seed: int = -1 |
|
|
|
|
|
fps: int = 30 |
|
double_num_frames: bool = False |
|
super_resolution: bool = False |
|
grain_amount: float = 0.0 |
|
quality: int = 18 |
|
|
|
|
|
enable_audio: bool = False |
|
audio_prompt: str = "" |
|
audio_negative_prompt: str = "voices, voice, talking, speaking, speech" |
|
|
|
def validate_and_adjust(self) -> 'GenerationConfig': |
|
"""Validate and adjust parameters""" |
|
|
|
k = (self.num_frames - 1) // 4 |
|
self.num_frames = (k * 4) + 1 |
|
|
|
|
|
if self.seed == -1: |
|
self.seed = random.randint(0, 2**32 - 1) |
|
|
|
return self |
|
|
|
class EndpointHandler: |
|
"""Handles video generation requests using HunyuanVideo and Varnish""" |
|
|
|
def __init__(self, path: str = ""): |
|
"""Initialize handler with models |
|
|
|
Args: |
|
path: Path to model weights |
|
""" |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self.pipeline = HunyuanVideoPipeline.from_pretrained( |
|
path, |
|
torch_dtype=torch.float16, |
|
).to(self.device) |
|
|
|
|
|
self.pipeline.text_encoder = self.pipeline.text_encoder.half() |
|
self.pipeline.text_encoder_2 = self.pipeline.text_encoder_2.half() |
|
|
|
|
|
self.pipeline.transformer = self.pipeline.transformer.to(torch.bfloat16) |
|
|
|
|
|
self.pipeline.vae = self.pipeline.vae.half() |
|
|
|
|
|
self.varnish = Varnish( |
|
device=self.device, |
|
model_base_dir="/repository/varnish" |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Process video generation requests |
|
|
|
Args: |
|
data: Request data containing: |
|
- inputs (str): Prompt for video generation |
|
- parameters (dict): Generation parameters |
|
|
|
Returns: |
|
Dictionary containing: |
|
- video: Base64 encoded MP4 data URI |
|
- content-type: MIME type |
|
- metadata: Generation metadata |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
if isinstance(inputs, dict): |
|
prompt = inputs.get("prompt", "") |
|
else: |
|
prompt = inputs |
|
|
|
params = data.get("parameters", {}) |
|
|
|
|
|
config = GenerationConfig( |
|
prompt=prompt, |
|
negative_prompt=params.get("negative_prompt", ""), |
|
num_frames=params.get("num_frames", 49), |
|
height=params.get("height", 320), |
|
width=params.get("width", 576), |
|
num_inference_steps=params.get("num_inference_steps", 50), |
|
guidance_scale=params.get("guidance_scale", 7.0), |
|
seed=params.get("seed", -1), |
|
fps=params.get("fps", 30), |
|
double_num_frames=params.get("double_num_frames", False), |
|
super_resolution=params.get("super_resolution", False), |
|
grain_amount=params.get("grain_amount", 0.0), |
|
quality=params.get("quality", 18), |
|
enable_audio=params.get("enable_audio", False), |
|
audio_prompt=params.get("audio_prompt", ""), |
|
audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"), |
|
).validate_and_adjust() |
|
|
|
try: |
|
|
|
if config.seed != -1: |
|
torch.manual_seed(config.seed) |
|
random.seed(config.seed) |
|
generator = torch.Generator(device=self.device).manual_seed(config.seed) |
|
else: |
|
generator = None |
|
|
|
|
|
with torch.inference_mode(): |
|
output = self.pipeline( |
|
prompt=config.prompt, |
|
negative_prompt=config.negative_prompt, |
|
num_frames=config.num_frames, |
|
height=config.height, |
|
width=config.width, |
|
num_inference_steps=config.num_inference_steps, |
|
guidance_scale=config.guidance_scale, |
|
generator=generator, |
|
output_type="pt", |
|
).frames |
|
|
|
|
|
import asyncio |
|
try: |
|
loop = asyncio.get_event_loop() |
|
except RuntimeError: |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
result = loop.run_until_complete( |
|
self.varnish( |
|
input_data=output, |
|
fps=config.fps, |
|
double_num_frames=config.double_num_frames, |
|
super_resolution=config.super_resolution, |
|
grain_amount=config.grain_amount, |
|
enable_audio=config.enable_audio, |
|
audio_prompt=config.audio_prompt, |
|
audio_negative_prompt=config.audio_negative_prompt, |
|
) |
|
) |
|
|
|
|
|
video_uri = loop.run_until_complete( |
|
result.write( |
|
type="data-uri", |
|
quality=config.quality |
|
) |
|
) |
|
|
|
return { |
|
"video": video_uri, |
|
"content-type": "video/mp4", |
|
"metadata": { |
|
"width": result.metadata.width, |
|
"height": result.metadata.height, |
|
"num_frames": result.metadata.frame_count, |
|
"fps": result.metadata.fps, |
|
"duration": result.metadata.duration, |
|
"seed": config.seed, |
|
} |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating video: {str(e)}") |
|
raise RuntimeError(f"Failed to generate video: {str(e)}") |