jbilcke-hf HF staff commited on
Commit
2df6ae7
1 Parent(s): c2310c3

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +198 -0
handler.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Any, Optional
3
+ import base64
4
+ import logging
5
+ import random
6
+ import torch
7
+ from diffusers import HunyuanVideoPipeline
8
+ from varnish import Varnish
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ @dataclass
15
+ class GenerationConfig:
16
+ """Configuration for video generation"""
17
+ # Content settings
18
+ prompt: str
19
+ negative_prompt: str = ""
20
+
21
+ # Model settings
22
+ num_frames: int = 49 # Should be 4k + 1 format
23
+ height: int = 320
24
+ width: int = 576
25
+ num_inference_steps: int = 50
26
+ guidance_scale: float = 7.0
27
+
28
+ # Reproducibility
29
+ seed: int = -1
30
+
31
+ # Varnish post-processing settings
32
+ fps: int = 30
33
+ double_num_frames: bool = False
34
+ super_resolution: bool = False
35
+ grain_amount: float = 0.0
36
+ quality: int = 18 # CRF scale (0-51, lower is better)
37
+
38
+ # Audio settings
39
+ enable_audio: bool = False
40
+ audio_prompt: str = ""
41
+ audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
42
+
43
+ def validate_and_adjust(self) -> 'GenerationConfig':
44
+ """Validate and adjust parameters"""
45
+ # Ensure num_frames follows 4k + 1 format
46
+ k = (self.num_frames - 1) // 4
47
+ self.num_frames = (k * 4) + 1
48
+
49
+ # Set random seed if not specified
50
+ if self.seed == -1:
51
+ self.seed = random.randint(0, 2**32 - 1)
52
+
53
+ return self
54
+
55
+ class EndpointHandler:
56
+ """Handles video generation requests using HunyuanVideo and Varnish"""
57
+
58
+ def __init__(self, path: str = ""):
59
+ """Initialize handler with models
60
+
61
+ Args:
62
+ path: Path to model weights
63
+ """
64
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
65
+
66
+ # Initialize HunyuanVideo pipeline
67
+ self.pipeline = HunyuanVideoPipeline.from_pretrained(
68
+ path,
69
+ torch_dtype=torch.float16,
70
+ ).to(self.device)
71
+
72
+ # Initialize text encoders in float16
73
+ self.pipeline.text_encoder = self.pipeline.text_encoder.half()
74
+ self.pipeline.text_encoder_2 = self.pipeline.text_encoder_2.half()
75
+
76
+ # Initialize transformer in bfloat16
77
+ self.pipeline.transformer = self.pipeline.transformer.to(torch.bfloat16)
78
+
79
+ # Initialize VAE in float16
80
+ self.pipeline.vae = self.pipeline.vae.half()
81
+
82
+ # Initialize Varnish for post-processing
83
+ self.varnish = Varnish(
84
+ device=self.device,
85
+ model_base_dir="/repository/varnish"
86
+ )
87
+
88
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
89
+ """Process video generation requests
90
+
91
+ Args:
92
+ data: Request data containing:
93
+ - inputs (str): Prompt for video generation
94
+ - parameters (dict): Generation parameters
95
+
96
+ Returns:
97
+ Dictionary containing:
98
+ - video: Base64 encoded MP4 data URI
99
+ - content-type: MIME type
100
+ - metadata: Generation metadata
101
+ """
102
+ # Extract inputs
103
+ inputs = data.pop("inputs", data)
104
+ if isinstance(inputs, dict):
105
+ prompt = inputs.get("prompt", "")
106
+ else:
107
+ prompt = inputs
108
+
109
+ params = data.get("parameters", {})
110
+
111
+ # Create and validate config
112
+ config = GenerationConfig(
113
+ prompt=prompt,
114
+ negative_prompt=params.get("negative_prompt", ""),
115
+ num_frames=params.get("num_frames", 49),
116
+ height=params.get("height", 320),
117
+ width=params.get("width", 576),
118
+ num_inference_steps=params.get("num_inference_steps", 50),
119
+ guidance_scale=params.get("guidance_scale", 7.0),
120
+ seed=params.get("seed", -1),
121
+ fps=params.get("fps", 30),
122
+ double_num_frames=params.get("double_num_frames", False),
123
+ super_resolution=params.get("super_resolution", False),
124
+ grain_amount=params.get("grain_amount", 0.0),
125
+ quality=params.get("quality", 18),
126
+ enable_audio=params.get("enable_audio", False),
127
+ audio_prompt=params.get("audio_prompt", ""),
128
+ audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"),
129
+ ).validate_and_adjust()
130
+
131
+ try:
132
+ # Set random seeds
133
+ if config.seed != -1:
134
+ torch.manual_seed(config.seed)
135
+ random.seed(config.seed)
136
+ generator = torch.Generator(device=self.device).manual_seed(config.seed)
137
+ else:
138
+ generator = None
139
+
140
+ # Generate video frames
141
+ with torch.inference_mode():
142
+ output = self.pipeline(
143
+ prompt=config.prompt,
144
+ negative_prompt=config.negative_prompt,
145
+ num_frames=config.num_frames,
146
+ height=config.height,
147
+ width=config.width,
148
+ num_inference_steps=config.num_inference_steps,
149
+ guidance_scale=config.guidance_scale,
150
+ generator=generator,
151
+ output_type="pt",
152
+ ).frames
153
+
154
+ # Process with Varnish
155
+ import asyncio
156
+ try:
157
+ loop = asyncio.get_event_loop()
158
+ except RuntimeError:
159
+ loop = asyncio.new_event_loop()
160
+ asyncio.set_event_loop(loop)
161
+
162
+ result = loop.run_until_complete(
163
+ self.varnish(
164
+ input_data=output,
165
+ fps=config.fps,
166
+ double_num_frames=config.double_num_frames,
167
+ super_resolution=config.super_resolution,
168
+ grain_amount=config.grain_amount,
169
+ enable_audio=config.enable_audio,
170
+ audio_prompt=config.audio_prompt,
171
+ audio_negative_prompt=config.audio_negative_prompt,
172
+ )
173
+ )
174
+
175
+ # Get video data URI
176
+ video_uri = loop.run_until_complete(
177
+ result.write(
178
+ type="data-uri",
179
+ quality=config.quality
180
+ )
181
+ )
182
+
183
+ return {
184
+ "video": video_uri,
185
+ "content-type": "video/mp4",
186
+ "metadata": {
187
+ "width": result.metadata.width,
188
+ "height": result.metadata.height,
189
+ "num_frames": result.metadata.frame_count,
190
+ "fps": result.metadata.fps,
191
+ "duration": result.metadata.duration,
192
+ "seed": config.seed,
193
+ }
194
+ }
195
+
196
+ except Exception as e:
197
+ logger.error(f"Error generating video: {str(e)}")
198
+ raise RuntimeError(f"Failed to generate video: {str(e)}")