Spaces:
Sleeping
Sleeping
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +167 -0
- .isort.cfg +7 -0
- .pre-commit-config.yaml +39 -0
- app.py +102 -205
- docs/dsp.md +0 -25
- docs/pab.md +0 -121
- eval/pab/commom_metrics/README.md +0 -6
- eval/pab/commom_metrics/calculate_lpips.py +0 -97
- eval/pab/commom_metrics/calculate_psnr.py +0 -90
- eval/pab/commom_metrics/calculate_ssim.py +0 -116
- eval/pab/commom_metrics/eval.py +0 -160
- eval/pab/experiments/attention_ablation.py +0 -60
- eval/pab/experiments/components_ablation.py +0 -46
- eval/pab/experiments/latte.py +0 -57
- eval/pab/experiments/opensora.py +0 -44
- eval/pab/experiments/opensora_plan.py +0 -57
- eval/pab/experiments/utils.py +0 -22
- eval/pab/vbench/VBench_full_info.json +0 -0
- eval/pab/vbench/cal_vbench.py +0 -154
- eval/pab/vbench/run_vbench.py +0 -52
- examples/cogvideo/sample.py +0 -14
- examples/latte/sample.py +0 -24
- examples/open_sora/sample.py +0 -24
- examples/open_sora_plan/sample.py +0 -24
- videosys/__init__.py +9 -13
- videosys/core/engine.py +2 -4
- videosys/core/pab_mgr.py +43 -175
- videosys/datasets/dataloader.py +0 -94
- videosys/datasets/image_transform.py +0 -42
- videosys/datasets/video_transform.py +0 -441
- videosys/diffusion/__init__.py +0 -41
- videosys/diffusion/diffusion_utils.py +0 -79
- videosys/diffusion/gaussian_diffusion.py +0 -829
- videosys/diffusion/respace.py +0 -119
- videosys/diffusion/timestep_sampler.py +0 -143
- {eval/pab/commom_metrics → videosys/models/autoencoders}/__init__.py +0 -0
- videosys/models/{cogvideo/autoencoder_kl.py → autoencoders/autoencoder_kl_cogvideox.py} +328 -94
- videosys/models/{open_sora/vae.py → autoencoders/autoencoder_kl_open_sora.py} +2 -9
- videosys/models/{open_sora_plan/ae.py → autoencoders/autoencoder_kl_open_sora_plan.py} +797 -14
- videosys/models/cogvideo/__init__.py +0 -6
- videosys/models/cogvideo/modules.py +0 -317
- videosys/models/cogvideo/retrieve_timesteps.py +0 -74
- videosys/models/latte/__init__.py +0 -7
- {eval/pab/experiments → videosys/models/modules}/__init__.py +0 -0
- videosys/models/modules/activations.py +3 -0
- videosys/{modules/attn.py → models/modules/attentions.py} +45 -131
- videosys/models/modules/downsampling.py +71 -0
- videosys/models/{open_sora/modules.py → modules/embeddings.py} +171 -209
- videosys/models/modules/normalization.py +102 -0
- videosys/models/modules/upsampling.py +67 -0
.gitignore
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
outputs/
|
2 |
+
processed/
|
3 |
+
profile/
|
4 |
+
|
5 |
+
# Byte-compiled / optimized / DLL files
|
6 |
+
__pycache__/
|
7 |
+
*.py[cod]
|
8 |
+
*$py.class
|
9 |
+
|
10 |
+
# C extensions
|
11 |
+
*.so
|
12 |
+
|
13 |
+
# Distribution / packaging
|
14 |
+
.Python
|
15 |
+
build/
|
16 |
+
develop-eggs/
|
17 |
+
dist/
|
18 |
+
downloads/
|
19 |
+
eggs/
|
20 |
+
.eggs/
|
21 |
+
lib/
|
22 |
+
lib64/
|
23 |
+
parts/
|
24 |
+
sdist/
|
25 |
+
var/
|
26 |
+
wheels/
|
27 |
+
pip-wheel-metadata/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
db.sqlite3
|
66 |
+
db.sqlite3-journal
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
docs/.build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
target/
|
81 |
+
|
82 |
+
# Jupyter Notebook
|
83 |
+
.ipynb_checkpoints
|
84 |
+
|
85 |
+
# IPython
|
86 |
+
profile_default/
|
87 |
+
ipython_config.py
|
88 |
+
|
89 |
+
# pyenv
|
90 |
+
.python-version
|
91 |
+
|
92 |
+
# pipenv
|
93 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
94 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
95 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
96 |
+
# install all needed dependencies.
|
97 |
+
#Pipfile.lock
|
98 |
+
|
99 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
100 |
+
__pypackages__/
|
101 |
+
|
102 |
+
# Celery stuff
|
103 |
+
celerybeat-schedule
|
104 |
+
celerybeat.pid
|
105 |
+
|
106 |
+
# SageMath parsed files
|
107 |
+
*.sage.py
|
108 |
+
|
109 |
+
# Environments
|
110 |
+
.env
|
111 |
+
.venv
|
112 |
+
env/
|
113 |
+
venv/
|
114 |
+
ENV/
|
115 |
+
env.bak/
|
116 |
+
venv.bak/
|
117 |
+
|
118 |
+
# Spyder project settings
|
119 |
+
.spyderproject
|
120 |
+
.spyproject
|
121 |
+
|
122 |
+
# Rope project settings
|
123 |
+
.ropeproject
|
124 |
+
|
125 |
+
# mkdocs documentation
|
126 |
+
/site
|
127 |
+
|
128 |
+
# mypy
|
129 |
+
.mypy_cache/
|
130 |
+
.dmypy.json
|
131 |
+
dmypy.json
|
132 |
+
|
133 |
+
# Pyre type checker
|
134 |
+
.pyre/
|
135 |
+
|
136 |
+
# IDE
|
137 |
+
.idea/
|
138 |
+
.vscode/
|
139 |
+
|
140 |
+
# macos
|
141 |
+
*.DS_Store
|
142 |
+
#data/
|
143 |
+
|
144 |
+
docs/.build
|
145 |
+
|
146 |
+
# pytorch checkpoint
|
147 |
+
*.pt
|
148 |
+
|
149 |
+
# ignore any kernel build files
|
150 |
+
.o
|
151 |
+
.so
|
152 |
+
|
153 |
+
# ignore python interface defition file
|
154 |
+
.pyi
|
155 |
+
|
156 |
+
# ignore coverage test file
|
157 |
+
coverage.lcov
|
158 |
+
coverage.xml
|
159 |
+
|
160 |
+
# ignore testmon and coverage files
|
161 |
+
.coverage
|
162 |
+
.testmondata*
|
163 |
+
|
164 |
+
pretrained
|
165 |
+
samples
|
166 |
+
cache_dir
|
167 |
+
test_outputs
|
.isort.cfg
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[settings]
|
2 |
+
line_length = 120
|
3 |
+
multi_line_output=3
|
4 |
+
include_trailing_comma = true
|
5 |
+
ignore_comments = true
|
6 |
+
profile = black
|
7 |
+
honor_noqa = true
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
|
3 |
+
- repo: https://github.com/PyCQA/autoflake
|
4 |
+
rev: v2.2.1
|
5 |
+
hooks:
|
6 |
+
- id: autoflake
|
7 |
+
name: autoflake (python)
|
8 |
+
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
|
9 |
+
|
10 |
+
- repo: https://github.com/pycqa/isort
|
11 |
+
rev: 5.12.0
|
12 |
+
hooks:
|
13 |
+
- id: isort
|
14 |
+
name: sort all imports (python)
|
15 |
+
|
16 |
+
- repo: https://github.com/psf/black-pre-commit-mirror
|
17 |
+
rev: 23.9.1
|
18 |
+
hooks:
|
19 |
+
- id: black
|
20 |
+
name: black formatter
|
21 |
+
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
|
22 |
+
|
23 |
+
- repo: https://github.com/pre-commit/mirrors-clang-format
|
24 |
+
rev: v13.0.1
|
25 |
+
hooks:
|
26 |
+
- id: clang-format
|
27 |
+
name: clang formatter
|
28 |
+
types_or: [c++, c]
|
29 |
+
|
30 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
31 |
+
rev: v4.3.0
|
32 |
+
hooks:
|
33 |
+
- id: check-yaml
|
34 |
+
- id: check-merge-conflict
|
35 |
+
- id: check-case-conflict
|
36 |
+
- id: trailing-whitespace
|
37 |
+
- id: end-of-file-fixer
|
38 |
+
- id: mixed-line-ending
|
39 |
+
args: ['--fix=lf']
|
app.py
CHANGED
@@ -2,131 +2,107 @@ import os
|
|
2 |
|
3 |
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), ".tmp_outputs")
|
4 |
|
5 |
-
import torch
|
6 |
-
from openai import OpenAI
|
7 |
-
from time import time
|
8 |
-
import tempfile
|
9 |
-
import uuid
|
10 |
import logging
|
|
|
|
|
|
|
11 |
import gradio as gr
|
12 |
-
from videosys import CogVideoConfig, VideoSysEngine
|
13 |
-
from videosys.models.cogvideo.pipeline import CogVideoPABConfig
|
14 |
import psutil
|
15 |
-
import
|
16 |
-
|
17 |
|
|
|
18 |
|
19 |
logging.basicConfig(level=logging.INFO)
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
22 |
-
dtype = torch.
|
23 |
-
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
24 |
-
|
25 |
-
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
26 |
-
There are a few rules to follow:
|
27 |
-
|
28 |
-
You will only ever output a single video description per user request.
|
29 |
-
|
30 |
-
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
|
31 |
-
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
|
32 |
-
|
33 |
-
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
|
34 |
-
"""
|
35 |
|
36 |
-
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
37 |
-
if not os.environ.get("OPENAI_API_KEY"):
|
38 |
-
return prompt
|
39 |
-
client = OpenAI()
|
40 |
-
text = prompt.strip()
|
41 |
-
|
42 |
-
for i in range(retry_times):
|
43 |
-
response = client.chat.completions.create(
|
44 |
-
messages=[
|
45 |
-
{"role": "system", "content": sys_prompt},
|
46 |
-
{
|
47 |
-
"role": "user",
|
48 |
-
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
|
49 |
-
},
|
50 |
-
{
|
51 |
-
"role": "assistant",
|
52 |
-
"content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
|
53 |
-
},
|
54 |
-
{
|
55 |
-
"role": "user",
|
56 |
-
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
|
57 |
-
},
|
58 |
-
{
|
59 |
-
"role": "assistant",
|
60 |
-
"content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
|
61 |
-
},
|
62 |
-
{
|
63 |
-
"role": "user",
|
64 |
-
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
|
65 |
-
},
|
66 |
-
{
|
67 |
-
"role": "assistant",
|
68 |
-
"content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
|
69 |
-
},
|
70 |
-
{
|
71 |
-
"role": "user",
|
72 |
-
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
|
73 |
-
},
|
74 |
-
],
|
75 |
-
model="glm-4-0520",
|
76 |
-
temperature=0.01,
|
77 |
-
top_p=0.7,
|
78 |
-
stream=False,
|
79 |
-
max_tokens=250,
|
80 |
-
)
|
81 |
-
if response.choices:
|
82 |
-
return response.choices[0].message.content
|
83 |
-
return prompt
|
84 |
|
85 |
-
def load_model(enable_video_sys=False, pab_threshold=[100, 850],
|
86 |
-
pab_config =
|
87 |
-
config =
|
88 |
engine = VideoSysEngine(config)
|
89 |
return engine
|
90 |
|
|
|
91 |
def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
|
92 |
-
|
93 |
-
video = engine.generate(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0]
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
unique_filename = f"{uuid.uuid4().hex}.mp4"
|
98 |
-
output_path = os.path.join("./temp_outputs", unique_filename)
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
except Exception as e:
|
103 |
-
logger.error(f"An error occurred: {str(e)}")
|
104 |
-
return None
|
105 |
|
106 |
|
107 |
def get_server_status():
|
108 |
cpu_percent = psutil.cpu_percent()
|
109 |
memory = psutil.virtual_memory()
|
110 |
-
disk = psutil.disk_usage(
|
111 |
gpus = GPUtil.getGPUs()
|
112 |
gpu_info = []
|
113 |
for gpu in gpus:
|
114 |
-
gpu_info.append(
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
return {
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
}
|
128 |
|
129 |
|
|
|
|
|
|
|
|
|
130 |
|
131 |
css = """
|
132 |
body {
|
@@ -137,16 +113,17 @@ body {
|
|
137 |
padding: 20px;
|
138 |
}
|
139 |
|
|
|
140 |
.container {
|
141 |
display: flex;
|
142 |
flex-direction: column;
|
143 |
-
gap:
|
144 |
}
|
145 |
|
146 |
.row {
|
147 |
display: flex;
|
148 |
flex-wrap: wrap;
|
149 |
-
gap:
|
150 |
}
|
151 |
|
152 |
.column {
|
@@ -186,12 +163,6 @@ body {
|
|
186 |
font-size: 0.9em !important;
|
187 |
line-height: 1.2 !important;
|
188 |
}
|
189 |
-
.server-status button {
|
190 |
-
padding: 1px 8px !important;
|
191 |
-
height: 22px !important;
|
192 |
-
font-size: 0.9em !important;
|
193 |
-
margin-top: 2px !important;
|
194 |
-
}
|
195 |
.server-status .textbox {
|
196 |
gap: 0 !important;
|
197 |
}
|
@@ -215,150 +186,76 @@ body {
|
|
215 |
"""
|
216 |
|
217 |
with gr.Blocks(css=css) as demo:
|
218 |
-
gr.HTML(
|
|
|
219 |
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
220 |
-
VideoSys
|
221 |
</div>
|
222 |
<div style="text-align: center; font-size: 15px;">
|
223 |
🌐 Github: <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys">https://github.com/NUS-HPC-AI-Lab/VideoSys</a><br>
|
224 |
-
|
225 |
-
⚠️ This demo is for academic research and experiential use only.
|
226 |
Users should strictly adhere to local laws and ethics.<br>
|
227 |
-
|
228 |
💡 This demo only demonstrates single-device inference. To experience the full power of VideoSys, please deploy it with multiple devices.<br><br>
|
229 |
</div>
|
230 |
</div>
|
231 |
-
"""
|
|
|
232 |
|
233 |
with gr.Row():
|
234 |
with gr.Column():
|
235 |
-
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="Sunset over the sea.", lines=
|
236 |
-
with gr.Row():
|
237 |
-
gr.Markdown(
|
238 |
-
"✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one."
|
239 |
-
)
|
240 |
-
enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
|
241 |
|
242 |
with gr.Column():
|
243 |
-
gr.Markdown(
|
244 |
-
"**Optional Parameters** (default values are recommended)<br>"
|
245 |
-
"Turn Inference Steps larger if you want more detailed video, but it will be slower.<br>"
|
246 |
-
"50 steps are recommended for most cases. will cause 120 seconds for inference.<br>"
|
247 |
-
)
|
248 |
with gr.Row():
|
249 |
num_inference_steps = gr.Number(label="Inference Steps", value=50)
|
250 |
guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
|
251 |
-
pab_gap = gr.Number(label="PAB Gap", value=2, precision=0)
|
252 |
-
pab_threshold = gr.Textbox(label="PAB Threshold", value="100,850", lines=1)
|
253 |
with gr.Row():
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
255 |
generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys (Faster)")
|
|
|
256 |
with gr.Column(elem_classes="server-status"):
|
257 |
gr.Markdown("#### Server Status")
|
258 |
-
|
259 |
with gr.Row():
|
260 |
cpu_status = gr.Textbox(label="CPU", scale=1)
|
261 |
memory_status = gr.Textbox(label="Memory", scale=1)
|
262 |
-
|
263 |
with gr.Row():
|
264 |
disk_status = gr.Textbox(label="Disk", scale=1)
|
265 |
gpu_status = gr.Textbox(label="GPU Memory", scale=1)
|
266 |
-
|
267 |
with gr.Row():
|
268 |
-
refresh_button = gr.Button("Refresh"
|
269 |
|
270 |
with gr.Column():
|
271 |
-
with gr.Row():
|
272 |
-
video_output = gr.Video(label="CogVideoX", width=720, height=480)
|
273 |
-
with gr.Row():
|
274 |
-
download_video_button = gr.File(label="📥 Download Video", visible=False)
|
275 |
-
elapsed_time = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
|
276 |
with gr.Row():
|
277 |
video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
|
278 |
with gr.Row():
|
279 |
-
|
280 |
-
elapsed_time_vs = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
|
281 |
-
# with gr.Column():
|
282 |
-
# task_status = gr.Textbox(label="任务状态", visible=False)
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
|
288 |
-
engine = load_model()
|
289 |
-
t = time()
|
290 |
-
video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
|
291 |
-
elapsed_time = time() - t
|
292 |
-
video_update = gr.update(visible=True, value=video_path)
|
293 |
-
elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
|
294 |
-
|
295 |
-
return video_path, video_update, elapsed_time
|
296 |
-
|
297 |
-
def generate_vs(prompt, num_inference_steps, guidance_scale, threshold, gap, progress=gr.Progress(track_tqdm=True)):
|
298 |
-
threshold = [int(i) for i in threshold.split(",")]
|
299 |
-
gap = int(gap)
|
300 |
-
engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_gap=gap)
|
301 |
-
t = time()
|
302 |
-
video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
|
303 |
-
elapsed_time = time() - t
|
304 |
-
video_update = gr.update(visible=True, value=video_path)
|
305 |
-
elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
|
306 |
-
|
307 |
-
return video_path, video_update, elapsed_time
|
308 |
-
|
309 |
-
def enhance_prompt_func(prompt):
|
310 |
-
return convert_prompt(prompt, retry_times=1)
|
311 |
-
|
312 |
-
def get_server_status():
|
313 |
-
cpu_percent = psutil.cpu_percent()
|
314 |
-
memory = psutil.virtual_memory()
|
315 |
-
disk = psutil.disk_usage('/')
|
316 |
-
try:
|
317 |
-
gpus = GPUtil.getGPUs()
|
318 |
-
if gpus:
|
319 |
-
gpu = gpus[0]
|
320 |
-
gpu_memory = f"{gpu.memoryUsed}/{gpu.memoryTotal}MB ({gpu.memoryUtil*100:.1f}%)"
|
321 |
-
else:
|
322 |
-
gpu_memory = "No GPU found"
|
323 |
-
except:
|
324 |
-
gpu_memory = "GPU information unavailable"
|
325 |
-
|
326 |
-
return {
|
327 |
-
'cpu': f"{cpu_percent}%",
|
328 |
-
'memory': f"{memory.percent}%",
|
329 |
-
'disk': f"{disk.percent}%",
|
330 |
-
'gpu_memory': gpu_memory
|
331 |
-
}
|
332 |
-
|
333 |
-
|
334 |
-
def update_server_status():
|
335 |
-
status = get_server_status()
|
336 |
-
return (
|
337 |
-
status['cpu'],
|
338 |
-
status['memory'],
|
339 |
-
status['disk'],
|
340 |
-
status['gpu_memory']
|
341 |
-
)
|
342 |
|
343 |
-
|
344 |
generate_button.click(
|
345 |
generate_vanilla,
|
346 |
inputs=[prompt, num_inference_steps, guidance_scale],
|
347 |
-
outputs=[video_output
|
348 |
)
|
349 |
|
350 |
generate_button_vs.click(
|
351 |
generate_vs,
|
352 |
-
inputs=[prompt, num_inference_steps, guidance_scale,
|
353 |
-
outputs=[video_output_vs
|
354 |
)
|
355 |
|
356 |
-
enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
|
357 |
-
|
358 |
-
|
359 |
refresh_button.click(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status])
|
360 |
demo.load(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status], every=1)
|
361 |
|
362 |
if __name__ == "__main__":
|
363 |
demo.queue(max_size=10, default_concurrency_limit=1)
|
364 |
-
demo.launch()
|
|
|
2 |
|
3 |
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), ".tmp_outputs")
|
4 |
|
|
|
|
|
|
|
|
|
|
|
5 |
import logging
|
6 |
+
import uuid
|
7 |
+
|
8 |
+
import GPUtil
|
9 |
import gradio as gr
|
|
|
|
|
10 |
import psutil
|
11 |
+
import torch
|
|
|
12 |
|
13 |
+
from videosys import CogVideoXConfig, CogVideoXPABConfig, VideoSysEngine
|
14 |
|
15 |
logging.basicConfig(level=logging.INFO)
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
18 |
+
dtype = torch.float16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
def load_model(enable_video_sys=False, pab_threshold=[100, 850], pab_range=2):
|
22 |
+
pab_config = CogVideoXPABConfig(spatial_threshold=pab_threshold, spatial_range=pab_range)
|
23 |
+
config = CogVideoXConfig(world_size=1, enable_pab=enable_video_sys, pab_config=pab_config)
|
24 |
engine = VideoSysEngine(config)
|
25 |
return engine
|
26 |
|
27 |
+
|
28 |
def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
|
29 |
+
video = engine.generate(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0]
|
|
|
30 |
|
31 |
+
unique_filename = f"{uuid.uuid4().hex}.mp4"
|
32 |
+
output_path = os.path.join("./.tmp_outputs", unique_filename)
|
|
|
|
|
33 |
|
34 |
+
engine.save_video(video, output_path)
|
35 |
+
return output_path
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
def get_server_status():
|
39 |
cpu_percent = psutil.cpu_percent()
|
40 |
memory = psutil.virtual_memory()
|
41 |
+
disk = psutil.disk_usage("/")
|
42 |
gpus = GPUtil.getGPUs()
|
43 |
gpu_info = []
|
44 |
for gpu in gpus:
|
45 |
+
gpu_info.append(
|
46 |
+
{
|
47 |
+
"id": gpu.id,
|
48 |
+
"name": gpu.name,
|
49 |
+
"load": f"{gpu.load*100:.1f}%",
|
50 |
+
"memory_used": f"{gpu.memoryUsed}MB",
|
51 |
+
"memory_total": f"{gpu.memoryTotal}MB",
|
52 |
+
}
|
53 |
+
)
|
54 |
+
|
55 |
+
return {"cpu": f"{cpu_percent}%", "memory": f"{memory.percent}%", "disk": f"{disk.percent}%", "gpu": gpu_info}
|
56 |
+
|
57 |
+
|
58 |
+
def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
|
59 |
+
engine = load_model()
|
60 |
+
video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
|
61 |
+
return video_path
|
62 |
+
|
63 |
+
|
64 |
+
def generate_vs(
|
65 |
+
prompt,
|
66 |
+
num_inference_steps,
|
67 |
+
guidance_scale,
|
68 |
+
threshold_start,
|
69 |
+
threshold_end,
|
70 |
+
gap,
|
71 |
+
progress=gr.Progress(track_tqdm=True),
|
72 |
+
):
|
73 |
+
threshold = [int(threshold_end), int(threshold_start)]
|
74 |
+
gap = int(gap)
|
75 |
+
engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_range=gap)
|
76 |
+
video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
|
77 |
+
return video_path
|
78 |
+
|
79 |
+
|
80 |
+
def get_server_status():
|
81 |
+
cpu_percent = psutil.cpu_percent()
|
82 |
+
memory = psutil.virtual_memory()
|
83 |
+
disk = psutil.disk_usage("/")
|
84 |
+
try:
|
85 |
+
gpus = GPUtil.getGPUs()
|
86 |
+
if gpus:
|
87 |
+
gpu = gpus[0]
|
88 |
+
gpu_memory = f"{gpu.memoryUsed}/{gpu.memoryTotal}MB ({gpu.memoryUtil*100:.1f}%)"
|
89 |
+
else:
|
90 |
+
gpu_memory = "No GPU found"
|
91 |
+
except:
|
92 |
+
gpu_memory = "GPU information unavailable"
|
93 |
+
|
94 |
return {
|
95 |
+
"cpu": f"{cpu_percent}%",
|
96 |
+
"memory": f"{memory.percent}%",
|
97 |
+
"disk": f"{disk.percent}%",
|
98 |
+
"gpu_memory": gpu_memory,
|
99 |
}
|
100 |
|
101 |
|
102 |
+
def update_server_status():
|
103 |
+
status = get_server_status()
|
104 |
+
return (status["cpu"], status["memory"], status["disk"], status["gpu_memory"])
|
105 |
+
|
106 |
|
107 |
css = """
|
108 |
body {
|
|
|
113 |
padding: 20px;
|
114 |
}
|
115 |
|
116 |
+
|
117 |
.container {
|
118 |
display: flex;
|
119 |
flex-direction: column;
|
120 |
+
gap: 10px;
|
121 |
}
|
122 |
|
123 |
.row {
|
124 |
display: flex;
|
125 |
flex-wrap: wrap;
|
126 |
+
gap: 10px;
|
127 |
}
|
128 |
|
129 |
.column {
|
|
|
163 |
font-size: 0.9em !important;
|
164 |
line-height: 1.2 !important;
|
165 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
.server-status .textbox {
|
167 |
gap: 0 !important;
|
168 |
}
|
|
|
186 |
"""
|
187 |
|
188 |
with gr.Blocks(css=css) as demo:
|
189 |
+
gr.HTML(
|
190 |
+
"""
|
191 |
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
192 |
+
VideoSys for CogVideoX🤗
|
193 |
</div>
|
194 |
<div style="text-align: center; font-size: 15px;">
|
195 |
🌐 Github: <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys">https://github.com/NUS-HPC-AI-Lab/VideoSys</a><br>
|
196 |
+
|
197 |
+
⚠️ This demo is for academic research and experiential use only.
|
198 |
Users should strictly adhere to local laws and ethics.<br>
|
199 |
+
|
200 |
💡 This demo only demonstrates single-device inference. To experience the full power of VideoSys, please deploy it with multiple devices.<br><br>
|
201 |
</div>
|
202 |
</div>
|
203 |
+
"""
|
204 |
+
)
|
205 |
|
206 |
with gr.Row():
|
207 |
with gr.Column():
|
208 |
+
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="Sunset over the sea.", lines=4)
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
with gr.Column():
|
211 |
+
gr.Markdown("**Generation Parameters**<br>")
|
|
|
|
|
|
|
|
|
212 |
with gr.Row():
|
213 |
num_inference_steps = gr.Number(label="Inference Steps", value=50)
|
214 |
guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
|
|
|
|
|
215 |
with gr.Row():
|
216 |
+
pab_range = gr.Number(
|
217 |
+
label="PAB Broadcast Range", value=2, precision=0, info="Broadcast timesteps range."
|
218 |
+
)
|
219 |
+
pab_threshold_start = gr.Number(label="PAB Start Timestep", value=850, info="Start from step 1000.")
|
220 |
+
pab_threshold_end = gr.Number(label="PAB End Timestep", value=100, info="End at step 0.")
|
221 |
+
with gr.Row():
|
222 |
generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys (Faster)")
|
223 |
+
generate_button = gr.Button("🎬 Generate Video (Original)")
|
224 |
with gr.Column(elem_classes="server-status"):
|
225 |
gr.Markdown("#### Server Status")
|
226 |
+
|
227 |
with gr.Row():
|
228 |
cpu_status = gr.Textbox(label="CPU", scale=1)
|
229 |
memory_status = gr.Textbox(label="Memory", scale=1)
|
230 |
+
|
231 |
with gr.Row():
|
232 |
disk_status = gr.Textbox(label="Disk", scale=1)
|
233 |
gpu_status = gr.Textbox(label="GPU Memory", scale=1)
|
234 |
+
|
235 |
with gr.Row():
|
236 |
+
refresh_button = gr.Button("Refresh")
|
237 |
|
238 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
239 |
with gr.Row():
|
240 |
video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
|
241 |
with gr.Row():
|
242 |
+
video_output = gr.Video(label="CogVideoX", width=720, height=480)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
|
|
244 |
generate_button.click(
|
245 |
generate_vanilla,
|
246 |
inputs=[prompt, num_inference_steps, guidance_scale],
|
247 |
+
outputs=[video_output],
|
248 |
)
|
249 |
|
250 |
generate_button_vs.click(
|
251 |
generate_vs,
|
252 |
+
inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold_start, pab_threshold_end, pab_range],
|
253 |
+
outputs=[video_output_vs],
|
254 |
)
|
255 |
|
|
|
|
|
|
|
256 |
refresh_button.click(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status])
|
257 |
demo.load(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status], every=1)
|
258 |
|
259 |
if __name__ == "__main__":
|
260 |
demo.queue(max_size=10, default_concurrency_limit=1)
|
261 |
+
demo.launch()
|
docs/dsp.md
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
# DSP
|
2 |
-
|
3 |
-
paper: https://arxiv.org/abs/2403.10266
|
4 |
-
|
5 |
-
![dsp_overview](../assets/figures/dsp_overview.png)
|
6 |
-
|
7 |
-
|
8 |
-
DSP (Dynamic Sequence Parallelism) is a novel, elegant and super efficient sequence parallelism for [OpenSora](https://github.com/hpcaitech/Open-Sora), [Latte](https://github.com/Vchitect/Latte) and other multi-dimensional transformer architecture.
|
9 |
-
|
10 |
-
The key idea is to dynamically switch the parallelism dimension according to the current computation stage, leveraging the potential characteristics of multi-dimensional transformers. Compared with splitting head and sequence dimension as previous methods, it can reduce at least 75% of communication cost.
|
11 |
-
|
12 |
-
It achieves **3x** speed for training and **2x** speed for inference in OpenSora compared with sota sequence parallelism ([DeepSpeed Ulysses](https://arxiv.org/abs/2309.14509)). For a 10s (80 frames) of 512x512 video, the inference latency of OpenSora is:
|
13 |
-
|
14 |
-
| Method | 1xH800 | 8xH800 (DS Ulysses) | 8xH800 (DSP) |
|
15 |
-
| ------ | ------ | ------ | ------ |
|
16 |
-
| Latency(s) | 106 | 45 | 22 |
|
17 |
-
|
18 |
-
The following is DSP's end-to-end throughput for training of OpenSora:
|
19 |
-
|
20 |
-
![dsp_overview](../assets/figures/dsp_exp.png)
|
21 |
-
|
22 |
-
|
23 |
-
### Usage
|
24 |
-
|
25 |
-
DSP is currently supported for: OpenSora, OpenSoraPlan and Latte. To enable DSP, you just need to launch with multiple GPUs.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/pab.md
DELETED
@@ -1,121 +0,0 @@
|
|
1 |
-
# Pyramid Attention Broadcast(PAB)
|
2 |
-
|
3 |
-
[[paper](https://arxiv.org/abs/2408.12588)][[blog](https://arxiv.org/abs/2403.10266)]
|
4 |
-
|
5 |
-
Pyramid Attention Broadcast(PAB)(#pyramid-attention-broadcastpab)
|
6 |
-
- [Pyramid Attention Broadcast(PAB)](#pyramid-attention-broadcastpab)
|
7 |
-
- [Insights](#insights)
|
8 |
-
- [Pyramid Attention Broadcast (PAB) Mechanism](#pyramid-attention-broadcast-pab-mechanism)
|
9 |
-
- [Experimental Results](#experimental-results)
|
10 |
-
- [Usage](#usage)
|
11 |
-
- [Supported Models](#supported-models)
|
12 |
-
- [Configuration for PAB](#configuration-for-pab)
|
13 |
-
- [Parameters](#parameters)
|
14 |
-
- [Example Configuration](#example-configuration)
|
15 |
-
|
16 |
-
|
17 |
-
We introduce Pyramid Attention Broadcast (PAB), the first approach that achieves real-time DiT-based video generation. By mitigating redundant attention computation, PAB achieves up to 21.6 FPS with 10.6x acceleration, without sacrificing quality across popular DiT-based video generation models including Open-Sora, Open-Sora-Plan, and Latte. Notably, as a training-free approach, PAB can enpower any future DiT-based video generation models with real-time capabilities.
|
18 |
-
|
19 |
-
## Insights
|
20 |
-
|
21 |
-
![method](../assets/figures/pab_motivation.png)
|
22 |
-
|
23 |
-
Our study reveals two key insights of three **attention mechanisms** within video diffusion transformers:
|
24 |
-
- First, attention differences across time steps exhibit a U-shaped pattern, with significant variations occurring during the first and last 15% of steps, while the middle 70% of steps show very stable, minor differences.
|
25 |
-
- Second, within the stable middle segment, the variability differs among attention types:
|
26 |
-
- **Spatial attention** varies the most, involving high-frequency elements like edges and textures;
|
27 |
-
- **Temporal attention** exhibits mid-frequency variations related to movements and dynamics in videos;
|
28 |
-
- **Cross-modal attention** is the most stable, linking text with video content, analogous to low-frequency signals reflecting textual semantics.
|
29 |
-
|
30 |
-
## Pyramid Attention Broadcast (PAB) Mechanism
|
31 |
-
|
32 |
-
![method](../assets/figures/pab_method.png)
|
33 |
-
|
34 |
-
Building on these insights, we propose a **pyramid attention broadcast(PAB)** mechanism to minimize unnecessary computations and optimize the utility of each attention module, as shown in Figure[xx figure] below.
|
35 |
-
|
36 |
-
In the middle segment, we broadcast one step's attention outputs to its subsequent several steps, thereby significantly reducing the computational cost on attention modules.
|
37 |
-
|
38 |
-
For more efficient broadcast and minimum influence to effect, we set varied broadcast ranges for different attentions based on their stability and differences.
|
39 |
-
**The smaller the variation in attention, the broader the potential broadcast range.**
|
40 |
-
|
41 |
-
|
42 |
-
## Experimental Results
|
43 |
-
Here are the results of our experiments, more results are shown in https://oahzxl.github.io/PAB:
|
44 |
-
|
45 |
-
![pab_vis](../assets/figures/pab_vis.png)
|
46 |
-
|
47 |
-
|
48 |
-
## Usage
|
49 |
-
|
50 |
-
### Supported Models
|
51 |
-
|
52 |
-
PAB currently supports Open-Sora, Open-Sora-Plan, and Latte.
|
53 |
-
|
54 |
-
### Configuration for PAB
|
55 |
-
|
56 |
-
To efficiently use the Pyramid Attention Broadcast (PAB) mechanism, configure the following parameters to control the broadcasting for different attention types. This helps reduce computational costs by skipping certain steps based on attention stability.
|
57 |
-
|
58 |
-
#### Parameters
|
59 |
-
|
60 |
-
- **spatial_broadcast**: Enable or disable broadcasting for spatial attention.
|
61 |
-
- Type: `True` or `False`
|
62 |
-
|
63 |
-
- **spatial_threshold**: Set the range of diffusion steps within which spatial attention is applied.
|
64 |
-
- Format: `[min_value, max_value]`
|
65 |
-
|
66 |
-
- **spatial_gap**: Number of blocks in model to skip during broadcasting for spatial attention.
|
67 |
-
- Type: Integer
|
68 |
-
|
69 |
-
- **temporal_broadcast**: Enable or disable broadcasting for temporal attention.
|
70 |
-
- Type: `True` or `False`
|
71 |
-
|
72 |
-
- **temporal_threshold**: Set the range of diffusion steps within which temporal attention is applied.
|
73 |
-
- Format: `[min_value, max_value]`
|
74 |
-
|
75 |
-
- **temporal_gap**: Number of steps to skip during broadcasting for temporal attention.
|
76 |
-
- Type: Integer
|
77 |
-
|
78 |
-
- **cross_broadcast**: Enable or disable broadcasting for cross-modal attention.
|
79 |
-
- Type: `True` or `False`
|
80 |
-
|
81 |
-
- **cross_threshold**: Set the range of diffusion steps within which cross-modal attention is applied.
|
82 |
-
- Format: `[min_value, max_value]`
|
83 |
-
|
84 |
-
- **cross_gap**: Number of steps to skip during broadcasting for cross-modal attention.
|
85 |
-
- Type: Integer
|
86 |
-
|
87 |
-
#### Example Configuration
|
88 |
-
|
89 |
-
```yaml
|
90 |
-
spatial_broadcast: True
|
91 |
-
spatial_threshold: [100, 800]
|
92 |
-
spatial_gap: 2
|
93 |
-
|
94 |
-
temporal_broadcast: True
|
95 |
-
temporal_threshold: [100, 800]
|
96 |
-
temporal_gap: 3
|
97 |
-
|
98 |
-
cross_broadcast: True
|
99 |
-
cross_threshold: [100, 900]
|
100 |
-
cross_gap: 5
|
101 |
-
```
|
102 |
-
|
103 |
-
Explanation:
|
104 |
-
|
105 |
-
- **Spatial Attention**:
|
106 |
-
- Broadcasting enabled (`spatial_broadcast: True`)
|
107 |
-
- Applied within the threshold range of 100 to 800
|
108 |
-
- Skips every 2 steps (`spatial_gap: 2`)
|
109 |
-
- Active within the first 28 steps (`spatial_block: [0, 28]`)
|
110 |
-
|
111 |
-
- **Temporal Attention**:
|
112 |
-
- Broadcasting enabled (`temporal_broadcast: True`)
|
113 |
-
- Applied within the threshold range of 100 to 800
|
114 |
-
- Skips every 3 steps (`temporal_gap: 3`)
|
115 |
-
|
116 |
-
- **Cross-Modal Attention**:
|
117 |
-
- Broadcasting enabled (`cross_broadcast: True`)
|
118 |
-
- Applied within the threshold range of 100 to 900
|
119 |
-
- Skips every 5 steps (`cross_gap: 5`)
|
120 |
-
|
121 |
-
Adjust these settings based on your specific needs to optimize the performance of each attention mechanism.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/pab/commom_metrics/README.md
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
Common metrics
|
2 |
-
|
3 |
-
Include LPIPS, PSNR and SSIM.
|
4 |
-
|
5 |
-
The code is adapted from [common_metrics_on_video_quality
|
6 |
-
](https://github.com/JunyaoHu/common_metrics_on_video_quality).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/pab/commom_metrics/calculate_lpips.py
DELETED
@@ -1,97 +0,0 @@
|
|
1 |
-
import lpips
|
2 |
-
import numpy as np
|
3 |
-
import torch
|
4 |
-
|
5 |
-
spatial = True # Return a spatial map of perceptual distance.
|
6 |
-
|
7 |
-
# Linearly calibrated models (LPIPS)
|
8 |
-
loss_fn = lpips.LPIPS(net="alex", spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
|
9 |
-
# loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
|
10 |
-
|
11 |
-
|
12 |
-
def trans(x):
|
13 |
-
# if greyscale images add channel
|
14 |
-
if x.shape[-3] == 1:
|
15 |
-
x = x.repeat(1, 1, 3, 1, 1)
|
16 |
-
|
17 |
-
# value range [0, 1] -> [-1, 1]
|
18 |
-
x = x * 2 - 1
|
19 |
-
|
20 |
-
return x
|
21 |
-
|
22 |
-
|
23 |
-
def calculate_lpips(videos1, videos2, device):
|
24 |
-
# image should be RGB, IMPORTANT: normalized to [-1,1]
|
25 |
-
|
26 |
-
assert videos1.shape == videos2.shape
|
27 |
-
|
28 |
-
# videos [batch_size, timestamps, channel, h, w]
|
29 |
-
|
30 |
-
# support grayscale input, if grayscale -> channel*3
|
31 |
-
# value range [0, 1] -> [-1, 1]
|
32 |
-
videos1 = trans(videos1)
|
33 |
-
videos2 = trans(videos2)
|
34 |
-
|
35 |
-
lpips_results = []
|
36 |
-
|
37 |
-
for video_num in range(videos1.shape[0]):
|
38 |
-
# get a video
|
39 |
-
# video [timestamps, channel, h, w]
|
40 |
-
video1 = videos1[video_num]
|
41 |
-
video2 = videos2[video_num]
|
42 |
-
|
43 |
-
lpips_results_of_a_video = []
|
44 |
-
for clip_timestamp in range(len(video1)):
|
45 |
-
# get a img
|
46 |
-
# img [timestamps[x], channel, h, w]
|
47 |
-
# img [channel, h, w] tensor
|
48 |
-
|
49 |
-
img1 = video1[clip_timestamp].unsqueeze(0).to(device)
|
50 |
-
img2 = video2[clip_timestamp].unsqueeze(0).to(device)
|
51 |
-
|
52 |
-
loss_fn.to(device)
|
53 |
-
|
54 |
-
# calculate lpips of a video
|
55 |
-
lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
|
56 |
-
lpips_results.append(lpips_results_of_a_video)
|
57 |
-
|
58 |
-
lpips_results = np.array(lpips_results)
|
59 |
-
|
60 |
-
lpips = {}
|
61 |
-
lpips_std = {}
|
62 |
-
|
63 |
-
for clip_timestamp in range(len(video1)):
|
64 |
-
lpips[clip_timestamp] = np.mean(lpips_results[:, clip_timestamp])
|
65 |
-
lpips_std[clip_timestamp] = np.std(lpips_results[:, clip_timestamp])
|
66 |
-
|
67 |
-
result = {
|
68 |
-
"value": lpips,
|
69 |
-
"value_std": lpips_std,
|
70 |
-
"video_setting": video1.shape,
|
71 |
-
"video_setting_name": "time, channel, heigth, width",
|
72 |
-
}
|
73 |
-
|
74 |
-
return result
|
75 |
-
|
76 |
-
|
77 |
-
# test code / using example
|
78 |
-
|
79 |
-
|
80 |
-
def main():
|
81 |
-
NUMBER_OF_VIDEOS = 8
|
82 |
-
VIDEO_LENGTH = 50
|
83 |
-
CHANNEL = 3
|
84 |
-
SIZE = 64
|
85 |
-
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
86 |
-
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
87 |
-
device = torch.device("cuda")
|
88 |
-
# device = torch.device("cpu")
|
89 |
-
|
90 |
-
import json
|
91 |
-
|
92 |
-
result = calculate_lpips(videos1, videos2, device)
|
93 |
-
print(json.dumps(result, indent=4))
|
94 |
-
|
95 |
-
|
96 |
-
if __name__ == "__main__":
|
97 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/pab/commom_metrics/calculate_psnr.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
|
6 |
-
|
7 |
-
def img_psnr(img1, img2):
|
8 |
-
# [0,1]
|
9 |
-
# compute mse
|
10 |
-
# mse = np.mean((img1-img2)**2)
|
11 |
-
mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
|
12 |
-
# compute psnr
|
13 |
-
if mse < 1e-10:
|
14 |
-
return 100
|
15 |
-
psnr = 20 * math.log10(1 / math.sqrt(mse))
|
16 |
-
return psnr
|
17 |
-
|
18 |
-
|
19 |
-
def trans(x):
|
20 |
-
return x
|
21 |
-
|
22 |
-
|
23 |
-
def calculate_psnr(videos1, videos2):
|
24 |
-
# videos [batch_size, timestamps, channel, h, w]
|
25 |
-
|
26 |
-
assert videos1.shape == videos2.shape
|
27 |
-
|
28 |
-
videos1 = trans(videos1)
|
29 |
-
videos2 = trans(videos2)
|
30 |
-
|
31 |
-
psnr_results = []
|
32 |
-
|
33 |
-
for video_num in range(videos1.shape[0]):
|
34 |
-
# get a video
|
35 |
-
# video [timestamps, channel, h, w]
|
36 |
-
video1 = videos1[video_num]
|
37 |
-
video2 = videos2[video_num]
|
38 |
-
|
39 |
-
psnr_results_of_a_video = []
|
40 |
-
for clip_timestamp in range(len(video1)):
|
41 |
-
# get a img
|
42 |
-
# img [timestamps[x], channel, h, w]
|
43 |
-
# img [channel, h, w] numpy
|
44 |
-
|
45 |
-
img1 = video1[clip_timestamp].numpy()
|
46 |
-
img2 = video2[clip_timestamp].numpy()
|
47 |
-
|
48 |
-
# calculate psnr of a video
|
49 |
-
psnr_results_of_a_video.append(img_psnr(img1, img2))
|
50 |
-
|
51 |
-
psnr_results.append(psnr_results_of_a_video)
|
52 |
-
|
53 |
-
psnr_results = np.array(psnr_results)
|
54 |
-
|
55 |
-
psnr = {}
|
56 |
-
psnr_std = {}
|
57 |
-
|
58 |
-
for clip_timestamp in range(len(video1)):
|
59 |
-
psnr[clip_timestamp] = np.mean(psnr_results[:, clip_timestamp])
|
60 |
-
psnr_std[clip_timestamp] = np.std(psnr_results[:, clip_timestamp])
|
61 |
-
|
62 |
-
result = {
|
63 |
-
"value": psnr,
|
64 |
-
"value_std": psnr_std,
|
65 |
-
"video_setting": video1.shape,
|
66 |
-
"video_setting_name": "time, channel, heigth, width",
|
67 |
-
}
|
68 |
-
|
69 |
-
return result
|
70 |
-
|
71 |
-
|
72 |
-
# test code / using example
|
73 |
-
|
74 |
-
|
75 |
-
def main():
|
76 |
-
NUMBER_OF_VIDEOS = 8
|
77 |
-
VIDEO_LENGTH = 50
|
78 |
-
CHANNEL = 3
|
79 |
-
SIZE = 64
|
80 |
-
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
81 |
-
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
82 |
-
|
83 |
-
import json
|
84 |
-
|
85 |
-
result = calculate_psnr(videos1, videos2)
|
86 |
-
print(json.dumps(result, indent=4))
|
87 |
-
|
88 |
-
|
89 |
-
if __name__ == "__main__":
|
90 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/pab/commom_metrics/calculate_ssim.py
DELETED
@@ -1,116 +0,0 @@
|
|
1 |
-
import cv2
|
2 |
-
import numpy as np
|
3 |
-
import torch
|
4 |
-
|
5 |
-
|
6 |
-
def ssim(img1, img2):
|
7 |
-
C1 = 0.01**2
|
8 |
-
C2 = 0.03**2
|
9 |
-
img1 = img1.astype(np.float64)
|
10 |
-
img2 = img2.astype(np.float64)
|
11 |
-
kernel = cv2.getGaussianKernel(11, 1.5)
|
12 |
-
window = np.outer(kernel, kernel.transpose())
|
13 |
-
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
14 |
-
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
15 |
-
mu1_sq = mu1**2
|
16 |
-
mu2_sq = mu2**2
|
17 |
-
mu1_mu2 = mu1 * mu2
|
18 |
-
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
19 |
-
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
20 |
-
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
21 |
-
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
22 |
-
return ssim_map.mean()
|
23 |
-
|
24 |
-
|
25 |
-
def calculate_ssim_function(img1, img2):
|
26 |
-
# [0,1]
|
27 |
-
# ssim is the only metric extremely sensitive to gray being compared to b/w
|
28 |
-
if not img1.shape == img2.shape:
|
29 |
-
raise ValueError("Input images must have the same dimensions.")
|
30 |
-
if img1.ndim == 2:
|
31 |
-
return ssim(img1, img2)
|
32 |
-
elif img1.ndim == 3:
|
33 |
-
if img1.shape[0] == 3:
|
34 |
-
ssims = []
|
35 |
-
for i in range(3):
|
36 |
-
ssims.append(ssim(img1[i], img2[i]))
|
37 |
-
return np.array(ssims).mean()
|
38 |
-
elif img1.shape[0] == 1:
|
39 |
-
return ssim(np.squeeze(img1), np.squeeze(img2))
|
40 |
-
else:
|
41 |
-
raise ValueError("Wrong input image dimensions.")
|
42 |
-
|
43 |
-
|
44 |
-
def trans(x):
|
45 |
-
return x
|
46 |
-
|
47 |
-
|
48 |
-
def calculate_ssim(videos1, videos2):
|
49 |
-
# videos [batch_size, timestamps, channel, h, w]
|
50 |
-
|
51 |
-
assert videos1.shape == videos2.shape
|
52 |
-
|
53 |
-
videos1 = trans(videos1)
|
54 |
-
videos2 = trans(videos2)
|
55 |
-
|
56 |
-
ssim_results = []
|
57 |
-
|
58 |
-
for video_num in range(videos1.shape[0]):
|
59 |
-
# get a video
|
60 |
-
# video [timestamps, channel, h, w]
|
61 |
-
video1 = videos1[video_num]
|
62 |
-
video2 = videos2[video_num]
|
63 |
-
|
64 |
-
ssim_results_of_a_video = []
|
65 |
-
for clip_timestamp in range(len(video1)):
|
66 |
-
# get a img
|
67 |
-
# img [timestamps[x], channel, h, w]
|
68 |
-
# img [channel, h, w] numpy
|
69 |
-
|
70 |
-
img1 = video1[clip_timestamp].numpy()
|
71 |
-
img2 = video2[clip_timestamp].numpy()
|
72 |
-
|
73 |
-
# calculate ssim of a video
|
74 |
-
ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
|
75 |
-
|
76 |
-
ssim_results.append(ssim_results_of_a_video)
|
77 |
-
|
78 |
-
ssim_results = np.array(ssim_results)
|
79 |
-
|
80 |
-
ssim = {}
|
81 |
-
ssim_std = {}
|
82 |
-
|
83 |
-
for clip_timestamp in range(len(video1)):
|
84 |
-
ssim[clip_timestamp] = np.mean(ssim_results[:, clip_timestamp])
|
85 |
-
ssim_std[clip_timestamp] = np.std(ssim_results[:, clip_timestamp])
|
86 |
-
|
87 |
-
result = {
|
88 |
-
"value": ssim,
|
89 |
-
"value_std": ssim_std,
|
90 |
-
"video_setting": video1.shape,
|
91 |
-
"video_setting_name": "time, channel, heigth, width",
|
92 |
-
}
|
93 |
-
|
94 |
-
return result
|
95 |
-
|
96 |
-
|
97 |
-
# test code / using example
|
98 |
-
|
99 |
-
|
100 |
-
def main():
|
101 |
-
NUMBER_OF_VIDEOS = 8
|
102 |
-
VIDEO_LENGTH = 50
|
103 |
-
CHANNEL = 3
|
104 |
-
SIZE = 64
|
105 |
-
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
106 |
-
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
107 |
-
torch.device("cuda")
|
108 |
-
|
109 |
-
import json
|
110 |
-
|
111 |
-
result = calculate_ssim(videos1, videos2)
|
112 |
-
print(json.dumps(result, indent=4))
|
113 |
-
|
114 |
-
|
115 |
-
if __name__ == "__main__":
|
116 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/pab/commom_metrics/eval.py
DELETED
@@ -1,160 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import os
|
3 |
-
|
4 |
-
import imageio
|
5 |
-
import torch
|
6 |
-
import torchvision.transforms.functional as F
|
7 |
-
import tqdm
|
8 |
-
from calculate_lpips import calculate_lpips
|
9 |
-
from calculate_psnr import calculate_psnr
|
10 |
-
from calculate_ssim import calculate_ssim
|
11 |
-
|
12 |
-
|
13 |
-
def load_videos(directory, video_ids, file_extension):
|
14 |
-
videos = []
|
15 |
-
for video_id in video_ids:
|
16 |
-
video_path = os.path.join(directory, f"{video_id}.{file_extension}")
|
17 |
-
if os.path.exists(video_path):
|
18 |
-
video = load_video(video_path) # Define load_video based on how videos are stored
|
19 |
-
videos.append(video)
|
20 |
-
else:
|
21 |
-
raise ValueError(f"Video {video_id}.{file_extension} not found in {directory}")
|
22 |
-
return videos
|
23 |
-
|
24 |
-
|
25 |
-
def load_video(video_path):
|
26 |
-
"""
|
27 |
-
Load a video from the given path and convert it to a PyTorch tensor.
|
28 |
-
"""
|
29 |
-
# Read the video using imageio
|
30 |
-
reader = imageio.get_reader(video_path, "ffmpeg")
|
31 |
-
|
32 |
-
# Extract frames and convert to a list of tensors
|
33 |
-
frames = []
|
34 |
-
for frame in reader:
|
35 |
-
# Convert the frame to a tensor and permute the dimensions to match (C, H, W)
|
36 |
-
frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
|
37 |
-
frames.append(frame_tensor)
|
38 |
-
|
39 |
-
# Stack the list of tensors into a single tensor with shape (T, C, H, W)
|
40 |
-
video_tensor = torch.stack(frames)
|
41 |
-
|
42 |
-
return video_tensor
|
43 |
-
|
44 |
-
|
45 |
-
def resize_video(video, target_height, target_width):
|
46 |
-
resized_frames = []
|
47 |
-
for frame in video:
|
48 |
-
resized_frame = F.resize(frame, [target_height, target_width])
|
49 |
-
resized_frames.append(resized_frame)
|
50 |
-
return torch.stack(resized_frames)
|
51 |
-
|
52 |
-
|
53 |
-
def preprocess_eval_video(eval_video, generated_video_shape):
|
54 |
-
T_gen, _, H_gen, W_gen = generated_video_shape
|
55 |
-
T_eval, _, H_eval, W_eval = eval_video.shape
|
56 |
-
|
57 |
-
if T_eval < T_gen:
|
58 |
-
raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
|
59 |
-
|
60 |
-
if H_eval < H_gen or W_eval < W_gen:
|
61 |
-
# Resize the video maintaining the aspect ratio
|
62 |
-
resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
|
63 |
-
resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
|
64 |
-
eval_video = resize_video(eval_video, resize_height, resize_width)
|
65 |
-
# Recalculate the dimensions
|
66 |
-
T_eval, _, H_eval, W_eval = eval_video.shape
|
67 |
-
|
68 |
-
# Center crop
|
69 |
-
start_h = (H_eval - H_gen) // 2
|
70 |
-
start_w = (W_eval - W_gen) // 2
|
71 |
-
cropped_video = eval_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
|
72 |
-
|
73 |
-
return cropped_video
|
74 |
-
|
75 |
-
|
76 |
-
def main(args):
|
77 |
-
device = "cuda"
|
78 |
-
gt_video_dir = args.gt_video_dir
|
79 |
-
generated_video_dir = args.generated_video_dir
|
80 |
-
|
81 |
-
video_ids = []
|
82 |
-
file_extension = "mp4"
|
83 |
-
for f in os.listdir(generated_video_dir):
|
84 |
-
if f.endswith(f".{file_extension}"):
|
85 |
-
video_ids.append(f.replace(f".{file_extension}", ""))
|
86 |
-
if not video_ids:
|
87 |
-
raise ValueError("No videos found in the generated video dataset. Exiting.")
|
88 |
-
|
89 |
-
print(f"Find {len(video_ids)} videos")
|
90 |
-
prompt_interval = 1
|
91 |
-
batch_size = 16
|
92 |
-
calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
|
93 |
-
|
94 |
-
lpips_results = []
|
95 |
-
psnr_results = []
|
96 |
-
ssim_results = []
|
97 |
-
|
98 |
-
total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
|
99 |
-
|
100 |
-
for idx, video_id in enumerate(tqdm.tqdm(range(total_len))):
|
101 |
-
gt_videos_tensor = []
|
102 |
-
generated_videos_tensor = []
|
103 |
-
for i in range(batch_size):
|
104 |
-
video_idx = idx * batch_size + i
|
105 |
-
if video_idx >= len(video_ids):
|
106 |
-
break
|
107 |
-
video_id = video_ids[video_idx]
|
108 |
-
generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.{file_extension}"))
|
109 |
-
generated_videos_tensor.append(generated_video)
|
110 |
-
eval_video = load_video(os.path.join(gt_video_dir, f"{video_id}.{file_extension}"))
|
111 |
-
gt_videos_tensor.append(eval_video)
|
112 |
-
gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
|
113 |
-
generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
|
114 |
-
|
115 |
-
if calculate_lpips_flag:
|
116 |
-
result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
|
117 |
-
result = result["value"].values()
|
118 |
-
result = sum(result) / len(result)
|
119 |
-
lpips_results.append(result)
|
120 |
-
|
121 |
-
if calculate_psnr_flag:
|
122 |
-
result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
|
123 |
-
result = result["value"].values()
|
124 |
-
result = sum(result) / len(result)
|
125 |
-
psnr_results.append(result)
|
126 |
-
|
127 |
-
if calculate_ssim_flag:
|
128 |
-
result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
|
129 |
-
result = result["value"].values()
|
130 |
-
result = sum(result) / len(result)
|
131 |
-
ssim_results.append(result)
|
132 |
-
|
133 |
-
if (idx + 1) % prompt_interval == 0:
|
134 |
-
out_str = ""
|
135 |
-
for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
|
136 |
-
result = sum(results) / len(results)
|
137 |
-
out_str += f"{name}: {result:.4f}, "
|
138 |
-
print(f"Processed {idx + 1} videos. {out_str[:-2]}")
|
139 |
-
|
140 |
-
out_str = ""
|
141 |
-
for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
|
142 |
-
result = sum(results) / len(results)
|
143 |
-
out_str += f"{name}: {result:.4f}, "
|
144 |
-
out_str = out_str[:-2]
|
145 |
-
|
146 |
-
# save
|
147 |
-
with open(f"./{os.path.basename(generated_video_dir)}.txt", "w+") as f:
|
148 |
-
f.write(out_str)
|
149 |
-
|
150 |
-
print(f"Processed all videos. {out_str}")
|
151 |
-
|
152 |
-
|
153 |
-
if __name__ == "__main__":
|
154 |
-
parser = argparse.ArgumentParser()
|
155 |
-
parser.add_argument("--gt_video_dir", type=str)
|
156 |
-
parser.add_argument("--generated_video_dir", type=str)
|
157 |
-
|
158 |
-
args = parser.parse_args()
|
159 |
-
|
160 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/pab/experiments/attention_ablation.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
from utils import generate_func, read_prompt_list
|
2 |
-
|
3 |
-
import videosys
|
4 |
-
from videosys import OpenSoraConfig, OpenSoraPipeline
|
5 |
-
from videosys.models.open_sora import OpenSoraPABConfig
|
6 |
-
|
7 |
-
|
8 |
-
def attention_ablation_func(pab_kwargs, prompt_list, output_dir):
|
9 |
-
pab_config = OpenSoraPABConfig(**pab_kwargs)
|
10 |
-
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
|
11 |
-
pipeline = OpenSoraPipeline(config)
|
12 |
-
|
13 |
-
generate_func(pipeline, prompt_list, output_dir)
|
14 |
-
|
15 |
-
|
16 |
-
def main(prompt_list):
|
17 |
-
# spatial
|
18 |
-
gap_list = [2, 3, 4, 5]
|
19 |
-
for gap in gap_list:
|
20 |
-
pab_kwargs = {
|
21 |
-
"spatial_broadcast": True,
|
22 |
-
"spatial_gap": gap,
|
23 |
-
"temporal_broadcast": False,
|
24 |
-
"cross_broadcast": False,
|
25 |
-
"mlp_skip": False,
|
26 |
-
}
|
27 |
-
output_dir = f"./samples/attention_ablation/spatial_g{gap}"
|
28 |
-
attention_ablation_func(pab_kwargs, prompt_list, output_dir)
|
29 |
-
|
30 |
-
# temporal
|
31 |
-
gap_list = [3, 4, 5, 6]
|
32 |
-
for gap in gap_list:
|
33 |
-
pab_kwargs = {
|
34 |
-
"spatial_broadcast": False,
|
35 |
-
"temporal_broadcast": True,
|
36 |
-
"temporal_gap": gap,
|
37 |
-
"cross_broadcast": False,
|
38 |
-
"mlp_skip": False,
|
39 |
-
}
|
40 |
-
output_dir = f"./samples/attention_ablation/temporal_g{gap}"
|
41 |
-
attention_ablation_func(pab_kwargs, prompt_list, output_dir)
|
42 |
-
|
43 |
-
# cross
|
44 |
-
gap_list = [5, 6, 7, 8]
|
45 |
-
for gap in gap_list:
|
46 |
-
pab_kwargs = {
|
47 |
-
"spatial_broadcast": False,
|
48 |
-
"temporal_broadcast": False,
|
49 |
-
"cross_broadcast": True,
|
50 |
-
"cross_gap": gap,
|
51 |
-
"mlp_skip": False,
|
52 |
-
}
|
53 |
-
output_dir = f"./samples/attention_ablation/cross_g{gap}"
|
54 |
-
attention_ablation_func(pab_kwargs, prompt_list, output_dir)
|
55 |
-
|
56 |
-
|
57 |
-
if __name__ == "__main__":
|
58 |
-
videosys.initialize(42)
|
59 |
-
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
60 |
-
main(prompt_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/pab/experiments/components_ablation.py
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
from utils import generate_func, read_prompt_list
|
2 |
-
|
3 |
-
import videosys
|
4 |
-
from videosys import OpenSoraConfig, OpenSoraPipeline
|
5 |
-
from videosys.models.open_sora import OpenSoraPABConfig
|
6 |
-
|
7 |
-
|
8 |
-
def wo_spatial(prompt_list):
|
9 |
-
pab_config = OpenSoraPABConfig(spatial_broadcast=False)
|
10 |
-
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
|
11 |
-
pipeline = OpenSoraPipeline(config)
|
12 |
-
|
13 |
-
generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_spatial")
|
14 |
-
|
15 |
-
|
16 |
-
def wo_temporal(prompt_list):
|
17 |
-
pab_config = OpenSoraPABConfig(temporal_broadcast=False)
|
18 |
-
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
|
19 |
-
pipeline = OpenSoraPipeline(config)
|
20 |
-
|
21 |
-
generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_temporal")
|
22 |
-
|
23 |
-
|
24 |
-
def wo_cross(prompt_list):
|
25 |
-
pab_config = OpenSoraPABConfig(cross_broadcast=False)
|
26 |
-
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
|
27 |
-
pipeline = OpenSoraPipeline(config)
|
28 |
-
|
29 |
-
generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_cross")
|
30 |
-
|
31 |
-
|
32 |
-
def wo_mlp(prompt_list):
|
33 |
-
pab_config = OpenSoraPABConfig(mlp_skip=False)
|
34 |
-
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
|
35 |
-
pipeline = OpenSoraPipeline(config)
|
36 |
-
|
37 |
-
generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_mlp")
|
38 |
-
|
39 |
-
|
40 |
-
if __name__ == "__main__":
|
41 |
-
videosys.initialize(42)
|
42 |
-
prompt_list = read_prompt_list("./vbench/VBench_full_info.json")
|
43 |
-
wo_spatial(prompt_list)
|
44 |
-
wo_temporal(prompt_list)
|
45 |
-
wo_cross(prompt_list)
|
46 |
-
wo_mlp(prompt_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/pab/experiments/latte.py
DELETED
@@ -1,57 +0,0 @@
|
|
1 |
-
from utils import generate_func, read_prompt_list
|
2 |
-
|
3 |
-
import videosys
|
4 |
-
from videosys import LatteConfig, LattePipeline
|
5 |
-
from videosys.models.latte import LattePABConfig
|
6 |
-
|
7 |
-
|
8 |
-
def eval_base(prompt_list):
|
9 |
-
config = LatteConfig()
|
10 |
-
pipeline = LattePipeline(config)
|
11 |
-
|
12 |
-
generate_func(pipeline, prompt_list, "./samples/latte_base", loop=5)
|
13 |
-
|
14 |
-
|
15 |
-
def eval_pab1(prompt_list):
|
16 |
-
pab_config = LattePABConfig(
|
17 |
-
spatial_gap=2,
|
18 |
-
temporal_gap=3,
|
19 |
-
cross_gap=6,
|
20 |
-
)
|
21 |
-
config = LatteConfig(enable_pab=True, pab_config=pab_config)
|
22 |
-
pipeline = LattePipeline(config)
|
23 |
-
|
24 |
-
generate_func(pipeline, prompt_list, "./samples/latte_pab1", loop=5)
|
25 |
-
|
26 |
-
|
27 |
-
def eval_pab2(prompt_list):
|
28 |
-
pab_config = LattePABConfig(
|
29 |
-
spatial_gap=3,
|
30 |
-
temporal_gap=4,
|
31 |
-
cross_gap=7,
|
32 |
-
)
|
33 |
-
config = LatteConfig(enable_pab=True, pab_config=pab_config)
|
34 |
-
pipeline = LattePipeline(config)
|
35 |
-
|
36 |
-
generate_func(pipeline, prompt_list, "./samples/latte_pab2", loop=5)
|
37 |
-
|
38 |
-
|
39 |
-
def eval_pab3(prompt_list):
|
40 |
-
pab_config = LattePABConfig(
|
41 |
-
spatial_gap=4,
|
42 |
-
temporal_gap=6,
|
43 |
-
cross_gap=9,
|
44 |
-
)
|
45 |
-
config = LatteConfig(enable_pab=True, pab_config=pab_config)
|
46 |
-
pipeline = LattePipeline(config)
|
47 |
-
|
48 |
-
generate_func(pipeline, prompt_list, "./samples/latte_pab3", loop=5)
|
49 |
-
|
50 |
-
|
51 |
-
if __name__ == "__main__":
|
52 |
-
videosys.initialize(42)
|
53 |
-
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
54 |
-
eval_base(prompt_list)
|
55 |
-
eval_pab1(prompt_list)
|
56 |
-
eval_pab2(prompt_list)
|
57 |
-
eval_pab3(prompt_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/pab/experiments/opensora.py
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
from utils import generate_func, read_prompt_list
|
2 |
-
|
3 |
-
import videosys
|
4 |
-
from videosys import OpenSoraConfig, OpenSoraPipeline
|
5 |
-
from videosys.models.open_sora import OpenSoraPABConfig
|
6 |
-
|
7 |
-
|
8 |
-
def eval_base(prompt_list):
|
9 |
-
config = OpenSoraConfig()
|
10 |
-
pipeline = OpenSoraPipeline(config)
|
11 |
-
|
12 |
-
generate_func(pipeline, prompt_list, "./samples/opensora_base", loop=5)
|
13 |
-
|
14 |
-
|
15 |
-
def eval_pab1(prompt_list):
|
16 |
-
config = OpenSoraConfig(enable_pab=True)
|
17 |
-
pipeline = OpenSoraPipeline(config)
|
18 |
-
|
19 |
-
generate_func(pipeline, prompt_list, "./samples/opensora_pab1", loop=5)
|
20 |
-
|
21 |
-
|
22 |
-
def eval_pab2(prompt_list):
|
23 |
-
pab_config = OpenSoraPABConfig(spatial_gap=3, temporal_gap=5, cross_gap=7)
|
24 |
-
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
|
25 |
-
pipeline = OpenSoraPipeline(config)
|
26 |
-
|
27 |
-
generate_func(pipeline, prompt_list, "./samples/opensora_pab2", loop=5)
|
28 |
-
|
29 |
-
|
30 |
-
def eval_pab3(prompt_list):
|
31 |
-
pab_config = OpenSoraPABConfig(spatial_gap=5, temporal_gap=7, cross_gap=9)
|
32 |
-
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
|
33 |
-
pipeline = OpenSoraPipeline(config)
|
34 |
-
|
35 |
-
generate_func(pipeline, prompt_list, "./samples/opensora_pab3", loop=5)
|
36 |
-
|
37 |
-
|
38 |
-
if __name__ == "__main__":
|
39 |
-
videosys.initialize(42)
|
40 |
-
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
41 |
-
eval_base(prompt_list)
|
42 |
-
eval_pab1(prompt_list)
|
43 |
-
eval_pab2(prompt_list)
|
44 |
-
eval_pab3(prompt_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/pab/experiments/opensora_plan.py
DELETED
@@ -1,57 +0,0 @@
|
|
1 |
-
from utils import generate_func, read_prompt_list
|
2 |
-
|
3 |
-
import videosys
|
4 |
-
from videosys import OpenSoraPlanConfig, OpenSoraPlanPipeline
|
5 |
-
from videosys.models.open_sora_plan import OpenSoraPlanPABConfig
|
6 |
-
|
7 |
-
|
8 |
-
def eval_base(prompt_list):
|
9 |
-
config = OpenSoraPlanConfig()
|
10 |
-
pipeline = OpenSoraPlanPipeline(config)
|
11 |
-
|
12 |
-
generate_func(pipeline, prompt_list, "./samples/opensoraplan_base", loop=5)
|
13 |
-
|
14 |
-
|
15 |
-
def eval_pab1(prompt_list):
|
16 |
-
pab_config = OpenSoraPlanPABConfig(
|
17 |
-
spatial_gap=2,
|
18 |
-
temporal_gap=4,
|
19 |
-
cross_gap=6,
|
20 |
-
)
|
21 |
-
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
|
22 |
-
pipeline = OpenSoraPlanPipeline(config)
|
23 |
-
|
24 |
-
generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab1", loop=5)
|
25 |
-
|
26 |
-
|
27 |
-
def eval_pab2(prompt_list):
|
28 |
-
pab_config = OpenSoraPlanPABConfig(
|
29 |
-
spatial_gap=3,
|
30 |
-
temporal_gap=5,
|
31 |
-
cross_gap=7,
|
32 |
-
)
|
33 |
-
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
|
34 |
-
pipeline = OpenSoraPlanPipeline(config)
|
35 |
-
|
36 |
-
generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab2", loop=5)
|
37 |
-
|
38 |
-
|
39 |
-
def eval_pab3(prompt_list):
|
40 |
-
pab_config = OpenSoraPlanPABConfig(
|
41 |
-
spatial_gap=5,
|
42 |
-
temporal_gap=7,
|
43 |
-
cross_gap=9,
|
44 |
-
)
|
45 |
-
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
|
46 |
-
pipeline = OpenSoraPlanPipeline(config)
|
47 |
-
|
48 |
-
generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab3", loop=5)
|
49 |
-
|
50 |
-
|
51 |
-
if __name__ == "__main__":
|
52 |
-
videosys.initialize(42)
|
53 |
-
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
54 |
-
eval_base(prompt_list)
|
55 |
-
eval_pab1(prompt_list)
|
56 |
-
eval_pab2(prompt_list)
|
57 |
-
eval_pab3(prompt_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/pab/experiments/utils.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import os
|
3 |
-
|
4 |
-
import tqdm
|
5 |
-
|
6 |
-
from videosys.utils.utils import set_seed
|
7 |
-
|
8 |
-
|
9 |
-
def generate_func(pipeline, prompt_list, output_dir, loop: int = 5, kwargs: dict = {}):
|
10 |
-
kwargs["verbose"] = False
|
11 |
-
for prompt in tqdm.tqdm(prompt_list):
|
12 |
-
for l in range(loop):
|
13 |
-
set_seed(l)
|
14 |
-
video = pipeline.generate(prompt, **kwargs).video[0]
|
15 |
-
pipeline.save_video(video, os.path.join(output_dir, f"{prompt}-{l}.mp4"))
|
16 |
-
|
17 |
-
|
18 |
-
def read_prompt_list(prompt_list_path):
|
19 |
-
with open(prompt_list_path, "r") as f:
|
20 |
-
prompt_list = json.load(f)
|
21 |
-
prompt_list = [prompt["prompt_en"] for prompt in prompt_list]
|
22 |
-
return prompt_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/pab/vbench/VBench_full_info.json
DELETED
The diff for this file is too large to render.
See raw diff
|
|
eval/pab/vbench/cal_vbench.py
DELETED
@@ -1,154 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import json
|
3 |
-
import os
|
4 |
-
|
5 |
-
SEMANTIC_WEIGHT = 1
|
6 |
-
QUALITY_WEIGHT = 4
|
7 |
-
|
8 |
-
QUALITY_LIST = [
|
9 |
-
"subject consistency",
|
10 |
-
"background consistency",
|
11 |
-
"temporal flickering",
|
12 |
-
"motion smoothness",
|
13 |
-
"aesthetic quality",
|
14 |
-
"imaging quality",
|
15 |
-
"dynamic degree",
|
16 |
-
]
|
17 |
-
|
18 |
-
SEMANTIC_LIST = [
|
19 |
-
"object class",
|
20 |
-
"multiple objects",
|
21 |
-
"human action",
|
22 |
-
"color",
|
23 |
-
"spatial relationship",
|
24 |
-
"scene",
|
25 |
-
"appearance style",
|
26 |
-
"temporal style",
|
27 |
-
"overall consistency",
|
28 |
-
]
|
29 |
-
|
30 |
-
NORMALIZE_DIC = {
|
31 |
-
"subject consistency": {"Min": 0.1462, "Max": 1.0},
|
32 |
-
"background consistency": {"Min": 0.2615, "Max": 1.0},
|
33 |
-
"temporal flickering": {"Min": 0.6293, "Max": 1.0},
|
34 |
-
"motion smoothness": {"Min": 0.706, "Max": 0.9975},
|
35 |
-
"dynamic degree": {"Min": 0.0, "Max": 1.0},
|
36 |
-
"aesthetic quality": {"Min": 0.0, "Max": 1.0},
|
37 |
-
"imaging quality": {"Min": 0.0, "Max": 1.0},
|
38 |
-
"object class": {"Min": 0.0, "Max": 1.0},
|
39 |
-
"multiple objects": {"Min": 0.0, "Max": 1.0},
|
40 |
-
"human action": {"Min": 0.0, "Max": 1.0},
|
41 |
-
"color": {"Min": 0.0, "Max": 1.0},
|
42 |
-
"spatial relationship": {"Min": 0.0, "Max": 1.0},
|
43 |
-
"scene": {"Min": 0.0, "Max": 0.8222},
|
44 |
-
"appearance style": {"Min": 0.0009, "Max": 0.2855},
|
45 |
-
"temporal style": {"Min": 0.0, "Max": 0.364},
|
46 |
-
"overall consistency": {"Min": 0.0, "Max": 0.364},
|
47 |
-
}
|
48 |
-
|
49 |
-
DIM_WEIGHT = {
|
50 |
-
"subject consistency": 1,
|
51 |
-
"background consistency": 1,
|
52 |
-
"temporal flickering": 1,
|
53 |
-
"motion smoothness": 1,
|
54 |
-
"aesthetic quality": 1,
|
55 |
-
"imaging quality": 1,
|
56 |
-
"dynamic degree": 0.5,
|
57 |
-
"object class": 1,
|
58 |
-
"multiple objects": 1,
|
59 |
-
"human action": 1,
|
60 |
-
"color": 1,
|
61 |
-
"spatial relationship": 1,
|
62 |
-
"scene": 1,
|
63 |
-
"appearance style": 1,
|
64 |
-
"temporal style": 1,
|
65 |
-
"overall consistency": 1,
|
66 |
-
}
|
67 |
-
|
68 |
-
ordered_scaled_res = [
|
69 |
-
"total score",
|
70 |
-
"quality score",
|
71 |
-
"semantic score",
|
72 |
-
"subject consistency",
|
73 |
-
"background consistency",
|
74 |
-
"temporal flickering",
|
75 |
-
"motion smoothness",
|
76 |
-
"dynamic degree",
|
77 |
-
"aesthetic quality",
|
78 |
-
"imaging quality",
|
79 |
-
"object class",
|
80 |
-
"multiple objects",
|
81 |
-
"human action",
|
82 |
-
"color",
|
83 |
-
"spatial relationship",
|
84 |
-
"scene",
|
85 |
-
"appearance style",
|
86 |
-
"temporal style",
|
87 |
-
"overall consistency",
|
88 |
-
]
|
89 |
-
|
90 |
-
|
91 |
-
def parse_args():
|
92 |
-
parser = argparse.ArgumentParser()
|
93 |
-
parser.add_argument("--score_dir", required=True, type=str)
|
94 |
-
args = parser.parse_args()
|
95 |
-
return args
|
96 |
-
|
97 |
-
|
98 |
-
if __name__ == "__main__":
|
99 |
-
args = parse_args()
|
100 |
-
res_postfix = "_eval_results.json"
|
101 |
-
info_postfix = "_full_info.json"
|
102 |
-
files = os.listdir(args.score_dir)
|
103 |
-
res_files = [x for x in files if res_postfix in x]
|
104 |
-
info_files = [x for x in files if info_postfix in x]
|
105 |
-
assert len(res_files) == len(info_files), f"got {len(res_files)} res files, but {len(info_files)} info files"
|
106 |
-
|
107 |
-
full_results = {}
|
108 |
-
for res_file in res_files:
|
109 |
-
# first check if results is normal
|
110 |
-
info_file = res_file.split(res_postfix)[0] + info_postfix
|
111 |
-
with open(os.path.join(args.score_dir, info_file), "r", encoding="utf-8") as f:
|
112 |
-
info = json.load(f)
|
113 |
-
assert len(info[0]["video_list"]) > 0, f"Error: {info_file} has 0 video list"
|
114 |
-
# read results
|
115 |
-
with open(os.path.join(args.score_dir, res_file), "r", encoding="utf-8") as f:
|
116 |
-
data = json.load(f)
|
117 |
-
for key, val in data.items():
|
118 |
-
full_results[key] = format(val[0], ".4f")
|
119 |
-
|
120 |
-
scaled_results = {}
|
121 |
-
dims = set()
|
122 |
-
for key, val in full_results.items():
|
123 |
-
dim = key.replace("_", " ") if "_" in key else key
|
124 |
-
scaled_score = (float(val) - NORMALIZE_DIC[dim]["Min"]) / (
|
125 |
-
NORMALIZE_DIC[dim]["Max"] - NORMALIZE_DIC[dim]["Min"]
|
126 |
-
)
|
127 |
-
scaled_score *= DIM_WEIGHT[dim]
|
128 |
-
scaled_results[dim] = scaled_score
|
129 |
-
dims.add(dim)
|
130 |
-
|
131 |
-
assert len(dims) == len(NORMALIZE_DIC), f"{set(NORMALIZE_DIC.keys())-dims} not calculated yet"
|
132 |
-
|
133 |
-
quality_score = sum([scaled_results[i] for i in QUALITY_LIST]) / sum([DIM_WEIGHT[i] for i in QUALITY_LIST])
|
134 |
-
semantic_score = sum([scaled_results[i] for i in SEMANTIC_LIST]) / sum([DIM_WEIGHT[i] for i in SEMANTIC_LIST])
|
135 |
-
scaled_results["quality score"] = quality_score
|
136 |
-
scaled_results["semantic score"] = semantic_score
|
137 |
-
scaled_results["total score"] = (quality_score * QUALITY_WEIGHT + semantic_score * SEMANTIC_WEIGHT) / (
|
138 |
-
QUALITY_WEIGHT + SEMANTIC_WEIGHT
|
139 |
-
)
|
140 |
-
|
141 |
-
formated_scaled_results = {"items": []}
|
142 |
-
for key in ordered_scaled_res:
|
143 |
-
formated_score = format(scaled_results[key] * 100, ".2f") + "%"
|
144 |
-
formated_scaled_results["items"].append({key: formated_score})
|
145 |
-
|
146 |
-
output_file_path = os.path.join(args.score_dir, "all_results.json")
|
147 |
-
with open(output_file_path, "w") as outfile:
|
148 |
-
json.dump(full_results, outfile, indent=4, sort_keys=True)
|
149 |
-
print(f"results saved to: {output_file_path}")
|
150 |
-
|
151 |
-
scaled_file_path = os.path.join(args.score_dir, "scaled_results.json")
|
152 |
-
with open(scaled_file_path, "w") as outfile:
|
153 |
-
json.dump(formated_scaled_results, outfile, indent=4, sort_keys=True)
|
154 |
-
print(f"results saved to: {scaled_file_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval/pab/vbench/run_vbench.py
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from vbench import VBench
|
5 |
-
|
6 |
-
full_info_path = "./vbench/VBench_full_info.json"
|
7 |
-
|
8 |
-
dimensions = [
|
9 |
-
"subject_consistency",
|
10 |
-
"imaging_quality",
|
11 |
-
"background_consistency",
|
12 |
-
"motion_smoothness",
|
13 |
-
"overall_consistency",
|
14 |
-
"human_action",
|
15 |
-
"multiple_objects",
|
16 |
-
"spatial_relationship",
|
17 |
-
"object_class",
|
18 |
-
"color",
|
19 |
-
"aesthetic_quality",
|
20 |
-
"appearance_style",
|
21 |
-
"temporal_flickering",
|
22 |
-
"scene",
|
23 |
-
"temporal_style",
|
24 |
-
"dynamic_degree",
|
25 |
-
]
|
26 |
-
|
27 |
-
|
28 |
-
def parse_args():
|
29 |
-
parser = argparse.ArgumentParser()
|
30 |
-
parser.add_argument("--video_path", required=True, type=str)
|
31 |
-
args = parser.parse_args()
|
32 |
-
return args
|
33 |
-
|
34 |
-
|
35 |
-
if __name__ == "__main__":
|
36 |
-
args = parse_args()
|
37 |
-
save_path = args.video_path.replace("/samples/", "/vbench_out/")
|
38 |
-
|
39 |
-
kwargs = {}
|
40 |
-
kwargs["imaging_quality_preprocessing_mode"] = "longer" # use VBench/evaluate.py default
|
41 |
-
|
42 |
-
for dimension in dimensions:
|
43 |
-
my_VBench = VBench(torch.device("cuda"), full_info_path, save_path)
|
44 |
-
my_VBench.evaluate(
|
45 |
-
videos_path=args.video_path,
|
46 |
-
name=dimension,
|
47 |
-
local=False,
|
48 |
-
read_frame=False,
|
49 |
-
dimension_list=[dimension],
|
50 |
-
mode="vbench_standard",
|
51 |
-
**kwargs,
|
52 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/cogvideo/sample.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
from videosys import CogVideoConfig, VideoSysEngine
|
2 |
-
|
3 |
-
|
4 |
-
def run_base():
|
5 |
-
config = CogVideoConfig(world_size=1)
|
6 |
-
engine = VideoSysEngine(config)
|
7 |
-
|
8 |
-
prompt = "Sunset over the sea."
|
9 |
-
video = engine.generate(prompt).video[0]
|
10 |
-
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
11 |
-
|
12 |
-
|
13 |
-
if __name__ == "__main__":
|
14 |
-
run_base()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/latte/sample.py
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
from videosys import LatteConfig, VideoSysEngine
|
2 |
-
|
3 |
-
|
4 |
-
def run_base():
|
5 |
-
config = LatteConfig(world_size=1)
|
6 |
-
engine = VideoSysEngine(config)
|
7 |
-
|
8 |
-
prompt = "Sunset over the sea."
|
9 |
-
video = engine.generate(prompt).video[0]
|
10 |
-
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
11 |
-
|
12 |
-
|
13 |
-
def run_pab():
|
14 |
-
config = LatteConfig(world_size=1)
|
15 |
-
engine = VideoSysEngine(config)
|
16 |
-
|
17 |
-
prompt = "Sunset over the sea."
|
18 |
-
video = engine.generate(prompt).video[0]
|
19 |
-
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
20 |
-
|
21 |
-
|
22 |
-
if __name__ == "__main__":
|
23 |
-
run_base()
|
24 |
-
# run_pab()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/open_sora/sample.py
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
from videosys import OpenSoraConfig, VideoSysEngine
|
2 |
-
|
3 |
-
|
4 |
-
def run_base():
|
5 |
-
config = OpenSoraConfig(world_size=1)
|
6 |
-
engine = VideoSysEngine(config)
|
7 |
-
|
8 |
-
prompt = "Sunset over the sea."
|
9 |
-
video = engine.generate(prompt).video[0]
|
10 |
-
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
11 |
-
|
12 |
-
|
13 |
-
def run_pab():
|
14 |
-
config = OpenSoraConfig(world_size=1, enable_pab=True)
|
15 |
-
engine = VideoSysEngine(config)
|
16 |
-
|
17 |
-
prompt = "Sunset over the sea."
|
18 |
-
video = engine.generate(prompt).video[0]
|
19 |
-
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
20 |
-
|
21 |
-
|
22 |
-
if __name__ == "__main__":
|
23 |
-
run_base()
|
24 |
-
run_pab()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/open_sora_plan/sample.py
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
from videosys import OpenSoraPlanConfig, VideoSysEngine
|
2 |
-
|
3 |
-
|
4 |
-
def run_base():
|
5 |
-
config = OpenSoraPlanConfig(world_size=1)
|
6 |
-
engine = VideoSysEngine(config)
|
7 |
-
|
8 |
-
prompt = "Sunset over the sea."
|
9 |
-
video = engine.generate(prompt).video[0]
|
10 |
-
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
11 |
-
|
12 |
-
|
13 |
-
def run_pab():
|
14 |
-
config = OpenSoraPlanConfig(world_size=1)
|
15 |
-
engine = VideoSysEngine(config)
|
16 |
-
|
17 |
-
prompt = "Sunset over the sea."
|
18 |
-
video = engine.generate(prompt).video[0]
|
19 |
-
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
20 |
-
|
21 |
-
|
22 |
-
if __name__ == "__main__":
|
23 |
-
run_base()
|
24 |
-
# run_pab()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
videosys/__init__.py
CHANGED
@@ -1,19 +1,15 @@
|
|
1 |
from .core.engine import VideoSysEngine
|
2 |
from .core.parallel_mgr import initialize
|
3 |
-
from .
|
4 |
-
from .
|
5 |
-
from .
|
6 |
-
from .
|
7 |
|
8 |
__all__ = [
|
9 |
"initialize",
|
10 |
"VideoSysEngine",
|
11 |
-
"LattePipeline",
|
12 |
-
"
|
13 |
-
"
|
14 |
-
"
|
15 |
-
|
16 |
-
"OpenSoraConfig",
|
17 |
-
"CogVideoConfig",
|
18 |
-
"CogVideoPipeline",
|
19 |
-
]
|
|
|
1 |
from .core.engine import VideoSysEngine
|
2 |
from .core.parallel_mgr import initialize
|
3 |
+
from .pipelines.cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline
|
4 |
+
from .pipelines.latte import LatteConfig, LattePABConfig, LattePipeline
|
5 |
+
from .pipelines.open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
|
6 |
+
from .pipelines.open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
|
7 |
|
8 |
__all__ = [
|
9 |
"initialize",
|
10 |
"VideoSysEngine",
|
11 |
+
"LattePipeline", "LatteConfig", "LattePABConfig",
|
12 |
+
"OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanPABConfig",
|
13 |
+
"OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig",
|
14 |
+
"CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig"
|
15 |
+
] # fmt: skip
|
|
|
|
|
|
|
|
videosys/core/engine.py
CHANGED
@@ -2,7 +2,6 @@ import os
|
|
2 |
from functools import partial
|
3 |
from typing import Any, Optional
|
4 |
|
5 |
-
import imageio
|
6 |
import torch
|
7 |
|
8 |
import videosys
|
@@ -120,8 +119,7 @@ class VideoSysEngine:
|
|
120 |
result.get()
|
121 |
|
122 |
def save_video(self, video, output_path):
|
123 |
-
|
124 |
-
imageio.mimwrite(output_path, video, fps=24)
|
125 |
|
126 |
def shutdown(self):
|
127 |
if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
|
@@ -129,4 +127,4 @@ class VideoSysEngine:
|
|
129 |
torch.distributed.destroy_process_group()
|
130 |
|
131 |
def __del__(self):
|
132 |
-
self.shutdown()
|
|
|
2 |
from functools import partial
|
3 |
from typing import Any, Optional
|
4 |
|
|
|
5 |
import torch
|
6 |
|
7 |
import videosys
|
|
|
119 |
result.get()
|
120 |
|
121 |
def save_video(self, video, output_path):
|
122 |
+
return self.driver_worker.save_video(video, output_path)
|
|
|
123 |
|
124 |
def shutdown(self):
|
125 |
if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
|
|
|
127 |
torch.distributed.destroy_process_group()
|
128 |
|
129 |
def __del__(self):
|
130 |
+
self.shutdown()
|
videosys/core/pab_mgr.py
CHANGED
@@ -1,8 +1,3 @@
|
|
1 |
-
import random
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
|
6 |
from videosys.utils.logging import logger
|
7 |
|
8 |
PAB_MANAGER = None
|
@@ -12,71 +7,56 @@ class PABConfig:
|
|
12 |
def __init__(
|
13 |
self,
|
14 |
steps: int,
|
15 |
-
cross_broadcast: bool,
|
16 |
-
cross_threshold: list,
|
17 |
-
|
18 |
-
spatial_broadcast: bool,
|
19 |
-
spatial_threshold: list,
|
20 |
-
|
21 |
-
temporal_broadcast: bool,
|
22 |
-
temporal_threshold: list,
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
mlp_skip: bool,
|
28 |
-
mlp_spatial_skip_config: dict,
|
29 |
-
mlp_temporal_skip_config: dict,
|
30 |
-
full_broadcast: bool = False,
|
31 |
-
full_threshold: list = None,
|
32 |
-
full_gap: int = 1,
|
33 |
):
|
34 |
self.steps = steps
|
35 |
|
36 |
self.cross_broadcast = cross_broadcast
|
37 |
self.cross_threshold = cross_threshold
|
38 |
-
self.
|
39 |
|
40 |
self.spatial_broadcast = spatial_broadcast
|
41 |
self.spatial_threshold = spatial_threshold
|
42 |
-
self.
|
43 |
|
44 |
self.temporal_broadcast = temporal_broadcast
|
45 |
self.temporal_threshold = temporal_threshold
|
46 |
-
self.
|
47 |
-
|
48 |
-
self.diffusion_skip = diffusion_skip
|
49 |
-
self.diffusion_timestep_respacing = diffusion_timestep_respacing
|
50 |
-
self.diffusion_skip_timestep = diffusion_skip_timestep
|
51 |
|
52 |
-
self.
|
53 |
-
self.
|
54 |
-
self.
|
55 |
-
|
56 |
-
self.
|
57 |
-
self.spatial_mlp_outputs = {}
|
58 |
-
|
59 |
-
self.full_broadcast = full_broadcast
|
60 |
-
self.full_threshold = full_threshold
|
61 |
-
self.full_gap = full_gap
|
62 |
|
63 |
|
64 |
class PABManager:
|
65 |
def __init__(self, config: PABConfig):
|
66 |
self.config: PABConfig = config
|
67 |
|
68 |
-
init_prompt = f"Init
|
69 |
-
init_prompt += f"
|
70 |
-
init_prompt += f"
|
71 |
-
init_prompt += f"
|
72 |
-
init_prompt += f"
|
73 |
logger.info(init_prompt)
|
74 |
|
75 |
def if_broadcast_cross(self, timestep: int, count: int):
|
76 |
if (
|
77 |
self.config.cross_broadcast
|
78 |
and (timestep is not None)
|
79 |
-
and (count % self.config.
|
80 |
and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
|
81 |
):
|
82 |
flag = True
|
@@ -89,7 +69,7 @@ class PABManager:
|
|
89 |
if (
|
90 |
self.config.temporal_broadcast
|
91 |
and (timestep is not None)
|
92 |
-
and (count % self.config.
|
93 |
and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
|
94 |
):
|
95 |
flag = True
|
@@ -102,7 +82,7 @@ class PABManager:
|
|
102 |
if (
|
103 |
self.config.spatial_broadcast
|
104 |
and (timestep is not None)
|
105 |
-
and (count % self.config.
|
106 |
and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
|
107 |
):
|
108 |
flag = True
|
@@ -111,19 +91,6 @@ class PABManager:
|
|
111 |
count = (count + 1) % self.config.steps
|
112 |
return flag, count
|
113 |
|
114 |
-
def if_broadcast_full(self, timestep: int, count: int, block_idx: int):
|
115 |
-
if (
|
116 |
-
self.config.full_broadcast
|
117 |
-
and (timestep is not None)
|
118 |
-
and (count % self.config.full_gap != 0)
|
119 |
-
and (self.config.full_threshold[0] < timestep < self.config.full_threshold[1])
|
120 |
-
):
|
121 |
-
flag = True
|
122 |
-
else:
|
123 |
-
flag = False
|
124 |
-
count = (count + 1) % self.config.steps
|
125 |
-
return flag, count
|
126 |
-
|
127 |
@staticmethod
|
128 |
def _is_t_in_skip_config(all_timesteps, timestep, config):
|
129 |
is_t_in_skip_config = False
|
@@ -139,18 +106,18 @@ class PABManager:
|
|
139 |
return is_t_in_skip_config, skip_range
|
140 |
|
141 |
def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
142 |
-
if not self.config.
|
143 |
return False, None, False, None
|
144 |
|
145 |
if is_temporal:
|
146 |
-
cur_config = self.config.
|
147 |
else:
|
148 |
-
cur_config = self.config.
|
149 |
|
150 |
is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
|
151 |
next_flag = False
|
152 |
if (
|
153 |
-
self.config.
|
154 |
and (timestep is not None)
|
155 |
and (timestep in cur_config)
|
156 |
and (block_idx in cur_config[timestep]["block"])
|
@@ -159,7 +126,7 @@ class PABManager:
|
|
159 |
next_flag = True
|
160 |
count = count + 1
|
161 |
elif (
|
162 |
-
self.config.
|
163 |
and (timestep is not None)
|
164 |
and (is_t_in_skip_config)
|
165 |
and (block_idx in cur_config[skip_range[0]]["block"])
|
@@ -173,22 +140,22 @@ class PABManager:
|
|
173 |
|
174 |
def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
|
175 |
if is_temporal:
|
176 |
-
self.config.
|
177 |
else:
|
178 |
-
self.config.
|
179 |
|
180 |
def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
|
181 |
skip_start_t = skip_range[0]
|
182 |
if is_temporal:
|
183 |
skip_output = (
|
184 |
-
self.config.
|
185 |
-
if self.config.
|
186 |
else None
|
187 |
)
|
188 |
else:
|
189 |
skip_output = (
|
190 |
-
self.config.
|
191 |
-
if self.config.
|
192 |
else None
|
193 |
)
|
194 |
|
@@ -196,9 +163,9 @@ class PABManager:
|
|
196 |
if timestep == skip_range[-1]:
|
197 |
# TODO: save memory
|
198 |
if is_temporal:
|
199 |
-
del self.config.
|
200 |
else:
|
201 |
-
del self.config.
|
202 |
else:
|
203 |
raise ValueError(
|
204 |
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
|
@@ -207,10 +174,10 @@ class PABManager:
|
|
207 |
return skip_output
|
208 |
|
209 |
def get_spatial_mlp_outputs(self):
|
210 |
-
return self.config.
|
211 |
|
212 |
def get_temporal_mlp_outputs(self):
|
213 |
-
return self.config.
|
214 |
|
215 |
|
216 |
def set_pab_manager(config: PABConfig):
|
@@ -250,11 +217,6 @@ def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
|
|
250 |
return False, count
|
251 |
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
|
252 |
|
253 |
-
def if_broadcast_full(timestep: int, count: int, block_idx: int):
|
254 |
-
if not enable_pab():
|
255 |
-
return False, count
|
256 |
-
return PAB_MANAGER.if_broadcast_full(timestep, count, block_idx)
|
257 |
-
|
258 |
|
259 |
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
260 |
if not enable_pab():
|
@@ -268,97 +230,3 @@ def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False)
|
|
268 |
|
269 |
def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
|
270 |
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
|
271 |
-
|
272 |
-
|
273 |
-
def get_diffusion_skip():
|
274 |
-
return enable_pab() and PAB_MANAGER.config.diffusion_skip
|
275 |
-
|
276 |
-
|
277 |
-
def get_diffusion_timestep_respacing():
|
278 |
-
return PAB_MANAGER.config.diffusion_timestep_respacing
|
279 |
-
|
280 |
-
|
281 |
-
def get_diffusion_skip_timestep():
|
282 |
-
return enable_pab() and PAB_MANAGER.config.diffusion_skip_timestep
|
283 |
-
|
284 |
-
|
285 |
-
def space_timesteps(time_steps, time_bins):
|
286 |
-
num_bins = len(time_bins)
|
287 |
-
bin_size = time_steps // num_bins
|
288 |
-
|
289 |
-
result = []
|
290 |
-
|
291 |
-
for i, bin_count in enumerate(time_bins):
|
292 |
-
start = i * bin_size
|
293 |
-
end = start + bin_size
|
294 |
-
|
295 |
-
bin_steps = np.linspace(start, end, bin_count, endpoint=False, dtype=int).tolist()
|
296 |
-
result.extend(bin_steps)
|
297 |
-
|
298 |
-
result_tensor = torch.tensor(result, dtype=torch.int32)
|
299 |
-
sorted_tensor = torch.sort(result_tensor, descending=True).values
|
300 |
-
|
301 |
-
return sorted_tensor
|
302 |
-
|
303 |
-
|
304 |
-
def skip_diffusion_timestep(timesteps, diffusion_skip_timestep):
|
305 |
-
if isinstance(timesteps, list):
|
306 |
-
# If timesteps is a list, we assume each element is a tensor
|
307 |
-
timesteps_np = [t.cpu().numpy() for t in timesteps]
|
308 |
-
device = timesteps[0].device
|
309 |
-
else:
|
310 |
-
# If timesteps is a tensor
|
311 |
-
timesteps_np = timesteps.cpu().numpy()
|
312 |
-
device = timesteps.device
|
313 |
-
|
314 |
-
num_bins = len(diffusion_skip_timestep)
|
315 |
-
|
316 |
-
if isinstance(timesteps_np, list):
|
317 |
-
bin_size = len(timesteps_np) // num_bins
|
318 |
-
new_timesteps = []
|
319 |
-
|
320 |
-
for i in range(num_bins):
|
321 |
-
bin_start = i * bin_size
|
322 |
-
bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
|
323 |
-
bin_timesteps = timesteps_np[bin_start:bin_end]
|
324 |
-
|
325 |
-
if diffusion_skip_timestep[i] == 0:
|
326 |
-
# If the bin is marked with 0, keep all timesteps
|
327 |
-
new_timesteps.extend(bin_timesteps)
|
328 |
-
elif diffusion_skip_timestep[i] == 1:
|
329 |
-
# If the bin is marked with 1, omit the last timestep in the bin
|
330 |
-
new_timesteps.extend(bin_timesteps[1:])
|
331 |
-
|
332 |
-
new_timesteps_tensor = [torch.tensor(t, device=device) for t in new_timesteps]
|
333 |
-
else:
|
334 |
-
bin_size = len(timesteps_np) // num_bins
|
335 |
-
new_timesteps = []
|
336 |
-
|
337 |
-
for i in range(num_bins):
|
338 |
-
bin_start = i * bin_size
|
339 |
-
bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
|
340 |
-
bin_timesteps = timesteps_np[bin_start:bin_end]
|
341 |
-
|
342 |
-
if diffusion_skip_timestep[i] == 0:
|
343 |
-
# If the bin is marked with 0, keep all timesteps
|
344 |
-
new_timesteps.extend(bin_timesteps)
|
345 |
-
elif diffusion_skip_timestep[i] == 1:
|
346 |
-
# If the bin is marked with 1, omit the last timestep in the bin
|
347 |
-
new_timesteps.extend(bin_timesteps[1:])
|
348 |
-
elif diffusion_skip_timestep[i] != 0:
|
349 |
-
# If the bin is marked with a non-zero value, randomly omit n timesteps
|
350 |
-
if len(bin_timesteps) > diffusion_skip_timestep[i]:
|
351 |
-
indices_to_remove = set(random.sample(range(len(bin_timesteps)), diffusion_skip_timestep[i]))
|
352 |
-
timesteps_to_keep = [
|
353 |
-
timestep for idx, timestep in enumerate(bin_timesteps) if idx not in indices_to_remove
|
354 |
-
]
|
355 |
-
else:
|
356 |
-
timesteps_to_keep = bin_timesteps # 如果bin_timesteps的长度小于等于n,则不删除任何元素
|
357 |
-
new_timesteps.extend(timesteps_to_keep)
|
358 |
-
|
359 |
-
new_timesteps_tensor = torch.tensor(new_timesteps, device=device)
|
360 |
-
|
361 |
-
if isinstance(timesteps, list):
|
362 |
-
return new_timesteps_tensor
|
363 |
-
else:
|
364 |
-
return new_timesteps_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from videosys.utils.logging import logger
|
2 |
|
3 |
PAB_MANAGER = None
|
|
|
7 |
def __init__(
|
8 |
self,
|
9 |
steps: int,
|
10 |
+
cross_broadcast: bool = False,
|
11 |
+
cross_threshold: list = None,
|
12 |
+
cross_range: int = None,
|
13 |
+
spatial_broadcast: bool = False,
|
14 |
+
spatial_threshold: list = None,
|
15 |
+
spatial_range: int = None,
|
16 |
+
temporal_broadcast: bool = False,
|
17 |
+
temporal_threshold: list = None,
|
18 |
+
temporal_range: int = None,
|
19 |
+
mlp_broadcast: bool = False,
|
20 |
+
mlp_spatial_broadcast_config: dict = None,
|
21 |
+
mlp_temporal_broadcast_config: dict = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
):
|
23 |
self.steps = steps
|
24 |
|
25 |
self.cross_broadcast = cross_broadcast
|
26 |
self.cross_threshold = cross_threshold
|
27 |
+
self.cross_range = cross_range
|
28 |
|
29 |
self.spatial_broadcast = spatial_broadcast
|
30 |
self.spatial_threshold = spatial_threshold
|
31 |
+
self.spatial_range = spatial_range
|
32 |
|
33 |
self.temporal_broadcast = temporal_broadcast
|
34 |
self.temporal_threshold = temporal_threshold
|
35 |
+
self.temporal_range = temporal_range
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
self.mlp_broadcast = mlp_broadcast
|
38 |
+
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
|
39 |
+
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
|
40 |
+
self.mlp_temporal_outputs = {}
|
41 |
+
self.mlp_spatial_outputs = {}
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
|
44 |
class PABManager:
|
45 |
def __init__(self, config: PABConfig):
|
46 |
self.config: PABConfig = config
|
47 |
|
48 |
+
init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
|
49 |
+
init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
|
50 |
+
init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
|
51 |
+
init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
|
52 |
+
init_prompt += f" mlp broadcast: {config.mlp_broadcast}."
|
53 |
logger.info(init_prompt)
|
54 |
|
55 |
def if_broadcast_cross(self, timestep: int, count: int):
|
56 |
if (
|
57 |
self.config.cross_broadcast
|
58 |
and (timestep is not None)
|
59 |
+
and (count % self.config.cross_range != 0)
|
60 |
and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
|
61 |
):
|
62 |
flag = True
|
|
|
69 |
if (
|
70 |
self.config.temporal_broadcast
|
71 |
and (timestep is not None)
|
72 |
+
and (count % self.config.temporal_range != 0)
|
73 |
and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
|
74 |
):
|
75 |
flag = True
|
|
|
82 |
if (
|
83 |
self.config.spatial_broadcast
|
84 |
and (timestep is not None)
|
85 |
+
and (count % self.config.spatial_range != 0)
|
86 |
and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
|
87 |
):
|
88 |
flag = True
|
|
|
91 |
count = (count + 1) % self.config.steps
|
92 |
return flag, count
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
@staticmethod
|
95 |
def _is_t_in_skip_config(all_timesteps, timestep, config):
|
96 |
is_t_in_skip_config = False
|
|
|
106 |
return is_t_in_skip_config, skip_range
|
107 |
|
108 |
def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
109 |
+
if not self.config.mlp_broadcast:
|
110 |
return False, None, False, None
|
111 |
|
112 |
if is_temporal:
|
113 |
+
cur_config = self.config.mlp_temporal_broadcast_config
|
114 |
else:
|
115 |
+
cur_config = self.config.mlp_spatial_broadcast_config
|
116 |
|
117 |
is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
|
118 |
next_flag = False
|
119 |
if (
|
120 |
+
self.config.mlp_broadcast
|
121 |
and (timestep is not None)
|
122 |
and (timestep in cur_config)
|
123 |
and (block_idx in cur_config[timestep]["block"])
|
|
|
126 |
next_flag = True
|
127 |
count = count + 1
|
128 |
elif (
|
129 |
+
self.config.mlp_broadcast
|
130 |
and (timestep is not None)
|
131 |
and (is_t_in_skip_config)
|
132 |
and (block_idx in cur_config[skip_range[0]]["block"])
|
|
|
140 |
|
141 |
def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
|
142 |
if is_temporal:
|
143 |
+
self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output
|
144 |
else:
|
145 |
+
self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output
|
146 |
|
147 |
def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
|
148 |
skip_start_t = skip_range[0]
|
149 |
if is_temporal:
|
150 |
skip_output = (
|
151 |
+
self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None)
|
152 |
+
if self.config.mlp_temporal_outputs is not None
|
153 |
else None
|
154 |
)
|
155 |
else:
|
156 |
skip_output = (
|
157 |
+
self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None)
|
158 |
+
if self.config.mlp_spatial_outputs is not None
|
159 |
else None
|
160 |
)
|
161 |
|
|
|
163 |
if timestep == skip_range[-1]:
|
164 |
# TODO: save memory
|
165 |
if is_temporal:
|
166 |
+
del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)]
|
167 |
else:
|
168 |
+
del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)]
|
169 |
else:
|
170 |
raise ValueError(
|
171 |
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
|
|
|
174 |
return skip_output
|
175 |
|
176 |
def get_spatial_mlp_outputs(self):
|
177 |
+
return self.config.mlp_spatial_outputs
|
178 |
|
179 |
def get_temporal_mlp_outputs(self):
|
180 |
+
return self.config.mlp_temporal_outputs
|
181 |
|
182 |
|
183 |
def set_pab_manager(config: PABConfig):
|
|
|
217 |
return False, count
|
218 |
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
|
219 |
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
222 |
if not enable_pab():
|
|
|
230 |
|
231 |
def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
|
232 |
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
videosys/datasets/dataloader.py
DELETED
@@ -1,94 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
from typing import Iterator, Optional
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
7 |
-
from torch.utils.data.distributed import DistributedSampler
|
8 |
-
|
9 |
-
from videosys.core.parallel_mgr import ParallelManager
|
10 |
-
|
11 |
-
|
12 |
-
class StatefulDistributedSampler(DistributedSampler):
|
13 |
-
def __init__(
|
14 |
-
self,
|
15 |
-
dataset: Dataset,
|
16 |
-
num_replicas: Optional[int] = None,
|
17 |
-
rank: Optional[int] = None,
|
18 |
-
shuffle: bool = True,
|
19 |
-
seed: int = 0,
|
20 |
-
drop_last: bool = False,
|
21 |
-
) -> None:
|
22 |
-
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
|
23 |
-
self.start_index: int = 0
|
24 |
-
|
25 |
-
def __iter__(self) -> Iterator:
|
26 |
-
iterator = super().__iter__()
|
27 |
-
indices = list(iterator)
|
28 |
-
indices = indices[self.start_index :]
|
29 |
-
return iter(indices)
|
30 |
-
|
31 |
-
def __len__(self) -> int:
|
32 |
-
return self.num_samples - self.start_index
|
33 |
-
|
34 |
-
def set_start_index(self, start_index: int) -> None:
|
35 |
-
self.start_index = start_index
|
36 |
-
|
37 |
-
|
38 |
-
def prepare_dataloader(
|
39 |
-
dataset,
|
40 |
-
batch_size,
|
41 |
-
shuffle=False,
|
42 |
-
seed=1024,
|
43 |
-
drop_last=False,
|
44 |
-
pin_memory=False,
|
45 |
-
num_workers=0,
|
46 |
-
pg_manager: Optional[ParallelManager] = None,
|
47 |
-
**kwargs,
|
48 |
-
):
|
49 |
-
r"""
|
50 |
-
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
51 |
-
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
|
52 |
-
|
53 |
-
|
54 |
-
Args:
|
55 |
-
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
56 |
-
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
57 |
-
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
58 |
-
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
59 |
-
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
60 |
-
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
61 |
-
the batch size, then the last batch will be smaller, defaults to False.
|
62 |
-
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
63 |
-
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
64 |
-
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
65 |
-
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
66 |
-
|
67 |
-
Returns:
|
68 |
-
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
69 |
-
"""
|
70 |
-
_kwargs = kwargs.copy()
|
71 |
-
sampler = StatefulDistributedSampler(
|
72 |
-
dataset,
|
73 |
-
num_replicas=pg_manager.size(pg_manager.dp_axis),
|
74 |
-
rank=pg_manager.coordinate(pg_manager.dp_axis),
|
75 |
-
shuffle=shuffle,
|
76 |
-
)
|
77 |
-
|
78 |
-
# Deterministic dataloader
|
79 |
-
def seed_worker(worker_id):
|
80 |
-
worker_seed = seed
|
81 |
-
np.random.seed(worker_seed)
|
82 |
-
torch.manual_seed(worker_seed)
|
83 |
-
random.seed(worker_seed)
|
84 |
-
|
85 |
-
return DataLoader(
|
86 |
-
dataset,
|
87 |
-
batch_size=batch_size,
|
88 |
-
sampler=sampler,
|
89 |
-
worker_init_fn=seed_worker,
|
90 |
-
drop_last=drop_last,
|
91 |
-
pin_memory=pin_memory,
|
92 |
-
num_workers=num_workers,
|
93 |
-
**_kwargs,
|
94 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
videosys/datasets/image_transform.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
# Adapted from DiT
|
2 |
-
|
3 |
-
# This source code is licensed under the license found in the
|
4 |
-
# LICENSE file in the root directory of this source tree.
|
5 |
-
# --------------------------------------------------------
|
6 |
-
# References:
|
7 |
-
# DiT: https://github.com/facebookresearch/DiT
|
8 |
-
# --------------------------------------------------------
|
9 |
-
|
10 |
-
|
11 |
-
import numpy as np
|
12 |
-
import torchvision.transforms as transforms
|
13 |
-
from PIL import Image
|
14 |
-
|
15 |
-
|
16 |
-
def center_crop_arr(pil_image, image_size):
|
17 |
-
"""
|
18 |
-
Center cropping implementation from ADM.
|
19 |
-
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
20 |
-
"""
|
21 |
-
while min(*pil_image.size) >= 2 * image_size:
|
22 |
-
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
|
23 |
-
|
24 |
-
scale = image_size / min(*pil_image.size)
|
25 |
-
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
|
26 |
-
|
27 |
-
arr = np.array(pil_image)
|
28 |
-
crop_y = (arr.shape[0] - image_size) // 2
|
29 |
-
crop_x = (arr.shape[1] - image_size) // 2
|
30 |
-
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
|
31 |
-
|
32 |
-
|
33 |
-
def get_transforms_image(image_size=256):
|
34 |
-
transform = transforms.Compose(
|
35 |
-
[
|
36 |
-
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
|
37 |
-
transforms.RandomHorizontalFlip(),
|
38 |
-
transforms.ToTensor(),
|
39 |
-
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
40 |
-
]
|
41 |
-
)
|
42 |
-
return transform
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
videosys/datasets/video_transform.py
DELETED
@@ -1,441 +0,0 @@
|
|
1 |
-
# Adapted from OpenSora and Latte
|
2 |
-
|
3 |
-
# This source code is licensed under the license found in the
|
4 |
-
# LICENSE file in the root directory of this source tree.
|
5 |
-
# --------------------------------------------------------
|
6 |
-
# References:
|
7 |
-
# OpenSora: https://github.com/hpcaitech/Open-Sora
|
8 |
-
# Latte: https://github.com/Vchitect/Latte
|
9 |
-
# --------------------------------------------------------
|
10 |
-
|
11 |
-
import numbers
|
12 |
-
import random
|
13 |
-
|
14 |
-
import numpy as np
|
15 |
-
import torch
|
16 |
-
from PIL import Image
|
17 |
-
|
18 |
-
|
19 |
-
def _is_tensor_video_clip(clip):
|
20 |
-
if not torch.is_tensor(clip):
|
21 |
-
raise TypeError("clip should be Tensor. Got %s" % type(clip))
|
22 |
-
|
23 |
-
if not clip.ndimension() == 4:
|
24 |
-
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
|
25 |
-
|
26 |
-
return True
|
27 |
-
|
28 |
-
|
29 |
-
def center_crop_arr(pil_image, image_size):
|
30 |
-
"""
|
31 |
-
Center cropping implementation from ADM.
|
32 |
-
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
33 |
-
"""
|
34 |
-
while min(*pil_image.size) >= 2 * image_size:
|
35 |
-
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
|
36 |
-
|
37 |
-
scale = image_size / min(*pil_image.size)
|
38 |
-
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
|
39 |
-
|
40 |
-
arr = np.array(pil_image)
|
41 |
-
crop_y = (arr.shape[0] - image_size) // 2
|
42 |
-
crop_x = (arr.shape[1] - image_size) // 2
|
43 |
-
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
|
44 |
-
|
45 |
-
|
46 |
-
def crop(clip, i, j, h, w):
|
47 |
-
"""
|
48 |
-
Args:
|
49 |
-
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
50 |
-
"""
|
51 |
-
if len(clip.size()) != 4:
|
52 |
-
raise ValueError("clip should be a 4D tensor")
|
53 |
-
return clip[..., i : i + h, j : j + w]
|
54 |
-
|
55 |
-
|
56 |
-
def resize(clip, target_size, interpolation_mode):
|
57 |
-
if len(target_size) != 2:
|
58 |
-
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
59 |
-
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
|
60 |
-
|
61 |
-
|
62 |
-
def resize_scale(clip, target_size, interpolation_mode):
|
63 |
-
if len(target_size) != 2:
|
64 |
-
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
65 |
-
H, W = clip.size(-2), clip.size(-1)
|
66 |
-
scale_ = target_size[0] / min(H, W)
|
67 |
-
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
|
68 |
-
|
69 |
-
|
70 |
-
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
|
71 |
-
"""
|
72 |
-
Do spatial cropping and resizing to the video clip
|
73 |
-
Args:
|
74 |
-
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
75 |
-
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
76 |
-
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
77 |
-
h (int): Height of the cropped region.
|
78 |
-
w (int): Width of the cropped region.
|
79 |
-
size (tuple(int, int)): height and width of resized clip
|
80 |
-
Returns:
|
81 |
-
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
|
82 |
-
"""
|
83 |
-
if not _is_tensor_video_clip(clip):
|
84 |
-
raise ValueError("clip should be a 4D torch.tensor")
|
85 |
-
clip = crop(clip, i, j, h, w)
|
86 |
-
clip = resize(clip, size, interpolation_mode)
|
87 |
-
return clip
|
88 |
-
|
89 |
-
|
90 |
-
def center_crop(clip, crop_size):
|
91 |
-
if not _is_tensor_video_clip(clip):
|
92 |
-
raise ValueError("clip should be a 4D torch.tensor")
|
93 |
-
h, w = clip.size(-2), clip.size(-1)
|
94 |
-
th, tw = crop_size
|
95 |
-
if h < th or w < tw:
|
96 |
-
raise ValueError("height and width must be no smaller than crop_size")
|
97 |
-
|
98 |
-
i = int(round((h - th) / 2.0))
|
99 |
-
j = int(round((w - tw) / 2.0))
|
100 |
-
return crop(clip, i, j, th, tw)
|
101 |
-
|
102 |
-
|
103 |
-
def center_crop_using_short_edge(clip):
|
104 |
-
if not _is_tensor_video_clip(clip):
|
105 |
-
raise ValueError("clip should be a 4D torch.tensor")
|
106 |
-
h, w = clip.size(-2), clip.size(-1)
|
107 |
-
if h < w:
|
108 |
-
th, tw = h, h
|
109 |
-
i = 0
|
110 |
-
j = int(round((w - tw) / 2.0))
|
111 |
-
else:
|
112 |
-
th, tw = w, w
|
113 |
-
i = int(round((h - th) / 2.0))
|
114 |
-
j = 0
|
115 |
-
return crop(clip, i, j, th, tw)
|
116 |
-
|
117 |
-
|
118 |
-
def random_shift_crop(clip):
|
119 |
-
"""
|
120 |
-
Slide along the long edge, with the short edge as crop size
|
121 |
-
"""
|
122 |
-
if not _is_tensor_video_clip(clip):
|
123 |
-
raise ValueError("clip should be a 4D torch.tensor")
|
124 |
-
h, w = clip.size(-2), clip.size(-1)
|
125 |
-
|
126 |
-
if h <= w:
|
127 |
-
short_edge = h
|
128 |
-
else:
|
129 |
-
short_edge = w
|
130 |
-
|
131 |
-
th, tw = short_edge, short_edge
|
132 |
-
|
133 |
-
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
134 |
-
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
135 |
-
return crop(clip, i, j, th, tw)
|
136 |
-
|
137 |
-
|
138 |
-
def to_tensor(clip):
|
139 |
-
"""
|
140 |
-
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
141 |
-
permute the dimensions of clip tensor
|
142 |
-
Args:
|
143 |
-
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
144 |
-
Return:
|
145 |
-
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
146 |
-
"""
|
147 |
-
_is_tensor_video_clip(clip)
|
148 |
-
if not clip.dtype == torch.uint8:
|
149 |
-
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
|
150 |
-
# return clip.float().permute(3, 0, 1, 2) / 255.0
|
151 |
-
return clip.float() / 255.0
|
152 |
-
|
153 |
-
|
154 |
-
def normalize(clip, mean, std, inplace=False):
|
155 |
-
"""
|
156 |
-
Args:
|
157 |
-
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
158 |
-
mean (tuple): pixel RGB mean. Size is (3)
|
159 |
-
std (tuple): pixel standard deviation. Size is (3)
|
160 |
-
Returns:
|
161 |
-
normalized clip (torch.tensor): Size is (T, C, H, W)
|
162 |
-
"""
|
163 |
-
if not _is_tensor_video_clip(clip):
|
164 |
-
raise ValueError("clip should be a 4D torch.tensor")
|
165 |
-
if not inplace:
|
166 |
-
clip = clip.clone()
|
167 |
-
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
|
168 |
-
# print(mean)
|
169 |
-
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
|
170 |
-
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
171 |
-
return clip
|
172 |
-
|
173 |
-
|
174 |
-
def hflip(clip):
|
175 |
-
"""
|
176 |
-
Args:
|
177 |
-
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
178 |
-
Returns:
|
179 |
-
flipped clip (torch.tensor): Size is (T, C, H, W)
|
180 |
-
"""
|
181 |
-
if not _is_tensor_video_clip(clip):
|
182 |
-
raise ValueError("clip should be a 4D torch.tensor")
|
183 |
-
return clip.flip(-1)
|
184 |
-
|
185 |
-
|
186 |
-
class RandomCropVideo:
|
187 |
-
def __init__(self, size):
|
188 |
-
if isinstance(size, numbers.Number):
|
189 |
-
self.size = (int(size), int(size))
|
190 |
-
else:
|
191 |
-
self.size = size
|
192 |
-
|
193 |
-
def __call__(self, clip):
|
194 |
-
"""
|
195 |
-
Args:
|
196 |
-
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
197 |
-
Returns:
|
198 |
-
torch.tensor: randomly cropped video clip.
|
199 |
-
size is (T, C, OH, OW)
|
200 |
-
"""
|
201 |
-
i, j, h, w = self.get_params(clip)
|
202 |
-
return crop(clip, i, j, h, w)
|
203 |
-
|
204 |
-
def get_params(self, clip):
|
205 |
-
h, w = clip.shape[-2:]
|
206 |
-
th, tw = self.size
|
207 |
-
|
208 |
-
if h < th or w < tw:
|
209 |
-
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
|
210 |
-
|
211 |
-
if w == tw and h == th:
|
212 |
-
return 0, 0, h, w
|
213 |
-
|
214 |
-
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
215 |
-
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
216 |
-
|
217 |
-
return i, j, th, tw
|
218 |
-
|
219 |
-
def __repr__(self) -> str:
|
220 |
-
return f"{self.__class__.__name__}(size={self.size})"
|
221 |
-
|
222 |
-
|
223 |
-
class CenterCropResizeVideo:
|
224 |
-
"""
|
225 |
-
First use the short side for cropping length,
|
226 |
-
center crop video, then resize to the specified size
|
227 |
-
"""
|
228 |
-
|
229 |
-
def __init__(
|
230 |
-
self,
|
231 |
-
size,
|
232 |
-
interpolation_mode="bilinear",
|
233 |
-
):
|
234 |
-
if isinstance(size, tuple):
|
235 |
-
if len(size) != 2:
|
236 |
-
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
237 |
-
self.size = size
|
238 |
-
else:
|
239 |
-
self.size = (size, size)
|
240 |
-
|
241 |
-
self.interpolation_mode = interpolation_mode
|
242 |
-
|
243 |
-
def __call__(self, clip):
|
244 |
-
"""
|
245 |
-
Args:
|
246 |
-
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
247 |
-
Returns:
|
248 |
-
torch.tensor: scale resized / center cropped video clip.
|
249 |
-
size is (T, C, crop_size, crop_size)
|
250 |
-
"""
|
251 |
-
clip_center_crop = center_crop_using_short_edge(clip)
|
252 |
-
clip_center_crop_resize = resize(
|
253 |
-
clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode
|
254 |
-
)
|
255 |
-
return clip_center_crop_resize
|
256 |
-
|
257 |
-
def __repr__(self) -> str:
|
258 |
-
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
259 |
-
|
260 |
-
|
261 |
-
class UCFCenterCropVideo:
|
262 |
-
"""
|
263 |
-
First scale to the specified size in equal proportion to the short edge,
|
264 |
-
then center cropping
|
265 |
-
"""
|
266 |
-
|
267 |
-
def __init__(
|
268 |
-
self,
|
269 |
-
size,
|
270 |
-
interpolation_mode="bilinear",
|
271 |
-
):
|
272 |
-
if isinstance(size, tuple):
|
273 |
-
if len(size) != 2:
|
274 |
-
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
275 |
-
self.size = size
|
276 |
-
else:
|
277 |
-
self.size = (size, size)
|
278 |
-
|
279 |
-
self.interpolation_mode = interpolation_mode
|
280 |
-
|
281 |
-
def __call__(self, clip):
|
282 |
-
"""
|
283 |
-
Args:
|
284 |
-
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
285 |
-
Returns:
|
286 |
-
torch.tensor: scale resized / center cropped video clip.
|
287 |
-
size is (T, C, crop_size, crop_size)
|
288 |
-
"""
|
289 |
-
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
|
290 |
-
clip_center_crop = center_crop(clip_resize, self.size)
|
291 |
-
return clip_center_crop
|
292 |
-
|
293 |
-
def __repr__(self) -> str:
|
294 |
-
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
295 |
-
|
296 |
-
|
297 |
-
class KineticsRandomCropResizeVideo:
|
298 |
-
"""
|
299 |
-
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
|
300 |
-
"""
|
301 |
-
|
302 |
-
def __init__(
|
303 |
-
self,
|
304 |
-
size,
|
305 |
-
interpolation_mode="bilinear",
|
306 |
-
):
|
307 |
-
if isinstance(size, tuple):
|
308 |
-
if len(size) != 2:
|
309 |
-
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
310 |
-
self.size = size
|
311 |
-
else:
|
312 |
-
self.size = (size, size)
|
313 |
-
|
314 |
-
self.interpolation_mode = interpolation_mode
|
315 |
-
|
316 |
-
def __call__(self, clip):
|
317 |
-
clip_random_crop = random_shift_crop(clip)
|
318 |
-
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
|
319 |
-
return clip_resize
|
320 |
-
|
321 |
-
|
322 |
-
class CenterCropVideo:
|
323 |
-
def __init__(
|
324 |
-
self,
|
325 |
-
size,
|
326 |
-
interpolation_mode="bilinear",
|
327 |
-
):
|
328 |
-
if isinstance(size, tuple):
|
329 |
-
if len(size) != 2:
|
330 |
-
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
331 |
-
self.size = size
|
332 |
-
else:
|
333 |
-
self.size = (size, size)
|
334 |
-
|
335 |
-
self.interpolation_mode = interpolation_mode
|
336 |
-
|
337 |
-
def __call__(self, clip):
|
338 |
-
"""
|
339 |
-
Args:
|
340 |
-
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
341 |
-
Returns:
|
342 |
-
torch.tensor: center cropped video clip.
|
343 |
-
size is (T, C, crop_size, crop_size)
|
344 |
-
"""
|
345 |
-
clip_center_crop = center_crop(clip, self.size)
|
346 |
-
return clip_center_crop
|
347 |
-
|
348 |
-
def __repr__(self) -> str:
|
349 |
-
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
350 |
-
|
351 |
-
|
352 |
-
class NormalizeVideo:
|
353 |
-
"""
|
354 |
-
Normalize the video clip by mean subtraction and division by standard deviation
|
355 |
-
Args:
|
356 |
-
mean (3-tuple): pixel RGB mean
|
357 |
-
std (3-tuple): pixel RGB standard deviation
|
358 |
-
inplace (boolean): whether do in-place normalization
|
359 |
-
"""
|
360 |
-
|
361 |
-
def __init__(self, mean, std, inplace=False):
|
362 |
-
self.mean = mean
|
363 |
-
self.std = std
|
364 |
-
self.inplace = inplace
|
365 |
-
|
366 |
-
def __call__(self, clip):
|
367 |
-
"""
|
368 |
-
Args:
|
369 |
-
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
|
370 |
-
"""
|
371 |
-
return normalize(clip, self.mean, self.std, self.inplace)
|
372 |
-
|
373 |
-
def __repr__(self) -> str:
|
374 |
-
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
|
375 |
-
|
376 |
-
|
377 |
-
class ToTensorVideo:
|
378 |
-
"""
|
379 |
-
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
380 |
-
permute the dimensions of clip tensor
|
381 |
-
"""
|
382 |
-
|
383 |
-
def __init__(self):
|
384 |
-
pass
|
385 |
-
|
386 |
-
def __call__(self, clip):
|
387 |
-
"""
|
388 |
-
Args:
|
389 |
-
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
390 |
-
Return:
|
391 |
-
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
392 |
-
"""
|
393 |
-
return to_tensor(clip)
|
394 |
-
|
395 |
-
def __repr__(self) -> str:
|
396 |
-
return self.__class__.__name__
|
397 |
-
|
398 |
-
|
399 |
-
class RandomHorizontalFlipVideo:
|
400 |
-
"""
|
401 |
-
Flip the video clip along the horizontal direction with a given probability
|
402 |
-
Args:
|
403 |
-
p (float): probability of the clip being flipped. Default value is 0.5
|
404 |
-
"""
|
405 |
-
|
406 |
-
def __init__(self, p=0.5):
|
407 |
-
self.p = p
|
408 |
-
|
409 |
-
def __call__(self, clip):
|
410 |
-
"""
|
411 |
-
Args:
|
412 |
-
clip (torch.tensor): Size is (T, C, H, W)
|
413 |
-
Return:
|
414 |
-
clip (torch.tensor): Size is (T, C, H, W)
|
415 |
-
"""
|
416 |
-
if random.random() < self.p:
|
417 |
-
clip = hflip(clip)
|
418 |
-
return clip
|
419 |
-
|
420 |
-
def __repr__(self) -> str:
|
421 |
-
return f"{self.__class__.__name__}(p={self.p})"
|
422 |
-
|
423 |
-
|
424 |
-
# ------------------------------------------------------------
|
425 |
-
# --------------------- Sampling ---------------------------
|
426 |
-
# ------------------------------------------------------------
|
427 |
-
class TemporalRandomCrop(object):
|
428 |
-
"""Temporally crop the given frame indices at a random location.
|
429 |
-
|
430 |
-
Args:
|
431 |
-
size (int): Desired length of frames will be seen in the model.
|
432 |
-
"""
|
433 |
-
|
434 |
-
def __init__(self, size):
|
435 |
-
self.size = size
|
436 |
-
|
437 |
-
def __call__(self, total_frames):
|
438 |
-
rand_end = max(0, total_frames - self.size - 1)
|
439 |
-
begin_index = random.randint(0, rand_end)
|
440 |
-
end_index = min(begin_index + self.size, total_frames)
|
441 |
-
return begin_index, end_index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
videosys/diffusion/__init__.py
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
# Modified from OpenAI's diffusion repos and Meta DiT
|
2 |
-
# DiT: https://github.com/facebookresearch/DiT/tree/main
|
3 |
-
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
4 |
-
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
5 |
-
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
6 |
-
|
7 |
-
from . import gaussian_diffusion as gd
|
8 |
-
from .respace import SpacedDiffusion, space_timesteps
|
9 |
-
|
10 |
-
|
11 |
-
def create_diffusion(
|
12 |
-
timestep_respacing,
|
13 |
-
noise_schedule="linear",
|
14 |
-
use_kl=False,
|
15 |
-
sigma_small=False,
|
16 |
-
predict_xstart=False,
|
17 |
-
learn_sigma=True,
|
18 |
-
rescale_learned_sigmas=False,
|
19 |
-
diffusion_steps=1000,
|
20 |
-
):
|
21 |
-
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
|
22 |
-
if use_kl:
|
23 |
-
loss_type = gd.LossType.RESCALED_KL
|
24 |
-
elif rescale_learned_sigmas:
|
25 |
-
loss_type = gd.LossType.RESCALED_MSE
|
26 |
-
else:
|
27 |
-
loss_type = gd.LossType.MSE
|
28 |
-
if timestep_respacing is None or timestep_respacing == "":
|
29 |
-
timestep_respacing = [diffusion_steps]
|
30 |
-
return SpacedDiffusion(
|
31 |
-
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
|
32 |
-
betas=betas,
|
33 |
-
model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X),
|
34 |
-
model_var_type=(
|
35 |
-
(gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
|
36 |
-
if not learn_sigma
|
37 |
-
else gd.ModelVarType.LEARNED_RANGE
|
38 |
-
),
|
39 |
-
loss_type=loss_type
|
40 |
-
# rescale_timesteps=rescale_timesteps,
|
41 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
videosys/diffusion/diffusion_utils.py
DELETED
@@ -1,79 +0,0 @@
|
|
1 |
-
# Modified from OpenAI's diffusion repos
|
2 |
-
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
-
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
-
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import torch as th
|
8 |
-
|
9 |
-
|
10 |
-
def normal_kl(mean1, logvar1, mean2, logvar2):
|
11 |
-
"""
|
12 |
-
Compute the KL divergence between two gaussians.
|
13 |
-
Shapes are automatically broadcasted, so batches can be compared to
|
14 |
-
scalars, among other use cases.
|
15 |
-
"""
|
16 |
-
tensor = None
|
17 |
-
for obj in (mean1, logvar1, mean2, logvar2):
|
18 |
-
if isinstance(obj, th.Tensor):
|
19 |
-
tensor = obj
|
20 |
-
break
|
21 |
-
assert tensor is not None, "at least one argument must be a Tensor"
|
22 |
-
|
23 |
-
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
24 |
-
# Tensors, but it does not work for th.exp().
|
25 |
-
logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)]
|
26 |
-
|
27 |
-
return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2))
|
28 |
-
|
29 |
-
|
30 |
-
def approx_standard_normal_cdf(x):
|
31 |
-
"""
|
32 |
-
A fast approximation of the cumulative distribution function of the
|
33 |
-
standard normal.
|
34 |
-
"""
|
35 |
-
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
36 |
-
|
37 |
-
|
38 |
-
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
|
39 |
-
"""
|
40 |
-
Compute the log-likelihood of a continuous Gaussian distribution.
|
41 |
-
:param x: the targets
|
42 |
-
:param means: the Gaussian mean Tensor.
|
43 |
-
:param log_scales: the Gaussian log stddev Tensor.
|
44 |
-
:return: a tensor like x of log probabilities (in nats).
|
45 |
-
"""
|
46 |
-
centered_x = x - means
|
47 |
-
inv_stdv = th.exp(-log_scales)
|
48 |
-
normalized_x = centered_x * inv_stdv
|
49 |
-
log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
|
50 |
-
return log_probs
|
51 |
-
|
52 |
-
|
53 |
-
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
54 |
-
"""
|
55 |
-
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
56 |
-
given image.
|
57 |
-
:param x: the target images. It is assumed that this was uint8 values,
|
58 |
-
rescaled to the range [-1, 1].
|
59 |
-
:param means: the Gaussian mean Tensor.
|
60 |
-
:param log_scales: the Gaussian log stddev Tensor.
|
61 |
-
:return: a tensor like x of log probabilities (in nats).
|
62 |
-
"""
|
63 |
-
assert x.shape == means.shape == log_scales.shape
|
64 |
-
centered_x = x - means
|
65 |
-
inv_stdv = th.exp(-log_scales)
|
66 |
-
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
67 |
-
cdf_plus = approx_standard_normal_cdf(plus_in)
|
68 |
-
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
69 |
-
cdf_min = approx_standard_normal_cdf(min_in)
|
70 |
-
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
71 |
-
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
72 |
-
cdf_delta = cdf_plus - cdf_min
|
73 |
-
log_probs = th.where(
|
74 |
-
x < -0.999,
|
75 |
-
log_cdf_plus,
|
76 |
-
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
77 |
-
)
|
78 |
-
assert log_probs.shape == x.shape
|
79 |
-
return log_probs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
videosys/diffusion/gaussian_diffusion.py
DELETED
@@ -1,829 +0,0 @@
|
|
1 |
-
# Modified from OpenAI's diffusion repos
|
2 |
-
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
-
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
-
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
-
|
6 |
-
|
7 |
-
import enum
|
8 |
-
import math
|
9 |
-
|
10 |
-
import numpy as np
|
11 |
-
import torch as th
|
12 |
-
|
13 |
-
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
14 |
-
|
15 |
-
|
16 |
-
def mean_flat(tensor):
|
17 |
-
"""
|
18 |
-
Take the mean over all non-batch dimensions.
|
19 |
-
"""
|
20 |
-
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
21 |
-
|
22 |
-
|
23 |
-
class ModelMeanType(enum.Enum):
|
24 |
-
"""
|
25 |
-
Which type of output the model predicts.
|
26 |
-
"""
|
27 |
-
|
28 |
-
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
29 |
-
START_X = enum.auto() # the model predicts x_0
|
30 |
-
EPSILON = enum.auto() # the model predicts epsilon
|
31 |
-
|
32 |
-
|
33 |
-
class ModelVarType(enum.Enum):
|
34 |
-
"""
|
35 |
-
What is used as the model's output variance.
|
36 |
-
The LEARNED_RANGE option has been added to allow the model to predict
|
37 |
-
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
38 |
-
"""
|
39 |
-
|
40 |
-
LEARNED = enum.auto()
|
41 |
-
FIXED_SMALL = enum.auto()
|
42 |
-
FIXED_LARGE = enum.auto()
|
43 |
-
LEARNED_RANGE = enum.auto()
|
44 |
-
|
45 |
-
|
46 |
-
class LossType(enum.Enum):
|
47 |
-
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
48 |
-
RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances)
|
49 |
-
KL = enum.auto() # use the variational lower-bound
|
50 |
-
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
51 |
-
|
52 |
-
def is_vb(self):
|
53 |
-
return self == LossType.KL or self == LossType.RESCALED_KL
|
54 |
-
|
55 |
-
|
56 |
-
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
57 |
-
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
58 |
-
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
59 |
-
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
60 |
-
return betas
|
61 |
-
|
62 |
-
|
63 |
-
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
64 |
-
"""
|
65 |
-
This is the deprecated API for creating beta schedules.
|
66 |
-
See get_named_beta_schedule() for the new library of schedules.
|
67 |
-
"""
|
68 |
-
if beta_schedule == "quad":
|
69 |
-
betas = (
|
70 |
-
np.linspace(
|
71 |
-
beta_start**0.5,
|
72 |
-
beta_end**0.5,
|
73 |
-
num_diffusion_timesteps,
|
74 |
-
dtype=np.float64,
|
75 |
-
)
|
76 |
-
** 2
|
77 |
-
)
|
78 |
-
elif beta_schedule == "linear":
|
79 |
-
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
80 |
-
elif beta_schedule == "warmup10":
|
81 |
-
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
82 |
-
elif beta_schedule == "warmup50":
|
83 |
-
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
84 |
-
elif beta_schedule == "const":
|
85 |
-
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
86 |
-
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
87 |
-
betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
|
88 |
-
else:
|
89 |
-
raise NotImplementedError(beta_schedule)
|
90 |
-
assert betas.shape == (num_diffusion_timesteps,)
|
91 |
-
return betas
|
92 |
-
|
93 |
-
|
94 |
-
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
95 |
-
"""
|
96 |
-
Get a pre-defined beta schedule for the given name.
|
97 |
-
The beta schedule library consists of beta schedules which remain similar
|
98 |
-
in the limit of num_diffusion_timesteps.
|
99 |
-
Beta schedules may be added, but should not be removed or changed once
|
100 |
-
they are committed to maintain backwards compatibility.
|
101 |
-
"""
|
102 |
-
if schedule_name == "linear":
|
103 |
-
# Linear schedule from Ho et al, extended to work for any number of
|
104 |
-
# diffusion steps.
|
105 |
-
scale = 1000 / num_diffusion_timesteps
|
106 |
-
return get_beta_schedule(
|
107 |
-
"linear",
|
108 |
-
beta_start=scale * 0.0001,
|
109 |
-
beta_end=scale * 0.02,
|
110 |
-
num_diffusion_timesteps=num_diffusion_timesteps,
|
111 |
-
)
|
112 |
-
elif schedule_name == "squaredcos_cap_v2":
|
113 |
-
return betas_for_alpha_bar(
|
114 |
-
num_diffusion_timesteps,
|
115 |
-
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
116 |
-
)
|
117 |
-
else:
|
118 |
-
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
119 |
-
|
120 |
-
|
121 |
-
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
122 |
-
"""
|
123 |
-
Create a beta schedule that discretizes the given alpha_t_bar function,
|
124 |
-
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
125 |
-
:param num_diffusion_timesteps: the number of betas to produce.
|
126 |
-
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
127 |
-
produces the cumulative product of (1-beta) up to that
|
128 |
-
part of the diffusion process.
|
129 |
-
:param max_beta: the maximum beta to use; use values lower than 1 to
|
130 |
-
prevent singularities.
|
131 |
-
"""
|
132 |
-
betas = []
|
133 |
-
for i in range(num_diffusion_timesteps):
|
134 |
-
t1 = i / num_diffusion_timesteps
|
135 |
-
t2 = (i + 1) / num_diffusion_timesteps
|
136 |
-
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
137 |
-
return np.array(betas)
|
138 |
-
|
139 |
-
|
140 |
-
class GaussianDiffusion:
|
141 |
-
"""
|
142 |
-
Utilities for training and sampling diffusion models.
|
143 |
-
Original ported from this codebase:
|
144 |
-
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
145 |
-
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
146 |
-
starting at T and going to 1.
|
147 |
-
"""
|
148 |
-
|
149 |
-
def __init__(self, *, betas, model_mean_type, model_var_type, loss_type):
|
150 |
-
self.model_mean_type = model_mean_type
|
151 |
-
self.model_var_type = model_var_type
|
152 |
-
self.loss_type = loss_type
|
153 |
-
|
154 |
-
# Use float64 for accuracy.
|
155 |
-
betas = np.array(betas, dtype=np.float64)
|
156 |
-
self.betas = betas
|
157 |
-
assert len(betas.shape) == 1, "betas must be 1-D"
|
158 |
-
assert (betas > 0).all() and (betas <= 1).all()
|
159 |
-
|
160 |
-
self.num_timesteps = int(betas.shape[0])
|
161 |
-
|
162 |
-
alphas = 1.0 - betas
|
163 |
-
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
164 |
-
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
165 |
-
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
166 |
-
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
167 |
-
|
168 |
-
# calculations for diffusion q(x_t | x_{t-1}) and others
|
169 |
-
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
170 |
-
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
171 |
-
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
172 |
-
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
173 |
-
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
174 |
-
|
175 |
-
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
176 |
-
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
177 |
-
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
178 |
-
self.posterior_log_variance_clipped = (
|
179 |
-
np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
|
180 |
-
if len(self.posterior_variance) > 1
|
181 |
-
else np.array([])
|
182 |
-
)
|
183 |
-
|
184 |
-
self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
185 |
-
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
186 |
-
|
187 |
-
def q_mean_variance(self, x_start, t):
|
188 |
-
"""
|
189 |
-
Get the distribution q(x_t | x_0).
|
190 |
-
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
191 |
-
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
192 |
-
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
193 |
-
"""
|
194 |
-
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
195 |
-
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
196 |
-
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
197 |
-
return mean, variance, log_variance
|
198 |
-
|
199 |
-
def q_sample(self, x_start, t, noise=None):
|
200 |
-
"""
|
201 |
-
Diffuse the data for a given number of diffusion steps.
|
202 |
-
In other words, sample from q(x_t | x_0).
|
203 |
-
:param x_start: the initial data batch.
|
204 |
-
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
205 |
-
:param noise: if specified, the split-out normal noise.
|
206 |
-
:return: A noisy version of x_start.
|
207 |
-
"""
|
208 |
-
if noise is None:
|
209 |
-
noise = th.randn_like(x_start)
|
210 |
-
assert noise.shape == x_start.shape
|
211 |
-
return (
|
212 |
-
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
213 |
-
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
214 |
-
)
|
215 |
-
|
216 |
-
def q_posterior_mean_variance(self, x_start, x_t, t):
|
217 |
-
"""
|
218 |
-
Compute the mean and variance of the diffusion posterior:
|
219 |
-
q(x_{t-1} | x_t, x_0)
|
220 |
-
"""
|
221 |
-
assert x_start.shape == x_t.shape
|
222 |
-
posterior_mean = (
|
223 |
-
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
224 |
-
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
225 |
-
)
|
226 |
-
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
227 |
-
posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
|
228 |
-
assert (
|
229 |
-
posterior_mean.shape[0]
|
230 |
-
== posterior_variance.shape[0]
|
231 |
-
== posterior_log_variance_clipped.shape[0]
|
232 |
-
== x_start.shape[0]
|
233 |
-
)
|
234 |
-
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
235 |
-
|
236 |
-
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
237 |
-
"""
|
238 |
-
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
239 |
-
the initial x, x_0.
|
240 |
-
:param model: the model, which takes a signal and a batch of timesteps
|
241 |
-
as input.
|
242 |
-
:param x: the [N x C x ...] tensor at time t.
|
243 |
-
:param t: a 1-D Tensor of timesteps.
|
244 |
-
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
245 |
-
:param denoised_fn: if not None, a function which applies to the
|
246 |
-
x_start prediction before it is used to sample. Applies before
|
247 |
-
clip_denoised.
|
248 |
-
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
249 |
-
pass to the model. This can be used for conditioning.
|
250 |
-
:return: a dict with the following keys:
|
251 |
-
- 'mean': the model mean output.
|
252 |
-
- 'variance': the model variance output.
|
253 |
-
- 'log_variance': the log of 'variance'.
|
254 |
-
- 'pred_xstart': the prediction for x_0.
|
255 |
-
"""
|
256 |
-
if model_kwargs is None:
|
257 |
-
model_kwargs = {}
|
258 |
-
|
259 |
-
B, C = x.shape[:2]
|
260 |
-
assert t.shape == (B,)
|
261 |
-
model_output = model(x, t, **model_kwargs)
|
262 |
-
if isinstance(model_output, tuple):
|
263 |
-
model_output, extra = model_output
|
264 |
-
else:
|
265 |
-
extra = None
|
266 |
-
|
267 |
-
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
268 |
-
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
269 |
-
model_output, model_var_values = th.split(model_output, C, dim=1)
|
270 |
-
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
271 |
-
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
272 |
-
# The model_var_values is [-1, 1] for [min_var, max_var].
|
273 |
-
frac = (model_var_values + 1) / 2
|
274 |
-
model_log_variance = frac * max_log + (1 - frac) * min_log
|
275 |
-
model_variance = th.exp(model_log_variance)
|
276 |
-
else:
|
277 |
-
model_variance, model_log_variance = {
|
278 |
-
# for fixedlarge, we set the initial (log-)variance like so
|
279 |
-
# to get a better decoder log likelihood.
|
280 |
-
ModelVarType.FIXED_LARGE: (
|
281 |
-
np.append(self.posterior_variance[1], self.betas[1:]),
|
282 |
-
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
283 |
-
),
|
284 |
-
ModelVarType.FIXED_SMALL: (
|
285 |
-
self.posterior_variance,
|
286 |
-
self.posterior_log_variance_clipped,
|
287 |
-
),
|
288 |
-
}[self.model_var_type]
|
289 |
-
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
290 |
-
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
291 |
-
|
292 |
-
def process_xstart(x):
|
293 |
-
if denoised_fn is not None:
|
294 |
-
x = denoised_fn(x)
|
295 |
-
if clip_denoised:
|
296 |
-
return x.clamp(-1, 1)
|
297 |
-
return x
|
298 |
-
|
299 |
-
if self.model_mean_type == ModelMeanType.START_X:
|
300 |
-
pred_xstart = process_xstart(model_output)
|
301 |
-
else:
|
302 |
-
pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
|
303 |
-
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
304 |
-
|
305 |
-
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
306 |
-
return {
|
307 |
-
"mean": model_mean,
|
308 |
-
"variance": model_variance,
|
309 |
-
"log_variance": model_log_variance,
|
310 |
-
"pred_xstart": pred_xstart,
|
311 |
-
"extra": extra,
|
312 |
-
}
|
313 |
-
|
314 |
-
def _predict_xstart_from_eps(self, x_t, t, eps):
|
315 |
-
assert x_t.shape == eps.shape
|
316 |
-
return (
|
317 |
-
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
318 |
-
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
319 |
-
)
|
320 |
-
|
321 |
-
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
322 |
-
return (
|
323 |
-
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
324 |
-
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
325 |
-
|
326 |
-
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
327 |
-
"""
|
328 |
-
Compute the mean for the previous step, given a function cond_fn that
|
329 |
-
computes the gradient of a conditional log probability with respect to
|
330 |
-
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
331 |
-
condition on y.
|
332 |
-
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
333 |
-
"""
|
334 |
-
gradient = cond_fn(x, t, **model_kwargs)
|
335 |
-
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
336 |
-
return new_mean
|
337 |
-
|
338 |
-
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
339 |
-
"""
|
340 |
-
Compute what the p_mean_variance output would have been, should the
|
341 |
-
model's score function be conditioned by cond_fn.
|
342 |
-
See condition_mean() for details on cond_fn.
|
343 |
-
Unlike condition_mean(), this instead uses the conditioning strategy
|
344 |
-
from Song et al (2020).
|
345 |
-
"""
|
346 |
-
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
347 |
-
|
348 |
-
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
349 |
-
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
|
350 |
-
|
351 |
-
out = p_mean_var.copy()
|
352 |
-
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
353 |
-
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
354 |
-
return out
|
355 |
-
|
356 |
-
def p_sample(
|
357 |
-
self,
|
358 |
-
model,
|
359 |
-
x,
|
360 |
-
t,
|
361 |
-
clip_denoised=True,
|
362 |
-
denoised_fn=None,
|
363 |
-
cond_fn=None,
|
364 |
-
model_kwargs=None,
|
365 |
-
):
|
366 |
-
"""
|
367 |
-
Sample x_{t-1} from the model at the given timestep.
|
368 |
-
:param model: the model to sample from.
|
369 |
-
:param x: the current tensor at x_{t-1}.
|
370 |
-
:param t: the value of t, starting at 0 for the first diffusion step.
|
371 |
-
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
372 |
-
:param denoised_fn: if not None, a function which applies to the
|
373 |
-
x_start prediction before it is used to sample.
|
374 |
-
:param cond_fn: if not None, this is a gradient function that acts
|
375 |
-
similarly to the model.
|
376 |
-
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
377 |
-
pass to the model. This can be used for conditioning.
|
378 |
-
:return: a dict containing the following keys:
|
379 |
-
- 'sample': a random sample from the model.
|
380 |
-
- 'pred_xstart': a prediction of x_0.
|
381 |
-
"""
|
382 |
-
out = self.p_mean_variance(
|
383 |
-
model,
|
384 |
-
x,
|
385 |
-
t,
|
386 |
-
clip_denoised=clip_denoised,
|
387 |
-
denoised_fn=denoised_fn,
|
388 |
-
model_kwargs=model_kwargs,
|
389 |
-
)
|
390 |
-
noise = th.randn_like(x)
|
391 |
-
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
|
392 |
-
if cond_fn is not None:
|
393 |
-
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
394 |
-
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
395 |
-
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
396 |
-
|
397 |
-
def p_sample_loop(
|
398 |
-
self,
|
399 |
-
model,
|
400 |
-
shape,
|
401 |
-
noise=None,
|
402 |
-
clip_denoised=True,
|
403 |
-
denoised_fn=None,
|
404 |
-
cond_fn=None,
|
405 |
-
model_kwargs=None,
|
406 |
-
device=None,
|
407 |
-
progress=False,
|
408 |
-
):
|
409 |
-
"""
|
410 |
-
Generate samples from the model.
|
411 |
-
:param model: the model module.
|
412 |
-
:param shape: the shape of the samples, (N, C, H, W).
|
413 |
-
:param noise: if specified, the noise from the encoder to sample.
|
414 |
-
Should be of the same shape as `shape`.
|
415 |
-
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
416 |
-
:param denoised_fn: if not None, a function which applies to the
|
417 |
-
x_start prediction before it is used to sample.
|
418 |
-
:param cond_fn: if not None, this is a gradient function that acts
|
419 |
-
similarly to the model.
|
420 |
-
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
421 |
-
pass to the model. This can be used for conditioning.
|
422 |
-
:param device: if specified, the device to create the samples on.
|
423 |
-
If not specified, use a model parameter's device.
|
424 |
-
:param progress: if True, show a tqdm progress bar.
|
425 |
-
:return: a non-differentiable batch of samples.
|
426 |
-
"""
|
427 |
-
final = None
|
428 |
-
for sample in self.p_sample_loop_progressive(
|
429 |
-
model,
|
430 |
-
shape,
|
431 |
-
noise=noise,
|
432 |
-
clip_denoised=clip_denoised,
|
433 |
-
denoised_fn=denoised_fn,
|
434 |
-
cond_fn=cond_fn,
|
435 |
-
model_kwargs=model_kwargs,
|
436 |
-
device=device,
|
437 |
-
progress=progress,
|
438 |
-
):
|
439 |
-
final = sample
|
440 |
-
return final["sample"]
|
441 |
-
|
442 |
-
def p_sample_loop_progressive(
|
443 |
-
self,
|
444 |
-
model,
|
445 |
-
shape,
|
446 |
-
noise=None,
|
447 |
-
clip_denoised=True,
|
448 |
-
denoised_fn=None,
|
449 |
-
cond_fn=None,
|
450 |
-
model_kwargs=None,
|
451 |
-
device=None,
|
452 |
-
progress=False,
|
453 |
-
):
|
454 |
-
"""
|
455 |
-
Generate samples from the model and yield intermediate samples from
|
456 |
-
each timestep of diffusion.
|
457 |
-
Arguments are the same as p_sample_loop().
|
458 |
-
Returns a generator over dicts, where each dict is the return value of
|
459 |
-
p_sample().
|
460 |
-
"""
|
461 |
-
if device is None:
|
462 |
-
device = next(model.parameters()).device
|
463 |
-
assert isinstance(shape, (tuple, list))
|
464 |
-
if noise is not None:
|
465 |
-
img = noise
|
466 |
-
else:
|
467 |
-
img = th.randn(*shape, device=device)
|
468 |
-
indices = list(range(self.num_timesteps))[::-1]
|
469 |
-
|
470 |
-
if progress:
|
471 |
-
# Lazy import so that we don't depend on tqdm.
|
472 |
-
from tqdm.auto import tqdm
|
473 |
-
|
474 |
-
indices = tqdm(indices)
|
475 |
-
|
476 |
-
for i in indices:
|
477 |
-
t = th.tensor([i] * shape[0], device=device)
|
478 |
-
with th.no_grad():
|
479 |
-
out = self.p_sample(
|
480 |
-
model,
|
481 |
-
img,
|
482 |
-
t,
|
483 |
-
clip_denoised=clip_denoised,
|
484 |
-
denoised_fn=denoised_fn,
|
485 |
-
cond_fn=cond_fn,
|
486 |
-
model_kwargs=model_kwargs,
|
487 |
-
)
|
488 |
-
yield out
|
489 |
-
img = out["sample"]
|
490 |
-
|
491 |
-
def ddim_sample(
|
492 |
-
self,
|
493 |
-
model,
|
494 |
-
x,
|
495 |
-
t,
|
496 |
-
clip_denoised=True,
|
497 |
-
denoised_fn=None,
|
498 |
-
cond_fn=None,
|
499 |
-
model_kwargs=None,
|
500 |
-
eta=0.0,
|
501 |
-
):
|
502 |
-
"""
|
503 |
-
Sample x_{t-1} from the model using DDIM.
|
504 |
-
Same usage as p_sample().
|
505 |
-
"""
|
506 |
-
out = self.p_mean_variance(
|
507 |
-
model,
|
508 |
-
x,
|
509 |
-
t,
|
510 |
-
clip_denoised=clip_denoised,
|
511 |
-
denoised_fn=denoised_fn,
|
512 |
-
model_kwargs=model_kwargs,
|
513 |
-
)
|
514 |
-
if cond_fn is not None:
|
515 |
-
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
516 |
-
|
517 |
-
# Usually our model outputs epsilon, but we re-derive it
|
518 |
-
# in case we used x_start or x_prev prediction.
|
519 |
-
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
520 |
-
|
521 |
-
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
522 |
-
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
523 |
-
sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
524 |
-
# Equation 12.
|
525 |
-
noise = th.randn_like(x)
|
526 |
-
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
|
527 |
-
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
|
528 |
-
sample = mean_pred + nonzero_mask * sigma * noise
|
529 |
-
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
530 |
-
|
531 |
-
def ddim_reverse_sample(
|
532 |
-
self,
|
533 |
-
model,
|
534 |
-
x,
|
535 |
-
t,
|
536 |
-
clip_denoised=True,
|
537 |
-
denoised_fn=None,
|
538 |
-
cond_fn=None,
|
539 |
-
model_kwargs=None,
|
540 |
-
eta=0.0,
|
541 |
-
):
|
542 |
-
"""
|
543 |
-
Sample x_{t+1} from the model using DDIM reverse ODE.
|
544 |
-
"""
|
545 |
-
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
546 |
-
out = self.p_mean_variance(
|
547 |
-
model,
|
548 |
-
x,
|
549 |
-
t,
|
550 |
-
clip_denoised=clip_denoised,
|
551 |
-
denoised_fn=denoised_fn,
|
552 |
-
model_kwargs=model_kwargs,
|
553 |
-
)
|
554 |
-
if cond_fn is not None:
|
555 |
-
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
556 |
-
# Usually our model outputs epsilon, but we re-derive it
|
557 |
-
# in case we used x_start or x_prev prediction.
|
558 |
-
eps = (
|
559 |
-
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
|
560 |
-
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
561 |
-
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
562 |
-
|
563 |
-
# Equation 12. reversed
|
564 |
-
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
565 |
-
|
566 |
-
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
567 |
-
|
568 |
-
def ddim_sample_loop(
|
569 |
-
self,
|
570 |
-
model,
|
571 |
-
shape,
|
572 |
-
noise=None,
|
573 |
-
clip_denoised=True,
|
574 |
-
denoised_fn=None,
|
575 |
-
cond_fn=None,
|
576 |
-
model_kwargs=None,
|
577 |
-
device=None,
|
578 |
-
progress=False,
|
579 |
-
eta=0.0,
|
580 |
-
):
|
581 |
-
"""
|
582 |
-
Generate samples from the model using DDIM.
|
583 |
-
Same usage as p_sample_loop().
|
584 |
-
"""
|
585 |
-
final = None
|
586 |
-
for sample in self.ddim_sample_loop_progressive(
|
587 |
-
model,
|
588 |
-
shape,
|
589 |
-
noise=noise,
|
590 |
-
clip_denoised=clip_denoised,
|
591 |
-
denoised_fn=denoised_fn,
|
592 |
-
cond_fn=cond_fn,
|
593 |
-
model_kwargs=model_kwargs,
|
594 |
-
device=device,
|
595 |
-
progress=progress,
|
596 |
-
eta=eta,
|
597 |
-
):
|
598 |
-
final = sample
|
599 |
-
return final["sample"]
|
600 |
-
|
601 |
-
def ddim_sample_loop_progressive(
|
602 |
-
self,
|
603 |
-
model,
|
604 |
-
shape,
|
605 |
-
noise=None,
|
606 |
-
clip_denoised=True,
|
607 |
-
denoised_fn=None,
|
608 |
-
cond_fn=None,
|
609 |
-
model_kwargs=None,
|
610 |
-
device=None,
|
611 |
-
progress=False,
|
612 |
-
eta=0.0,
|
613 |
-
):
|
614 |
-
"""
|
615 |
-
Use DDIM to sample from the model and yield intermediate samples from
|
616 |
-
each timestep of DDIM.
|
617 |
-
Same usage as p_sample_loop_progressive().
|
618 |
-
"""
|
619 |
-
if device is None:
|
620 |
-
device = next(model.parameters()).device
|
621 |
-
assert isinstance(shape, (tuple, list))
|
622 |
-
if noise is not None:
|
623 |
-
img = noise
|
624 |
-
else:
|
625 |
-
img = th.randn(*shape, device=device)
|
626 |
-
indices = list(range(self.num_timesteps))[::-1]
|
627 |
-
|
628 |
-
if progress:
|
629 |
-
# Lazy import so that we don't depend on tqdm.
|
630 |
-
from tqdm.auto import tqdm
|
631 |
-
|
632 |
-
indices = tqdm(indices)
|
633 |
-
|
634 |
-
for i in indices:
|
635 |
-
t = th.tensor([i] * shape[0], device=device)
|
636 |
-
with th.no_grad():
|
637 |
-
out = self.ddim_sample(
|
638 |
-
model,
|
639 |
-
img,
|
640 |
-
t,
|
641 |
-
clip_denoised=clip_denoised,
|
642 |
-
denoised_fn=denoised_fn,
|
643 |
-
cond_fn=cond_fn,
|
644 |
-
model_kwargs=model_kwargs,
|
645 |
-
eta=eta,
|
646 |
-
)
|
647 |
-
yield out
|
648 |
-
img = out["sample"]
|
649 |
-
|
650 |
-
def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
|
651 |
-
"""
|
652 |
-
Get a term for the variational lower-bound.
|
653 |
-
The resulting units are bits (rather than nats, as one might expect).
|
654 |
-
This allows for comparison to other papers.
|
655 |
-
:return: a dict with the following keys:
|
656 |
-
- 'output': a shape [N] tensor of NLLs or KLs.
|
657 |
-
- 'pred_xstart': the x_0 predictions.
|
658 |
-
"""
|
659 |
-
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
|
660 |
-
out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
|
661 |
-
kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
|
662 |
-
kl = mean_flat(kl) / np.log(2.0)
|
663 |
-
|
664 |
-
decoder_nll = -discretized_gaussian_log_likelihood(
|
665 |
-
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
666 |
-
)
|
667 |
-
assert decoder_nll.shape == x_start.shape
|
668 |
-
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
669 |
-
|
670 |
-
# At the first timestep return the decoder NLL,
|
671 |
-
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
672 |
-
output = th.where((t == 0), decoder_nll, kl)
|
673 |
-
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
674 |
-
|
675 |
-
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
676 |
-
"""
|
677 |
-
Compute training losses for a single timestep.
|
678 |
-
:param model: the model to evaluate loss on.
|
679 |
-
:param x_start: the [N x C x ...] tensor of inputs.
|
680 |
-
:param t: a batch of timestep indices.
|
681 |
-
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
682 |
-
pass to the model. This can be used for conditioning.
|
683 |
-
:param noise: if specified, the specific Gaussian noise to try to remove.
|
684 |
-
:return: a dict with the key "loss" containing a tensor of shape [N].
|
685 |
-
Some mean or variance settings may also have other keys.
|
686 |
-
"""
|
687 |
-
if model_kwargs is None:
|
688 |
-
model_kwargs = {}
|
689 |
-
if noise is None:
|
690 |
-
noise = th.randn_like(x_start)
|
691 |
-
x_t = self.q_sample(x_start, t, noise=noise)
|
692 |
-
|
693 |
-
terms = {}
|
694 |
-
|
695 |
-
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
696 |
-
terms["loss"] = self._vb_terms_bpd(
|
697 |
-
model=model,
|
698 |
-
x_start=x_start,
|
699 |
-
x_t=x_t,
|
700 |
-
t=t,
|
701 |
-
clip_denoised=False,
|
702 |
-
model_kwargs=model_kwargs,
|
703 |
-
)["output"]
|
704 |
-
if self.loss_type == LossType.RESCALED_KL:
|
705 |
-
terms["loss"] *= self.num_timesteps
|
706 |
-
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
707 |
-
model_output = model(x_t, t, **model_kwargs)
|
708 |
-
|
709 |
-
if self.model_var_type in [
|
710 |
-
ModelVarType.LEARNED,
|
711 |
-
ModelVarType.LEARNED_RANGE,
|
712 |
-
]:
|
713 |
-
B, C = x_t.shape[:2]
|
714 |
-
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
|
715 |
-
model_output, model_var_values = th.split(model_output, C, dim=1)
|
716 |
-
# Learn the variance using the variational bound, but don't let
|
717 |
-
# it affect our mean prediction.
|
718 |
-
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
719 |
-
terms["vb"] = self._vb_terms_bpd(
|
720 |
-
model=lambda *args, r=frozen_out: r,
|
721 |
-
x_start=x_start,
|
722 |
-
x_t=x_t,
|
723 |
-
t=t,
|
724 |
-
clip_denoised=False,
|
725 |
-
)["output"]
|
726 |
-
if self.loss_type == LossType.RESCALED_MSE:
|
727 |
-
# Divide by 1000 for equivalence with initial implementation.
|
728 |
-
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
729 |
-
terms["vb"] *= self.num_timesteps / 1000.0
|
730 |
-
|
731 |
-
target = {
|
732 |
-
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
|
733 |
-
ModelMeanType.START_X: x_start,
|
734 |
-
ModelMeanType.EPSILON: noise,
|
735 |
-
}[self.model_mean_type]
|
736 |
-
assert model_output.shape == target.shape == x_start.shape
|
737 |
-
terms["mse"] = mean_flat((target - model_output) ** 2)
|
738 |
-
if "vb" in terms:
|
739 |
-
terms["loss"] = terms["mse"] + terms["vb"]
|
740 |
-
else:
|
741 |
-
terms["loss"] = terms["mse"]
|
742 |
-
else:
|
743 |
-
raise NotImplementedError(self.loss_type)
|
744 |
-
|
745 |
-
return terms
|
746 |
-
|
747 |
-
def _prior_bpd(self, x_start):
|
748 |
-
"""
|
749 |
-
Get the prior KL term for the variational lower-bound, measured in
|
750 |
-
bits-per-dim.
|
751 |
-
This term can't be optimized, as it only depends on the encoder.
|
752 |
-
:param x_start: the [N x C x ...] tensor of inputs.
|
753 |
-
:return: a batch of [N] KL values (in bits), one per batch element.
|
754 |
-
"""
|
755 |
-
batch_size = x_start.shape[0]
|
756 |
-
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
757 |
-
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
758 |
-
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
759 |
-
return mean_flat(kl_prior) / np.log(2.0)
|
760 |
-
|
761 |
-
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
762 |
-
"""
|
763 |
-
Compute the entire variational lower-bound, measured in bits-per-dim,
|
764 |
-
as well as other related quantities.
|
765 |
-
:param model: the model to evaluate loss on.
|
766 |
-
:param x_start: the [N x C x ...] tensor of inputs.
|
767 |
-
:param clip_denoised: if True, clip denoised samples.
|
768 |
-
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
769 |
-
pass to the model. This can be used for conditioning.
|
770 |
-
:return: a dict containing the following keys:
|
771 |
-
- total_bpd: the total variational lower-bound, per batch element.
|
772 |
-
- prior_bpd: the prior term in the lower-bound.
|
773 |
-
- vb: an [N x T] tensor of terms in the lower-bound.
|
774 |
-
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
775 |
-
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
776 |
-
"""
|
777 |
-
device = x_start.device
|
778 |
-
batch_size = x_start.shape[0]
|
779 |
-
|
780 |
-
vb = []
|
781 |
-
xstart_mse = []
|
782 |
-
mse = []
|
783 |
-
for t in list(range(self.num_timesteps))[::-1]:
|
784 |
-
t_batch = th.tensor([t] * batch_size, device=device)
|
785 |
-
noise = th.randn_like(x_start)
|
786 |
-
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
787 |
-
# Calculate VLB term at the current timestep
|
788 |
-
with th.no_grad():
|
789 |
-
out = self._vb_terms_bpd(
|
790 |
-
model,
|
791 |
-
x_start=x_start,
|
792 |
-
x_t=x_t,
|
793 |
-
t=t_batch,
|
794 |
-
clip_denoised=clip_denoised,
|
795 |
-
model_kwargs=model_kwargs,
|
796 |
-
)
|
797 |
-
vb.append(out["output"])
|
798 |
-
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
799 |
-
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
800 |
-
mse.append(mean_flat((eps - noise) ** 2))
|
801 |
-
|
802 |
-
vb = th.stack(vb, dim=1)
|
803 |
-
xstart_mse = th.stack(xstart_mse, dim=1)
|
804 |
-
mse = th.stack(mse, dim=1)
|
805 |
-
|
806 |
-
prior_bpd = self._prior_bpd(x_start)
|
807 |
-
total_bpd = vb.sum(dim=1) + prior_bpd
|
808 |
-
return {
|
809 |
-
"total_bpd": total_bpd,
|
810 |
-
"prior_bpd": prior_bpd,
|
811 |
-
"vb": vb,
|
812 |
-
"xstart_mse": xstart_mse,
|
813 |
-
"mse": mse,
|
814 |
-
}
|
815 |
-
|
816 |
-
|
817 |
-
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
818 |
-
"""
|
819 |
-
Extract values from a 1-D numpy array for a batch of indices.
|
820 |
-
:param arr: the 1-D numpy array.
|
821 |
-
:param timesteps: a tensor of indices into the array to extract.
|
822 |
-
:param broadcast_shape: a larger shape of K dimensions with the batch
|
823 |
-
dimension equal to the length of timesteps.
|
824 |
-
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
825 |
-
"""
|
826 |
-
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
827 |
-
while len(res.shape) < len(broadcast_shape):
|
828 |
-
res = res[..., None]
|
829 |
-
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
videosys/diffusion/respace.py
DELETED
@@ -1,119 +0,0 @@
|
|
1 |
-
# Modified from OpenAI's diffusion repos
|
2 |
-
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
-
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
-
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import torch as th
|
8 |
-
|
9 |
-
from .gaussian_diffusion import GaussianDiffusion
|
10 |
-
|
11 |
-
|
12 |
-
def space_timesteps(num_timesteps, section_counts):
|
13 |
-
"""
|
14 |
-
Create a list of timesteps to use from an original diffusion process,
|
15 |
-
given the number of timesteps we want to take from equally-sized portions
|
16 |
-
of the original process.
|
17 |
-
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
18 |
-
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
19 |
-
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
20 |
-
If the stride is a string starting with "ddim", then the fixed striding
|
21 |
-
from the DDIM paper is used, and only one section is allowed.
|
22 |
-
:param num_timesteps: the number of diffusion steps in the original
|
23 |
-
process to divide up.
|
24 |
-
:param section_counts: either a list of numbers, or a string containing
|
25 |
-
comma-separated numbers, indicating the step count
|
26 |
-
per section. As a special case, use "ddimN" where N
|
27 |
-
is a number of steps to use the striding from the
|
28 |
-
DDIM paper.
|
29 |
-
:return: a set of diffusion steps from the original process to use.
|
30 |
-
"""
|
31 |
-
if isinstance(section_counts, str):
|
32 |
-
if section_counts.startswith("ddim"):
|
33 |
-
desired_count = int(section_counts[len("ddim") :])
|
34 |
-
for i in range(1, num_timesteps):
|
35 |
-
if len(range(0, num_timesteps, i)) == desired_count:
|
36 |
-
return set(range(0, num_timesteps, i))
|
37 |
-
raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
|
38 |
-
section_counts = [int(x) for x in section_counts.split(",")]
|
39 |
-
size_per = num_timesteps // len(section_counts)
|
40 |
-
extra = num_timesteps % len(section_counts)
|
41 |
-
start_idx = 0
|
42 |
-
all_steps = []
|
43 |
-
for i, section_count in enumerate(section_counts):
|
44 |
-
size = size_per + (1 if i < extra else 0)
|
45 |
-
if size < section_count:
|
46 |
-
raise ValueError(f"cannot divide section of {size} steps into {section_count}")
|
47 |
-
if section_count <= 1:
|
48 |
-
frac_stride = 1
|
49 |
-
else:
|
50 |
-
frac_stride = (size - 1) / (section_count - 1)
|
51 |
-
cur_idx = 0.0
|
52 |
-
taken_steps = []
|
53 |
-
for _ in range(section_count):
|
54 |
-
taken_steps.append(start_idx + round(cur_idx))
|
55 |
-
cur_idx += frac_stride
|
56 |
-
all_steps += taken_steps
|
57 |
-
start_idx += size
|
58 |
-
return set(all_steps)
|
59 |
-
|
60 |
-
|
61 |
-
class SpacedDiffusion(GaussianDiffusion):
|
62 |
-
"""
|
63 |
-
A diffusion process which can skip steps in a base diffusion process.
|
64 |
-
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
65 |
-
original diffusion process to retain.
|
66 |
-
:param kwargs: the kwargs to create the base diffusion process.
|
67 |
-
"""
|
68 |
-
|
69 |
-
def __init__(self, use_timesteps, **kwargs):
|
70 |
-
self.use_timesteps = set(use_timesteps)
|
71 |
-
self.timestep_map = []
|
72 |
-
self.original_num_steps = len(kwargs["betas"])
|
73 |
-
|
74 |
-
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
75 |
-
last_alpha_cumprod = 1.0
|
76 |
-
new_betas = []
|
77 |
-
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
78 |
-
if i in self.use_timesteps:
|
79 |
-
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
80 |
-
last_alpha_cumprod = alpha_cumprod
|
81 |
-
self.timestep_map.append(i)
|
82 |
-
kwargs["betas"] = np.array(new_betas)
|
83 |
-
super().__init__(**kwargs)
|
84 |
-
|
85 |
-
def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
|
86 |
-
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
87 |
-
|
88 |
-
def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
|
89 |
-
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
90 |
-
|
91 |
-
def condition_mean(self, cond_fn, *args, **kwargs):
|
92 |
-
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
93 |
-
|
94 |
-
def condition_score(self, cond_fn, *args, **kwargs):
|
95 |
-
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
96 |
-
|
97 |
-
def _wrap_model(self, model):
|
98 |
-
if isinstance(model, _WrappedModel):
|
99 |
-
return model
|
100 |
-
return _WrappedModel(model, self.timestep_map, self.original_num_steps)
|
101 |
-
|
102 |
-
def _scale_timesteps(self, t):
|
103 |
-
# Scaling is done by the wrapped model.
|
104 |
-
return t
|
105 |
-
|
106 |
-
|
107 |
-
class _WrappedModel:
|
108 |
-
def __init__(self, model, timestep_map, original_num_steps):
|
109 |
-
self.model = model
|
110 |
-
self.timestep_map = timestep_map
|
111 |
-
# self.rescale_timesteps = rescale_timesteps
|
112 |
-
self.original_num_steps = original_num_steps
|
113 |
-
|
114 |
-
def __call__(self, x, ts, **kwargs):
|
115 |
-
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
116 |
-
new_ts = map_tensor[ts]
|
117 |
-
# if self.rescale_timesteps:
|
118 |
-
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
119 |
-
return self.model(x, new_ts, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
videosys/diffusion/timestep_sampler.py
DELETED
@@ -1,143 +0,0 @@
|
|
1 |
-
# Modified from OpenAI's diffusion repos
|
2 |
-
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
-
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
-
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
-
|
6 |
-
from abc import ABC, abstractmethod
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import torch as th
|
10 |
-
import torch.distributed as dist
|
11 |
-
|
12 |
-
|
13 |
-
def create_named_schedule_sampler(name, diffusion):
|
14 |
-
"""
|
15 |
-
Create a ScheduleSampler from a library of pre-defined samplers.
|
16 |
-
:param name: the name of the sampler.
|
17 |
-
:param diffusion: the diffusion object to sample for.
|
18 |
-
"""
|
19 |
-
if name == "uniform":
|
20 |
-
return UniformSampler(diffusion)
|
21 |
-
elif name == "loss-second-moment":
|
22 |
-
return LossSecondMomentResampler(diffusion)
|
23 |
-
else:
|
24 |
-
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
25 |
-
|
26 |
-
|
27 |
-
class ScheduleSampler(ABC):
|
28 |
-
"""
|
29 |
-
A distribution over timesteps in the diffusion process, intended to reduce
|
30 |
-
variance of the objective.
|
31 |
-
By default, samplers perform unbiased importance sampling, in which the
|
32 |
-
objective's mean is unchanged.
|
33 |
-
However, subclasses may override sample() to change how the resampled
|
34 |
-
terms are reweighted, allowing for actual changes in the objective.
|
35 |
-
"""
|
36 |
-
|
37 |
-
@abstractmethod
|
38 |
-
def weights(self):
|
39 |
-
"""
|
40 |
-
Get a numpy array of weights, one per diffusion step.
|
41 |
-
The weights needn't be normalized, but must be positive.
|
42 |
-
"""
|
43 |
-
|
44 |
-
def sample(self, batch_size, device):
|
45 |
-
"""
|
46 |
-
Importance-sample timesteps for a batch.
|
47 |
-
:param batch_size: the number of timesteps.
|
48 |
-
:param device: the torch device to save to.
|
49 |
-
:return: a tuple (timesteps, weights):
|
50 |
-
- timesteps: a tensor of timestep indices.
|
51 |
-
- weights: a tensor of weights to scale the resulting losses.
|
52 |
-
"""
|
53 |
-
w = self.weights()
|
54 |
-
p = w / np.sum(w)
|
55 |
-
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
|
56 |
-
indices = th.from_numpy(indices_np).long().to(device)
|
57 |
-
weights_np = 1 / (len(p) * p[indices_np])
|
58 |
-
weights = th.from_numpy(weights_np).float().to(device)
|
59 |
-
return indices, weights
|
60 |
-
|
61 |
-
|
62 |
-
class UniformSampler(ScheduleSampler):
|
63 |
-
def __init__(self, diffusion):
|
64 |
-
self.diffusion = diffusion
|
65 |
-
self._weights = np.ones([diffusion.num_timesteps])
|
66 |
-
|
67 |
-
def weights(self):
|
68 |
-
return self._weights
|
69 |
-
|
70 |
-
|
71 |
-
class LossAwareSampler(ScheduleSampler):
|
72 |
-
def update_with_local_losses(self, local_ts, local_losses):
|
73 |
-
"""
|
74 |
-
Update the reweighting using losses from a model.
|
75 |
-
Call this method from each rank with a batch of timesteps and the
|
76 |
-
corresponding losses for each of those timesteps.
|
77 |
-
This method will perform synchronization to make sure all of the ranks
|
78 |
-
maintain the exact same reweighting.
|
79 |
-
:param local_ts: an integer Tensor of timesteps.
|
80 |
-
:param local_losses: a 1D Tensor of losses.
|
81 |
-
"""
|
82 |
-
batch_sizes = [th.tensor([0], dtype=th.int32, device=local_ts.device) for _ in range(dist.get_world_size())]
|
83 |
-
dist.all_gather(
|
84 |
-
batch_sizes,
|
85 |
-
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
|
86 |
-
)
|
87 |
-
|
88 |
-
# Pad all_gather batches to be the maximum batch size.
|
89 |
-
batch_sizes = [x.item() for x in batch_sizes]
|
90 |
-
max_bs = max(batch_sizes)
|
91 |
-
|
92 |
-
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
|
93 |
-
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
|
94 |
-
dist.all_gather(timestep_batches, local_ts)
|
95 |
-
dist.all_gather(loss_batches, local_losses)
|
96 |
-
timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]]
|
97 |
-
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
|
98 |
-
self.update_with_all_losses(timesteps, losses)
|
99 |
-
|
100 |
-
@abstractmethod
|
101 |
-
def update_with_all_losses(self, ts, losses):
|
102 |
-
"""
|
103 |
-
Update the reweighting using losses from a model.
|
104 |
-
Sub-classes should override this method to update the reweighting
|
105 |
-
using losses from the model.
|
106 |
-
This method directly updates the reweighting without synchronizing
|
107 |
-
between workers. It is called by update_with_local_losses from all
|
108 |
-
ranks with identical arguments. Thus, it should have deterministic
|
109 |
-
behavior to maintain state across workers.
|
110 |
-
:param ts: a list of int timesteps.
|
111 |
-
:param losses: a list of float losses, one per timestep.
|
112 |
-
"""
|
113 |
-
|
114 |
-
|
115 |
-
class LossSecondMomentResampler(LossAwareSampler):
|
116 |
-
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
|
117 |
-
self.diffusion = diffusion
|
118 |
-
self.history_per_term = history_per_term
|
119 |
-
self.uniform_prob = uniform_prob
|
120 |
-
self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64)
|
121 |
-
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
|
122 |
-
|
123 |
-
def weights(self):
|
124 |
-
if not self._warmed_up():
|
125 |
-
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
|
126 |
-
weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
|
127 |
-
weights /= np.sum(weights)
|
128 |
-
weights *= 1 - self.uniform_prob
|
129 |
-
weights += self.uniform_prob / len(weights)
|
130 |
-
return weights
|
131 |
-
|
132 |
-
def update_with_all_losses(self, ts, losses):
|
133 |
-
for t, loss in zip(ts, losses):
|
134 |
-
if self._loss_counts[t] == self.history_per_term:
|
135 |
-
# Shift out the oldest loss term.
|
136 |
-
self._loss_history[t, :-1] = self._loss_history[t, 1:]
|
137 |
-
self._loss_history[t, -1] = loss
|
138 |
-
else:
|
139 |
-
self._loss_history[t, self._loss_counts[t]] = loss
|
140 |
-
self._loss_counts[t] += 1
|
141 |
-
|
142 |
-
def _warmed_up(self):
|
143 |
-
return (self._loss_counts == self.history_per_term).all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{eval/pab/commom_metrics → videosys/models/autoencoders}/__init__.py
RENAMED
File without changes
|
videosys/models/{cogvideo/autoencoder_kl.py → autoencoders/autoencoder_kl_cogvideox.py}
RENAMED
@@ -20,16 +20,16 @@ from diffusers.models.activations import get_activation
|
|
20 |
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
|
21 |
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
22 |
from diffusers.models.modeling_utils import ModelMixin
|
23 |
-
from diffusers.utils import logging
|
24 |
from diffusers.utils.accelerate_utils import apply_forward_hook
|
25 |
|
26 |
-
from .
|
27 |
|
28 |
-
|
|
|
29 |
|
30 |
|
31 |
class CogVideoXSafeConv3d(nn.Conv3d):
|
32 |
-
"""
|
33 |
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
|
34 |
"""
|
35 |
|
@@ -61,12 +61,12 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
61 |
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
|
62 |
|
63 |
Args:
|
64 |
-
in_channels (int): Number of channels in the input tensor.
|
65 |
-
out_channels (int): Number of output channels.
|
66 |
-
kernel_size (
|
67 |
-
stride (int
|
68 |
-
dilation (int
|
69 |
-
pad_mode (str
|
70 |
"""
|
71 |
|
72 |
def __init__(
|
@@ -111,19 +111,10 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
111 |
self.conv_cache = None
|
112 |
|
113 |
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
114 |
-
dim = self.temporal_dim
|
115 |
kernel_size = self.time_kernel_size
|
116 |
-
if kernel_size
|
117 |
-
|
118 |
-
|
119 |
-
inputs = inputs.transpose(0, dim)
|
120 |
-
|
121 |
-
if self.conv_cache is not None:
|
122 |
-
inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
|
123 |
-
else:
|
124 |
-
inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)
|
125 |
-
|
126 |
-
inputs = inputs.transpose(0, dim).contiguous()
|
127 |
return inputs
|
128 |
|
129 |
def _clear_fake_context_parallel_cache(self):
|
@@ -131,16 +122,17 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
131 |
self.conv_cache = None
|
132 |
|
133 |
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
134 |
-
|
135 |
|
136 |
self._clear_fake_context_parallel_cache()
|
137 |
-
|
|
|
|
|
138 |
|
139 |
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
140 |
-
|
141 |
|
142 |
-
|
143 |
-
output = output_parallel
|
144 |
return output
|
145 |
|
146 |
|
@@ -156,6 +148,8 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
|
156 |
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
157 |
zq_channels (`int`):
|
158 |
The number of channels for the quantized vector as described in the paper.
|
|
|
|
|
159 |
"""
|
160 |
|
161 |
def __init__(
|
@@ -190,17 +184,26 @@ class CogVideoXResnetBlock3D(nn.Module):
|
|
190 |
A 3D ResNet block used in the CogVideoX model.
|
191 |
|
192 |
Args:
|
193 |
-
in_channels (int):
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
"""
|
205 |
|
206 |
def __init__(
|
@@ -302,18 +305,28 @@ class CogVideoXDownBlock3D(nn.Module):
|
|
302 |
A downsampling block used in the CogVideoX model.
|
303 |
|
304 |
Args:
|
305 |
-
in_channels (int):
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
"""
|
318 |
|
319 |
_supports_gradient_checkpointing = True
|
@@ -398,15 +411,24 @@ class CogVideoXMidBlock3D(nn.Module):
|
|
398 |
A middle block used in the CogVideoX model.
|
399 |
|
400 |
Args:
|
401 |
-
in_channels (int):
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
"""
|
411 |
|
412 |
_supports_gradient_checkpointing = True
|
@@ -473,19 +495,30 @@ class CogVideoXUpBlock3D(nn.Module):
|
|
473 |
An upsampling block used in the CogVideoX model.
|
474 |
|
475 |
Args:
|
476 |
-
in_channels (int):
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
"""
|
490 |
|
491 |
def __init__(
|
@@ -576,14 +609,12 @@ class CogVideoXEncoder3D(nn.Module):
|
|
576 |
options.
|
577 |
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
578 |
The number of output channels for each block.
|
|
|
|
|
579 |
layers_per_block (`int`, *optional*, defaults to 2):
|
580 |
The number of layers per block.
|
581 |
norm_num_groups (`int`, *optional*, defaults to 32):
|
582 |
The number of groups for normalization.
|
583 |
-
act_fn (`str`, *optional*, defaults to `"silu"`):
|
584 |
-
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
585 |
-
double_z (`bool`, *optional*, defaults to `True`):
|
586 |
-
Whether to double the number of output channels for the last block.
|
587 |
"""
|
588 |
|
589 |
_supports_gradient_checkpointing = True
|
@@ -712,14 +743,12 @@ class CogVideoXDecoder3D(nn.Module):
|
|
712 |
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
713 |
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
714 |
The number of output channels for each block.
|
|
|
|
|
715 |
layers_per_block (`int`, *optional*, defaults to 2):
|
716 |
The number of layers per block.
|
717 |
norm_num_groups (`int`, *optional*, defaults to 32):
|
718 |
The number of groups for normalization.
|
719 |
-
act_fn (`str`, *optional*, defaults to `"silu"`):
|
720 |
-
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
721 |
-
norm_type (`str`, *optional*, defaults to `"group"`):
|
722 |
-
The normalization type to use. Can be either `"group"` or `"spatial"`.
|
723 |
"""
|
724 |
|
725 |
_supports_gradient_checkpointing = True
|
@@ -860,7 +889,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
860 |
Tuple of block output channels.
|
861 |
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
862 |
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
863 |
-
scaling_factor (`float`, *optional*, defaults to
|
864 |
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
865 |
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
866 |
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
@@ -900,7 +929,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
900 |
norm_eps: float = 1e-6,
|
901 |
norm_num_groups: int = 32,
|
902 |
temporal_compression_ratio: float = 4,
|
903 |
-
|
|
|
904 |
scaling_factor: float = 1.15258426,
|
905 |
shift_factor: Optional[float] = None,
|
906 |
latents_mean: Optional[Tuple[float]] = None,
|
@@ -939,25 +969,105 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
939 |
self.use_slicing = False
|
940 |
self.use_tiling = False
|
941 |
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
947 |
)
|
948 |
-
self.
|
949 |
-
|
|
|
|
|
|
|
|
|
|
|
950 |
|
951 |
def _set_gradient_checkpointing(self, module, value=False):
|
952 |
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
953 |
module.gradient_checkpointing = value
|
954 |
|
955 |
-
def
|
956 |
for name, module in self.named_modules():
|
957 |
if isinstance(module, CogVideoXCausalConv3d):
|
958 |
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
959 |
module._clear_fake_context_parallel_cache()
|
960 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
961 |
@apply_forward_hook
|
962 |
def encode(
|
963 |
self, x: torch.Tensor, return_dict: bool = True
|
@@ -982,8 +1092,34 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
982 |
return (posterior,)
|
983 |
return AutoencoderKLOutput(latent_dist=posterior)
|
984 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
985 |
@apply_forward_hook
|
986 |
-
def decode(self, z: torch.
|
987 |
"""
|
988 |
Decode a batch of images.
|
989 |
|
@@ -996,13 +1132,111 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
996 |
[`~models.vae.DecoderOutput`] or `tuple`:
|
997 |
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
998 |
returned.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
999 |
|
|
|
|
|
|
|
|
|
1000 |
"""
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1004 |
if not return_dict:
|
1005 |
return (dec,)
|
|
|
1006 |
return DecoderOutput(sample=dec)
|
1007 |
|
1008 |
def forward(
|
|
|
20 |
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
|
21 |
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
22 |
from diffusers.models.modeling_utils import ModelMixin
|
|
|
23 |
from diffusers.utils.accelerate_utils import apply_forward_hook
|
24 |
|
25 |
+
from videosys.utils.logging import logger
|
26 |
|
27 |
+
from ..modules.downsampling import CogVideoXDownsample3D
|
28 |
+
from ..modules.upsampling import CogVideoXUpsample3D
|
29 |
|
30 |
|
31 |
class CogVideoXSafeConv3d(nn.Conv3d):
|
32 |
+
r"""
|
33 |
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
|
34 |
"""
|
35 |
|
|
|
61 |
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
|
62 |
|
63 |
Args:
|
64 |
+
in_channels (`int`): Number of channels in the input tensor.
|
65 |
+
out_channels (`int`): Number of output channels produced by the convolution.
|
66 |
+
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
|
67 |
+
stride (`int`, defaults to `1`): Stride of the convolution.
|
68 |
+
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
|
69 |
+
pad_mode (`str`, defaults to `"constant"`): Padding mode.
|
70 |
"""
|
71 |
|
72 |
def __init__(
|
|
|
111 |
self.conv_cache = None
|
112 |
|
113 |
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
|
114 |
kernel_size = self.time_kernel_size
|
115 |
+
if kernel_size > 1:
|
116 |
+
cached_inputs = [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
117 |
+
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
return inputs
|
119 |
|
120 |
def _clear_fake_context_parallel_cache(self):
|
|
|
122 |
self.conv_cache = None
|
123 |
|
124 |
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
125 |
+
inputs = self.fake_context_parallel_forward(inputs)
|
126 |
|
127 |
self._clear_fake_context_parallel_cache()
|
128 |
+
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
|
129 |
+
# hundred megabytes and so let's not do it for now
|
130 |
+
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
131 |
|
132 |
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
133 |
+
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
134 |
|
135 |
+
output = self.conv(inputs)
|
|
|
136 |
return output
|
137 |
|
138 |
|
|
|
148 |
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
149 |
zq_channels (`int`):
|
150 |
The number of channels for the quantized vector as described in the paper.
|
151 |
+
groups (`int`):
|
152 |
+
Number of groups to separate the channels into for group normalization.
|
153 |
"""
|
154 |
|
155 |
def __init__(
|
|
|
184 |
A 3D ResNet block used in the CogVideoX model.
|
185 |
|
186 |
Args:
|
187 |
+
in_channels (`int`):
|
188 |
+
Number of input channels.
|
189 |
+
out_channels (`int`, *optional*):
|
190 |
+
Number of output channels. If None, defaults to `in_channels`.
|
191 |
+
dropout (`float`, defaults to `0.0`):
|
192 |
+
Dropout rate.
|
193 |
+
temb_channels (`int`, defaults to `512`):
|
194 |
+
Number of time embedding channels.
|
195 |
+
groups (`int`, defaults to `32`):
|
196 |
+
Number of groups to separate the channels into for group normalization.
|
197 |
+
eps (`float`, defaults to `1e-6`):
|
198 |
+
Epsilon value for normalization layers.
|
199 |
+
non_linearity (`str`, defaults to `"swish"`):
|
200 |
+
Activation function to use.
|
201 |
+
conv_shortcut (bool, defaults to `False`):
|
202 |
+
Whether or not to use a convolution shortcut.
|
203 |
+
spatial_norm_dim (`int`, *optional*):
|
204 |
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
205 |
+
pad_mode (str, defaults to `"first"`):
|
206 |
+
Padding mode.
|
207 |
"""
|
208 |
|
209 |
def __init__(
|
|
|
305 |
A downsampling block used in the CogVideoX model.
|
306 |
|
307 |
Args:
|
308 |
+
in_channels (`int`):
|
309 |
+
Number of input channels.
|
310 |
+
out_channels (`int`, *optional*):
|
311 |
+
Number of output channels. If None, defaults to `in_channels`.
|
312 |
+
temb_channels (`int`, defaults to `512`):
|
313 |
+
Number of time embedding channels.
|
314 |
+
num_layers (`int`, defaults to `1`):
|
315 |
+
Number of resnet layers.
|
316 |
+
dropout (`float`, defaults to `0.0`):
|
317 |
+
Dropout rate.
|
318 |
+
resnet_eps (`float`, defaults to `1e-6`):
|
319 |
+
Epsilon value for normalization layers.
|
320 |
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
321 |
+
Activation function to use.
|
322 |
+
resnet_groups (`int`, defaults to `32`):
|
323 |
+
Number of groups to separate the channels into for group normalization.
|
324 |
+
add_downsample (`bool`, defaults to `True`):
|
325 |
+
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
326 |
+
compress_time (`bool`, defaults to `False`):
|
327 |
+
Whether or not to downsample across temporal dimension.
|
328 |
+
pad_mode (str, defaults to `"first"`):
|
329 |
+
Padding mode.
|
330 |
"""
|
331 |
|
332 |
_supports_gradient_checkpointing = True
|
|
|
411 |
A middle block used in the CogVideoX model.
|
412 |
|
413 |
Args:
|
414 |
+
in_channels (`int`):
|
415 |
+
Number of input channels.
|
416 |
+
temb_channels (`int`, defaults to `512`):
|
417 |
+
Number of time embedding channels.
|
418 |
+
dropout (`float`, defaults to `0.0`):
|
419 |
+
Dropout rate.
|
420 |
+
num_layers (`int`, defaults to `1`):
|
421 |
+
Number of resnet layers.
|
422 |
+
resnet_eps (`float`, defaults to `1e-6`):
|
423 |
+
Epsilon value for normalization layers.
|
424 |
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
425 |
+
Activation function to use.
|
426 |
+
resnet_groups (`int`, defaults to `32`):
|
427 |
+
Number of groups to separate the channels into for group normalization.
|
428 |
+
spatial_norm_dim (`int`, *optional*):
|
429 |
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
430 |
+
pad_mode (str, defaults to `"first"`):
|
431 |
+
Padding mode.
|
432 |
"""
|
433 |
|
434 |
_supports_gradient_checkpointing = True
|
|
|
495 |
An upsampling block used in the CogVideoX model.
|
496 |
|
497 |
Args:
|
498 |
+
in_channels (`int`):
|
499 |
+
Number of input channels.
|
500 |
+
out_channels (`int`, *optional*):
|
501 |
+
Number of output channels. If None, defaults to `in_channels`.
|
502 |
+
temb_channels (`int`, defaults to `512`):
|
503 |
+
Number of time embedding channels.
|
504 |
+
dropout (`float`, defaults to `0.0`):
|
505 |
+
Dropout rate.
|
506 |
+
num_layers (`int`, defaults to `1`):
|
507 |
+
Number of resnet layers.
|
508 |
+
resnet_eps (`float`, defaults to `1e-6`):
|
509 |
+
Epsilon value for normalization layers.
|
510 |
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
511 |
+
Activation function to use.
|
512 |
+
resnet_groups (`int`, defaults to `32`):
|
513 |
+
Number of groups to separate the channels into for group normalization.
|
514 |
+
spatial_norm_dim (`int`, defaults to `16`):
|
515 |
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
516 |
+
add_upsample (`bool`, defaults to `True`):
|
517 |
+
Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
|
518 |
+
compress_time (`bool`, defaults to `False`):
|
519 |
+
Whether or not to downsample across temporal dimension.
|
520 |
+
pad_mode (str, defaults to `"first"`):
|
521 |
+
Padding mode.
|
522 |
"""
|
523 |
|
524 |
def __init__(
|
|
|
609 |
options.
|
610 |
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
611 |
The number of output channels for each block.
|
612 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
613 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
614 |
layers_per_block (`int`, *optional*, defaults to 2):
|
615 |
The number of layers per block.
|
616 |
norm_num_groups (`int`, *optional*, defaults to 32):
|
617 |
The number of groups for normalization.
|
|
|
|
|
|
|
|
|
618 |
"""
|
619 |
|
620 |
_supports_gradient_checkpointing = True
|
|
|
743 |
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
744 |
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
745 |
The number of output channels for each block.
|
746 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
747 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
748 |
layers_per_block (`int`, *optional*, defaults to 2):
|
749 |
The number of layers per block.
|
750 |
norm_num_groups (`int`, *optional*, defaults to 32):
|
751 |
The number of groups for normalization.
|
|
|
|
|
|
|
|
|
752 |
"""
|
753 |
|
754 |
_supports_gradient_checkpointing = True
|
|
|
889 |
Tuple of block output channels.
|
890 |
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
891 |
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
892 |
+
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
|
893 |
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
894 |
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
895 |
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
|
|
929 |
norm_eps: float = 1e-6,
|
930 |
norm_num_groups: int = 32,
|
931 |
temporal_compression_ratio: float = 4,
|
932 |
+
sample_height: int = 480,
|
933 |
+
sample_width: int = 720,
|
934 |
scaling_factor: float = 1.15258426,
|
935 |
shift_factor: Optional[float] = None,
|
936 |
latents_mean: Optional[Tuple[float]] = None,
|
|
|
969 |
self.use_slicing = False
|
970 |
self.use_tiling = False
|
971 |
|
972 |
+
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
|
973 |
+
# recommended because the temporal parts of the VAE, here, are tricky to understand.
|
974 |
+
# If you decode X latent frames together, the number of output frames is:
|
975 |
+
# (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
|
976 |
+
#
|
977 |
+
# Example with num_latent_frames_batch_size = 2:
|
978 |
+
# - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
|
979 |
+
# => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
980 |
+
# => 6 * 8 = 48 frames
|
981 |
+
# - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
|
982 |
+
# => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
|
983 |
+
# ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
984 |
+
# => 1 * 9 + 5 * 8 = 49 frames
|
985 |
+
# It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
|
986 |
+
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
|
987 |
+
# number of temporal frames.
|
988 |
+
self.num_latent_frames_batch_size = 2
|
989 |
+
|
990 |
+
# We make the minimum height and width of sample for tiling half that of the generally supported
|
991 |
+
self.tile_sample_min_height = sample_height // 2
|
992 |
+
self.tile_sample_min_width = sample_width // 2
|
993 |
+
self.tile_latent_min_height = int(
|
994 |
+
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
995 |
)
|
996 |
+
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
997 |
+
|
998 |
+
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
|
999 |
+
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
|
1000 |
+
# and so the tiling implementation has only been tested on those specific resolutions.
|
1001 |
+
self.tile_overlap_factor_height = 1 / 6
|
1002 |
+
self.tile_overlap_factor_width = 1 / 5
|
1003 |
|
1004 |
def _set_gradient_checkpointing(self, module, value=False):
|
1005 |
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
1006 |
module.gradient_checkpointing = value
|
1007 |
|
1008 |
+
def _clear_fake_context_parallel_cache(self):
|
1009 |
for name, module in self.named_modules():
|
1010 |
if isinstance(module, CogVideoXCausalConv3d):
|
1011 |
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
1012 |
module._clear_fake_context_parallel_cache()
|
1013 |
|
1014 |
+
def enable_tiling(
|
1015 |
+
self,
|
1016 |
+
tile_sample_min_height: Optional[int] = None,
|
1017 |
+
tile_sample_min_width: Optional[int] = None,
|
1018 |
+
tile_overlap_factor_height: Optional[float] = None,
|
1019 |
+
tile_overlap_factor_width: Optional[float] = None,
|
1020 |
+
) -> None:
|
1021 |
+
r"""
|
1022 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
1023 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
1024 |
+
processing larger images.
|
1025 |
+
|
1026 |
+
Args:
|
1027 |
+
tile_sample_min_height (`int`, *optional*):
|
1028 |
+
The minimum height required for a sample to be separated into tiles across the height dimension.
|
1029 |
+
tile_sample_min_width (`int`, *optional*):
|
1030 |
+
The minimum width required for a sample to be separated into tiles across the width dimension.
|
1031 |
+
tile_overlap_factor_height (`int`, *optional*):
|
1032 |
+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
1033 |
+
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
|
1034 |
+
value might cause more tiles to be processed leading to slow down of the decoding process.
|
1035 |
+
tile_overlap_factor_width (`int`, *optional*):
|
1036 |
+
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
|
1037 |
+
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
|
1038 |
+
value might cause more tiles to be processed leading to slow down of the decoding process.
|
1039 |
+
"""
|
1040 |
+
self.use_tiling = True
|
1041 |
+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
1042 |
+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
1043 |
+
self.tile_latent_min_height = int(
|
1044 |
+
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
1045 |
+
)
|
1046 |
+
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
1047 |
+
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
|
1048 |
+
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
|
1049 |
+
|
1050 |
+
def disable_tiling(self) -> None:
|
1051 |
+
r"""
|
1052 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
1053 |
+
decoding in one step.
|
1054 |
+
"""
|
1055 |
+
self.use_tiling = False
|
1056 |
+
|
1057 |
+
def enable_slicing(self) -> None:
|
1058 |
+
r"""
|
1059 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
1060 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
1061 |
+
"""
|
1062 |
+
self.use_slicing = True
|
1063 |
+
|
1064 |
+
def disable_slicing(self) -> None:
|
1065 |
+
r"""
|
1066 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
1067 |
+
decoding in one step.
|
1068 |
+
"""
|
1069 |
+
self.use_slicing = False
|
1070 |
+
|
1071 |
@apply_forward_hook
|
1072 |
def encode(
|
1073 |
self, x: torch.Tensor, return_dict: bool = True
|
|
|
1092 |
return (posterior,)
|
1093 |
return AutoencoderKLOutput(latent_dist=posterior)
|
1094 |
|
1095 |
+
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
1096 |
+
batch_size, num_channels, num_frames, height, width = z.shape
|
1097 |
+
|
1098 |
+
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
1099 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
1100 |
+
|
1101 |
+
frame_batch_size = self.num_latent_frames_batch_size
|
1102 |
+
dec = []
|
1103 |
+
for i in range(num_frames // frame_batch_size):
|
1104 |
+
remaining_frames = num_frames % frame_batch_size
|
1105 |
+
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
1106 |
+
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
1107 |
+
z_intermediate = z[:, :, start_frame:end_frame]
|
1108 |
+
if self.post_quant_conv is not None:
|
1109 |
+
z_intermediate = self.post_quant_conv(z_intermediate)
|
1110 |
+
z_intermediate = self.decoder(z_intermediate)
|
1111 |
+
dec.append(z_intermediate)
|
1112 |
+
|
1113 |
+
self._clear_fake_context_parallel_cache()
|
1114 |
+
dec = torch.cat(dec, dim=2)
|
1115 |
+
|
1116 |
+
if not return_dict:
|
1117 |
+
return (dec,)
|
1118 |
+
|
1119 |
+
return DecoderOutput(sample=dec)
|
1120 |
+
|
1121 |
@apply_forward_hook
|
1122 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
1123 |
"""
|
1124 |
Decode a batch of images.
|
1125 |
|
|
|
1132 |
[`~models.vae.DecoderOutput`] or `tuple`:
|
1133 |
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
1134 |
returned.
|
1135 |
+
"""
|
1136 |
+
if self.use_slicing and z.shape[0] > 1:
|
1137 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
1138 |
+
decoded = torch.cat(decoded_slices)
|
1139 |
+
else:
|
1140 |
+
decoded = self._decode(z).sample
|
1141 |
+
|
1142 |
+
if not return_dict:
|
1143 |
+
return (decoded,)
|
1144 |
+
return DecoderOutput(sample=decoded)
|
1145 |
+
|
1146 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
1147 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
1148 |
+
for y in range(blend_extent):
|
1149 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
1150 |
+
y / blend_extent
|
1151 |
+
)
|
1152 |
+
return b
|
1153 |
+
|
1154 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
1155 |
+
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
1156 |
+
for x in range(blend_extent):
|
1157 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
1158 |
+
x / blend_extent
|
1159 |
+
)
|
1160 |
+
return b
|
1161 |
+
|
1162 |
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
1163 |
+
r"""
|
1164 |
+
Decode a batch of images using a tiled decoder.
|
1165 |
+
|
1166 |
+
Args:
|
1167 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
1168 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1169 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
1170 |
|
1171 |
+
Returns:
|
1172 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
1173 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
1174 |
+
returned.
|
1175 |
"""
|
1176 |
+
# Rough memory assessment:
|
1177 |
+
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
|
1178 |
+
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
|
1179 |
+
# - Assume fp16 (2 bytes per value).
|
1180 |
+
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
|
1181 |
+
#
|
1182 |
+
# Memory assessment when using tiling:
|
1183 |
+
# - Assume everything as above but now HxW is 240x360 by tiling in half
|
1184 |
+
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
|
1185 |
+
|
1186 |
+
batch_size, num_channels, num_frames, height, width = z.shape
|
1187 |
+
|
1188 |
+
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
1189 |
+
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
1190 |
+
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
1191 |
+
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
1192 |
+
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
1193 |
+
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
1194 |
+
frame_batch_size = self.num_latent_frames_batch_size
|
1195 |
+
|
1196 |
+
# Split z into overlapping tiles and decode them separately.
|
1197 |
+
# The tiles have an overlap to avoid seams between tiles.
|
1198 |
+
rows = []
|
1199 |
+
for i in range(0, height, overlap_height):
|
1200 |
+
row = []
|
1201 |
+
for j in range(0, width, overlap_width):
|
1202 |
+
time = []
|
1203 |
+
for k in range(num_frames // frame_batch_size):
|
1204 |
+
remaining_frames = num_frames % frame_batch_size
|
1205 |
+
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
1206 |
+
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
1207 |
+
tile = z[
|
1208 |
+
:,
|
1209 |
+
:,
|
1210 |
+
start_frame:end_frame,
|
1211 |
+
i : i + self.tile_latent_min_height,
|
1212 |
+
j : j + self.tile_latent_min_width,
|
1213 |
+
]
|
1214 |
+
if self.post_quant_conv is not None:
|
1215 |
+
tile = self.post_quant_conv(tile)
|
1216 |
+
tile = self.decoder(tile)
|
1217 |
+
time.append(tile)
|
1218 |
+
self._clear_fake_context_parallel_cache()
|
1219 |
+
row.append(torch.cat(time, dim=2))
|
1220 |
+
rows.append(row)
|
1221 |
+
|
1222 |
+
result_rows = []
|
1223 |
+
for i, row in enumerate(rows):
|
1224 |
+
result_row = []
|
1225 |
+
for j, tile in enumerate(row):
|
1226 |
+
# blend the above tile and the left tile
|
1227 |
+
# to the current tile and add the current tile to the result row
|
1228 |
+
if i > 0:
|
1229 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
1230 |
+
if j > 0:
|
1231 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
1232 |
+
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
1233 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
1234 |
+
|
1235 |
+
dec = torch.cat(result_rows, dim=3)
|
1236 |
+
|
1237 |
if not return_dict:
|
1238 |
return (dec,)
|
1239 |
+
|
1240 |
return DecoderOutput(sample=dec)
|
1241 |
|
1242 |
def forward(
|
videosys/models/{open_sora/vae.py → autoencoders/autoencoder_kl_open_sora.py}
RENAMED
@@ -18,8 +18,6 @@ from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
|
|
18 |
from einops import rearrange
|
19 |
from transformers import PretrainedConfig, PreTrainedModel
|
20 |
|
21 |
-
from .utils import load_checkpoint
|
22 |
-
|
23 |
|
24 |
class DiagonalGaussianDistribution(object):
|
25 |
def __init__(
|
@@ -474,7 +472,7 @@ class VAE_Temporal(nn.Module):
|
|
474 |
return recon_video, posterior, z
|
475 |
|
476 |
|
477 |
-
def VAE_Temporal_SD(
|
478 |
model = VAE_Temporal(
|
479 |
in_out_channels=4,
|
480 |
latent_embed_dim=4,
|
@@ -485,8 +483,6 @@ def VAE_Temporal_SD(from_pretrained=None, **kwargs):
|
|
485 |
temporal_downsample=(False, True, True),
|
486 |
**kwargs,
|
487 |
)
|
488 |
-
if from_pretrained is not None:
|
489 |
-
load_checkpoint(model, from_pretrained)
|
490 |
return model
|
491 |
|
492 |
|
@@ -634,7 +630,7 @@ class VideoAutoencoderPipeline(PreTrainedModel):
|
|
634 |
micro_batch_size=4,
|
635 |
subfolder="vae",
|
636 |
)
|
637 |
-
self.temporal_vae = VAE_Temporal_SD(
|
638 |
self.cal_loss = config.cal_loss
|
639 |
self.micro_frame_size = config.micro_frame_size
|
640 |
self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0]
|
@@ -763,7 +759,4 @@ def OpenSoraVAE_V1_2(
|
|
763 |
else:
|
764 |
config = VideoAutoencoderPipelineConfig(**kwargs)
|
765 |
model = VideoAutoencoderPipeline(config)
|
766 |
-
|
767 |
-
if from_pretrained:
|
768 |
-
load_checkpoint(model, from_pretrained)
|
769 |
return model
|
|
|
18 |
from einops import rearrange
|
19 |
from transformers import PretrainedConfig, PreTrainedModel
|
20 |
|
|
|
|
|
21 |
|
22 |
class DiagonalGaussianDistribution(object):
|
23 |
def __init__(
|
|
|
472 |
return recon_video, posterior, z
|
473 |
|
474 |
|
475 |
+
def VAE_Temporal_SD(**kwargs):
|
476 |
model = VAE_Temporal(
|
477 |
in_out_channels=4,
|
478 |
latent_embed_dim=4,
|
|
|
483 |
temporal_downsample=(False, True, True),
|
484 |
**kwargs,
|
485 |
)
|
|
|
|
|
486 |
return model
|
487 |
|
488 |
|
|
|
630 |
micro_batch_size=4,
|
631 |
subfolder="vae",
|
632 |
)
|
633 |
+
self.temporal_vae = VAE_Temporal_SD()
|
634 |
self.cal_loss = config.cal_loss
|
635 |
self.micro_frame_size = config.micro_frame_size
|
636 |
self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0]
|
|
|
759 |
else:
|
760 |
config = VideoAutoencoderPipelineConfig(**kwargs)
|
761 |
model = VideoAutoencoderPipeline(config)
|
|
|
|
|
|
|
762 |
return model
|
videosys/models/{open_sora_plan/ae.py → autoencoders/autoencoder_kl_open_sora_plan.py}
RENAMED
@@ -6,20 +6,24 @@
|
|
6 |
# References:
|
7 |
# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
|
8 |
# --------------------------------------------------------
|
9 |
-
|
10 |
import glob
|
11 |
-
import importlib
|
12 |
import os
|
13 |
from typing import Optional, Tuple, Union
|
14 |
|
15 |
import numpy as np
|
16 |
import torch
|
|
|
|
|
|
|
17 |
from diffusers import ConfigMixin, ModelMixin
|
18 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
19 |
from diffusers.models.modeling_utils import ModelMixin
|
|
|
20 |
from einops import rearrange
|
21 |
from torch import nn
|
22 |
|
|
|
|
|
23 |
|
24 |
def Normalize(in_channels, num_groups=32):
|
25 |
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
@@ -80,13 +84,7 @@ class DiagonalGaussianDistribution(object):
|
|
80 |
|
81 |
|
82 |
def resolve_str_to_obj(str_val, append=True):
|
83 |
-
|
84 |
-
str_val = "videosys.models.open_sora_plan.modules." + str_val
|
85 |
-
if "opensora.models.ae.videobase." in str_val:
|
86 |
-
str_val = str_val.replace("opensora.models.ae.videobase.", "videosys.models.open_sora_plan.")
|
87 |
-
module_name, class_name = str_val.rsplit(".", 1)
|
88 |
-
module = importlib.import_module(module_name)
|
89 |
-
return getattr(module, class_name)
|
90 |
|
91 |
|
92 |
class VideoBaseAE_PL(ModelMixin, ConfigMixin):
|
@@ -130,7 +128,6 @@ class VideoBaseAE_PL(ModelMixin, ConfigMixin):
|
|
130 |
model.init_from_ckpt(last_ckpt_file)
|
131 |
return model
|
132 |
else:
|
133 |
-
print(f"Loading model from {pretrained_model_name_or_path}")
|
134 |
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
135 |
|
136 |
|
@@ -431,8 +428,6 @@ class CausalVAEModel(VideoBaseAE_PL):
|
|
431 |
self.learning_rate = lr
|
432 |
self.lr_g_factor = 1.0
|
433 |
|
434 |
-
self.loss = resolve_str_to_obj(loss_type, append=False)(**loss_params)
|
435 |
-
|
436 |
self.encoder = Encoder(
|
437 |
z_channels=z_channels,
|
438 |
hidden_size=hidden_size,
|
@@ -471,8 +466,6 @@ class CausalVAEModel(VideoBaseAE_PL):
|
|
471 |
quant_conv_cls = resolve_str_to_obj(q_conv)
|
472 |
self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1)
|
473 |
self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1)
|
474 |
-
if hasattr(self.loss, "discriminator"):
|
475 |
-
self.automatic_optimization = False
|
476 |
|
477 |
def encode(self, x):
|
478 |
if self.use_tiling and (
|
@@ -855,3 +848,793 @@ def getae_wrapper(ae):
|
|
855 |
ae = videobase_ae.get(ae, None)
|
856 |
assert ae is not None
|
857 |
return ae
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
# References:
|
7 |
# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
|
8 |
# --------------------------------------------------------
|
|
|
9 |
import glob
|
|
|
10 |
import os
|
11 |
from typing import Optional, Tuple, Union
|
12 |
|
13 |
import numpy as np
|
14 |
import torch
|
15 |
+
import torch.distributed as dist
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
from diffusers import ConfigMixin, ModelMixin
|
19 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
20 |
from diffusers.models.modeling_utils import ModelMixin
|
21 |
+
from diffusers.utils import logging
|
22 |
from einops import rearrange
|
23 |
from torch import nn
|
24 |
|
25 |
+
logging.set_verbosity_error()
|
26 |
+
|
27 |
|
28 |
def Normalize(in_channels, num_groups=32):
|
29 |
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
|
|
84 |
|
85 |
|
86 |
def resolve_str_to_obj(str_val, append=True):
|
87 |
+
return globals()[str_val]
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
|
90 |
class VideoBaseAE_PL(ModelMixin, ConfigMixin):
|
|
|
128 |
model.init_from_ckpt(last_ckpt_file)
|
129 |
return model
|
130 |
else:
|
|
|
131 |
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
132 |
|
133 |
|
|
|
428 |
self.learning_rate = lr
|
429 |
self.lr_g_factor = 1.0
|
430 |
|
|
|
|
|
431 |
self.encoder = Encoder(
|
432 |
z_channels=z_channels,
|
433 |
hidden_size=hidden_size,
|
|
|
466 |
quant_conv_cls = resolve_str_to_obj(q_conv)
|
467 |
self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1)
|
468 |
self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1)
|
|
|
|
|
469 |
|
470 |
def encode(self, x):
|
471 |
if self.use_tiling and (
|
|
|
848 |
ae = videobase_ae.get(ae, None)
|
849 |
assert ae is not None
|
850 |
return ae
|
851 |
+
|
852 |
+
|
853 |
+
def video_to_image(func):
|
854 |
+
def wrapper(self, x, *args, **kwargs):
|
855 |
+
if x.dim() == 5:
|
856 |
+
t = x.shape[2]
|
857 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
858 |
+
x = func(self, x, *args, **kwargs)
|
859 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
860 |
+
return x
|
861 |
+
|
862 |
+
return wrapper
|
863 |
+
|
864 |
+
|
865 |
+
class Block(nn.Module):
|
866 |
+
def __init__(self, *args, **kwargs) -> None:
|
867 |
+
super().__init__(*args, **kwargs)
|
868 |
+
|
869 |
+
|
870 |
+
class LinearAttention(Block):
|
871 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
872 |
+
super().__init__()
|
873 |
+
self.heads = heads
|
874 |
+
hidden_dim = dim_head * heads
|
875 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
876 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
877 |
+
|
878 |
+
def forward(self, x):
|
879 |
+
b, c, h, w = x.shape
|
880 |
+
qkv = self.to_qkv(x)
|
881 |
+
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
882 |
+
k = k.softmax(dim=-1)
|
883 |
+
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
884 |
+
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
885 |
+
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
886 |
+
return self.to_out(out)
|
887 |
+
|
888 |
+
|
889 |
+
class LinAttnBlock(LinearAttention):
|
890 |
+
"""to match AttnBlock usage"""
|
891 |
+
|
892 |
+
def __init__(self, in_channels):
|
893 |
+
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
894 |
+
|
895 |
+
|
896 |
+
class AttnBlock3D(Block):
|
897 |
+
"""Compatible with old versions, there are issues, use with caution."""
|
898 |
+
|
899 |
+
def __init__(self, in_channels):
|
900 |
+
super().__init__()
|
901 |
+
self.in_channels = in_channels
|
902 |
+
|
903 |
+
self.norm = Normalize(in_channels)
|
904 |
+
self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
905 |
+
self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
906 |
+
self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
907 |
+
self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
908 |
+
|
909 |
+
def forward(self, x):
|
910 |
+
h_ = x
|
911 |
+
h_ = self.norm(h_)
|
912 |
+
q = self.q(h_)
|
913 |
+
k = self.k(h_)
|
914 |
+
v = self.v(h_)
|
915 |
+
|
916 |
+
# compute attention
|
917 |
+
b, c, t, h, w = q.shape
|
918 |
+
q = q.reshape(b * t, c, h * w)
|
919 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
920 |
+
k = k.reshape(b * t, c, h * w) # b,c,hw
|
921 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
922 |
+
w_ = w_ * (int(c) ** (-0.5))
|
923 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
924 |
+
|
925 |
+
# attend to values
|
926 |
+
v = v.reshape(b * t, c, h * w)
|
927 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
928 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
929 |
+
h_ = h_.reshape(b, c, t, h, w)
|
930 |
+
|
931 |
+
h_ = self.proj_out(h_)
|
932 |
+
|
933 |
+
return x + h_
|
934 |
+
|
935 |
+
|
936 |
+
class AttnBlock3DFix(nn.Module):
|
937 |
+
"""
|
938 |
+
Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172.
|
939 |
+
"""
|
940 |
+
|
941 |
+
def __init__(self, in_channels):
|
942 |
+
super().__init__()
|
943 |
+
self.in_channels = in_channels
|
944 |
+
|
945 |
+
self.norm = Normalize(in_channels)
|
946 |
+
self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
947 |
+
self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
948 |
+
self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
949 |
+
self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
950 |
+
|
951 |
+
def forward(self, x):
|
952 |
+
h_ = x
|
953 |
+
h_ = self.norm(h_)
|
954 |
+
q = self.q(h_)
|
955 |
+
k = self.k(h_)
|
956 |
+
v = self.v(h_)
|
957 |
+
|
958 |
+
# compute attention
|
959 |
+
# q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c)
|
960 |
+
b, c, t, h, w = q.shape
|
961 |
+
q = q.permute(0, 2, 1, 3, 4)
|
962 |
+
q = q.reshape(b * t, c, h * w)
|
963 |
+
q = q.permute(0, 2, 1)
|
964 |
+
|
965 |
+
# k: (b c t h w) -> (b t c h w) -> (b*t c h*w)
|
966 |
+
k = k.permute(0, 2, 1, 3, 4)
|
967 |
+
k = k.reshape(b * t, c, h * w)
|
968 |
+
|
969 |
+
# w: (b*t hw hw)
|
970 |
+
w_ = torch.bmm(q, k)
|
971 |
+
w_ = w_ * (int(c) ** (-0.5))
|
972 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
973 |
+
|
974 |
+
# attend to values
|
975 |
+
# v: (b c t h w) -> (b t c h w) -> (bt c hw)
|
976 |
+
# w_: (bt hw hw) -> (bt hw hw)
|
977 |
+
v = v.permute(0, 2, 1, 3, 4)
|
978 |
+
v = v.reshape(b * t, c, h * w)
|
979 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
980 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
981 |
+
|
982 |
+
# h_: (b*t c hw) -> (b t c h w) -> (b c t h w)
|
983 |
+
h_ = h_.reshape(b, t, c, h, w)
|
984 |
+
h_ = h_.permute(0, 2, 1, 3, 4)
|
985 |
+
|
986 |
+
h_ = self.proj_out(h_)
|
987 |
+
|
988 |
+
return x + h_
|
989 |
+
|
990 |
+
|
991 |
+
class AttnBlock(Block):
|
992 |
+
def __init__(self, in_channels):
|
993 |
+
super().__init__()
|
994 |
+
self.in_channels = in_channels
|
995 |
+
|
996 |
+
self.norm = Normalize(in_channels)
|
997 |
+
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
998 |
+
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
999 |
+
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
1000 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
1001 |
+
|
1002 |
+
@video_to_image
|
1003 |
+
def forward(self, x):
|
1004 |
+
h_ = x
|
1005 |
+
h_ = self.norm(h_)
|
1006 |
+
q = self.q(h_)
|
1007 |
+
k = self.k(h_)
|
1008 |
+
v = self.v(h_)
|
1009 |
+
|
1010 |
+
# compute attention
|
1011 |
+
b, c, h, w = q.shape
|
1012 |
+
q = q.reshape(b, c, h * w)
|
1013 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
1014 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
1015 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
1016 |
+
w_ = w_ * (int(c) ** (-0.5))
|
1017 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
1018 |
+
|
1019 |
+
# attend to values
|
1020 |
+
v = v.reshape(b, c, h * w)
|
1021 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
1022 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
1023 |
+
h_ = h_.reshape(b, c, h, w)
|
1024 |
+
|
1025 |
+
h_ = self.proj_out(h_)
|
1026 |
+
|
1027 |
+
return x + h_
|
1028 |
+
|
1029 |
+
|
1030 |
+
class TemporalAttnBlock(Block):
|
1031 |
+
def __init__(self, in_channels):
|
1032 |
+
super().__init__()
|
1033 |
+
self.in_channels = in_channels
|
1034 |
+
|
1035 |
+
self.norm = Normalize(in_channels)
|
1036 |
+
self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
1037 |
+
self.k = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
1038 |
+
self.v = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
1039 |
+
self.proj_out = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
1040 |
+
|
1041 |
+
def forward(self, x):
|
1042 |
+
h_ = x
|
1043 |
+
h_ = self.norm(h_)
|
1044 |
+
q = self.q(h_)
|
1045 |
+
k = self.k(h_)
|
1046 |
+
v = self.v(h_)
|
1047 |
+
|
1048 |
+
# compute attention
|
1049 |
+
b, c, t, h, w = q.shape
|
1050 |
+
q = rearrange(q, "b c t h w -> (b h w) t c")
|
1051 |
+
k = rearrange(k, "b c t h w -> (b h w) c t")
|
1052 |
+
v = rearrange(v, "b c t h w -> (b h w) c t")
|
1053 |
+
w_ = torch.bmm(q, k)
|
1054 |
+
w_ = w_ * (int(c) ** (-0.5))
|
1055 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
1056 |
+
|
1057 |
+
# attend to values
|
1058 |
+
w_ = w_.permute(0, 2, 1)
|
1059 |
+
h_ = torch.bmm(v, w_)
|
1060 |
+
h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w)
|
1061 |
+
h_ = self.proj_out(h_)
|
1062 |
+
|
1063 |
+
return x + h_
|
1064 |
+
|
1065 |
+
|
1066 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
1067 |
+
assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown"
|
1068 |
+
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
1069 |
+
print(attn_type)
|
1070 |
+
if attn_type == "vanilla":
|
1071 |
+
return AttnBlock(in_channels)
|
1072 |
+
elif attn_type == "vanilla3D":
|
1073 |
+
return AttnBlock3D(in_channels)
|
1074 |
+
elif attn_type == "none":
|
1075 |
+
return nn.Identity(in_channels)
|
1076 |
+
else:
|
1077 |
+
return LinAttnBlock(in_channels)
|
1078 |
+
|
1079 |
+
|
1080 |
+
class Conv2d(nn.Conv2d):
|
1081 |
+
def __init__(
|
1082 |
+
self,
|
1083 |
+
in_channels: int,
|
1084 |
+
out_channels: int,
|
1085 |
+
kernel_size: Union[int, Tuple[int]] = 3,
|
1086 |
+
stride: Union[int, Tuple[int]] = 1,
|
1087 |
+
padding: Union[str, int, Tuple[int]] = 0,
|
1088 |
+
dilation: Union[int, Tuple[int]] = 1,
|
1089 |
+
groups: int = 1,
|
1090 |
+
bias: bool = True,
|
1091 |
+
padding_mode: str = "zeros",
|
1092 |
+
device=None,
|
1093 |
+
dtype=None,
|
1094 |
+
) -> None:
|
1095 |
+
super().__init__(
|
1096 |
+
in_channels,
|
1097 |
+
out_channels,
|
1098 |
+
kernel_size,
|
1099 |
+
stride,
|
1100 |
+
padding,
|
1101 |
+
dilation,
|
1102 |
+
groups,
|
1103 |
+
bias,
|
1104 |
+
padding_mode,
|
1105 |
+
device,
|
1106 |
+
dtype,
|
1107 |
+
)
|
1108 |
+
|
1109 |
+
@video_to_image
|
1110 |
+
def forward(self, x):
|
1111 |
+
return super().forward(x)
|
1112 |
+
|
1113 |
+
|
1114 |
+
class CausalConv3d(nn.Module):
|
1115 |
+
def __init__(
|
1116 |
+
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs
|
1117 |
+
):
|
1118 |
+
super().__init__()
|
1119 |
+
self.kernel_size = cast_tuple(kernel_size, 3)
|
1120 |
+
self.time_kernel_size = self.kernel_size[0]
|
1121 |
+
self.chan_in = chan_in
|
1122 |
+
self.chan_out = chan_out
|
1123 |
+
stride = kwargs.pop("stride", 1)
|
1124 |
+
padding = kwargs.pop("padding", 0)
|
1125 |
+
padding = list(cast_tuple(padding, 3))
|
1126 |
+
padding[0] = 0
|
1127 |
+
stride = cast_tuple(stride, 3)
|
1128 |
+
self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding)
|
1129 |
+
self._init_weights(init_method)
|
1130 |
+
|
1131 |
+
def _init_weights(self, init_method):
|
1132 |
+
torch.tensor(self.kernel_size)
|
1133 |
+
if init_method == "avg":
|
1134 |
+
assert self.kernel_size[1] == 1 and self.kernel_size[2] == 1, "only support temporal up/down sample"
|
1135 |
+
assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out"
|
1136 |
+
weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size))
|
1137 |
+
|
1138 |
+
eyes = torch.concat(
|
1139 |
+
[
|
1140 |
+
torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
|
1141 |
+
torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
|
1142 |
+
torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
|
1143 |
+
],
|
1144 |
+
dim=-1,
|
1145 |
+
)
|
1146 |
+
weight[:, :, :, 0, 0] = eyes
|
1147 |
+
|
1148 |
+
self.conv.weight = nn.Parameter(
|
1149 |
+
weight,
|
1150 |
+
requires_grad=True,
|
1151 |
+
)
|
1152 |
+
elif init_method == "zero":
|
1153 |
+
self.conv.weight = nn.Parameter(
|
1154 |
+
torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)),
|
1155 |
+
requires_grad=True,
|
1156 |
+
)
|
1157 |
+
if self.conv.bias is not None:
|
1158 |
+
nn.init.constant_(self.conv.bias, 0)
|
1159 |
+
|
1160 |
+
def forward(self, x):
|
1161 |
+
# 1 + 16 16 as video, 1 as image
|
1162 |
+
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w
|
1163 |
+
x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16
|
1164 |
+
return self.conv(x)
|
1165 |
+
|
1166 |
+
|
1167 |
+
class GroupNorm(Block):
|
1168 |
+
def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None:
|
1169 |
+
super().__init__(*args, **kwargs)
|
1170 |
+
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True)
|
1171 |
+
|
1172 |
+
def forward(self, x):
|
1173 |
+
return self.norm(x)
|
1174 |
+
|
1175 |
+
|
1176 |
+
def Normalize(in_channels, num_groups=32):
|
1177 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
1178 |
+
|
1179 |
+
|
1180 |
+
class ActNorm(nn.Module):
|
1181 |
+
def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False):
|
1182 |
+
assert affine
|
1183 |
+
super().__init__()
|
1184 |
+
self.logdet = logdet
|
1185 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
1186 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
1187 |
+
self.allow_reverse_init = allow_reverse_init
|
1188 |
+
|
1189 |
+
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
|
1190 |
+
|
1191 |
+
def initialize(self, input):
|
1192 |
+
with torch.no_grad():
|
1193 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
1194 |
+
mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
|
1195 |
+
std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
|
1196 |
+
|
1197 |
+
self.loc.data.copy_(-mean)
|
1198 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
1199 |
+
|
1200 |
+
def forward(self, input, reverse=False):
|
1201 |
+
if reverse:
|
1202 |
+
return self.reverse(input)
|
1203 |
+
if len(input.shape) == 2:
|
1204 |
+
input = input[:, :, None, None]
|
1205 |
+
squeeze = True
|
1206 |
+
else:
|
1207 |
+
squeeze = False
|
1208 |
+
|
1209 |
+
_, _, height, width = input.shape
|
1210 |
+
|
1211 |
+
if self.training and self.initialized.item() == 0:
|
1212 |
+
self.initialize(input)
|
1213 |
+
self.initialized.fill_(1)
|
1214 |
+
|
1215 |
+
h = self.scale * (input + self.loc)
|
1216 |
+
|
1217 |
+
if squeeze:
|
1218 |
+
h = h.squeeze(-1).squeeze(-1)
|
1219 |
+
|
1220 |
+
if self.logdet:
|
1221 |
+
log_abs = torch.log(torch.abs(self.scale))
|
1222 |
+
logdet = height * width * torch.sum(log_abs)
|
1223 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
1224 |
+
return h, logdet
|
1225 |
+
|
1226 |
+
return h
|
1227 |
+
|
1228 |
+
def reverse(self, output):
|
1229 |
+
if self.training and self.initialized.item() == 0:
|
1230 |
+
if not self.allow_reverse_init:
|
1231 |
+
raise RuntimeError(
|
1232 |
+
"Initializing ActNorm in reverse direction is "
|
1233 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
1234 |
+
)
|
1235 |
+
else:
|
1236 |
+
self.initialize(output)
|
1237 |
+
self.initialized.fill_(1)
|
1238 |
+
|
1239 |
+
if len(output.shape) == 2:
|
1240 |
+
output = output[:, :, None, None]
|
1241 |
+
squeeze = True
|
1242 |
+
else:
|
1243 |
+
squeeze = False
|
1244 |
+
|
1245 |
+
h = output / self.scale - self.loc
|
1246 |
+
|
1247 |
+
if squeeze:
|
1248 |
+
h = h.squeeze(-1).squeeze(-1)
|
1249 |
+
return h
|
1250 |
+
|
1251 |
+
|
1252 |
+
def nonlinearity(x):
|
1253 |
+
return x * torch.sigmoid(x)
|
1254 |
+
|
1255 |
+
|
1256 |
+
def cast_tuple(t, length=1):
|
1257 |
+
return t if isinstance(t, tuple) else ((t,) * length)
|
1258 |
+
|
1259 |
+
|
1260 |
+
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
|
1261 |
+
n_dims = len(x.shape)
|
1262 |
+
if src_dim < 0:
|
1263 |
+
src_dim = n_dims + src_dim
|
1264 |
+
if dest_dim < 0:
|
1265 |
+
dest_dim = n_dims + dest_dim
|
1266 |
+
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
|
1267 |
+
dims = list(range(n_dims))
|
1268 |
+
del dims[src_dim]
|
1269 |
+
permutation = []
|
1270 |
+
ctr = 0
|
1271 |
+
for i in range(n_dims):
|
1272 |
+
if i == dest_dim:
|
1273 |
+
permutation.append(src_dim)
|
1274 |
+
else:
|
1275 |
+
permutation.append(dims[ctr])
|
1276 |
+
ctr += 1
|
1277 |
+
x = x.permute(permutation)
|
1278 |
+
if make_contiguous:
|
1279 |
+
x = x.contiguous()
|
1280 |
+
return x
|
1281 |
+
|
1282 |
+
|
1283 |
+
class Codebook(nn.Module):
|
1284 |
+
def __init__(self, n_codes, embedding_dim):
|
1285 |
+
super().__init__()
|
1286 |
+
self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
|
1287 |
+
self.register_buffer("N", torch.zeros(n_codes))
|
1288 |
+
self.register_buffer("z_avg", self.embeddings.data.clone())
|
1289 |
+
|
1290 |
+
self.n_codes = n_codes
|
1291 |
+
self.embedding_dim = embedding_dim
|
1292 |
+
self._need_init = True
|
1293 |
+
|
1294 |
+
def _tile(self, x):
|
1295 |
+
d, ew = x.shape
|
1296 |
+
if d < self.n_codes:
|
1297 |
+
n_repeats = (self.n_codes + d - 1) // d
|
1298 |
+
std = 0.01 / np.sqrt(ew)
|
1299 |
+
x = x.repeat(n_repeats, 1)
|
1300 |
+
x = x + torch.randn_like(x) * std
|
1301 |
+
return x
|
1302 |
+
|
1303 |
+
def _init_embeddings(self, z):
|
1304 |
+
# z: [b, c, t, h, w]
|
1305 |
+
self._need_init = False
|
1306 |
+
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
|
1307 |
+
y = self._tile(flat_inputs)
|
1308 |
+
|
1309 |
+
y.shape[0]
|
1310 |
+
_k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
|
1311 |
+
if dist.is_initialized():
|
1312 |
+
dist.broadcast(_k_rand, 0)
|
1313 |
+
self.embeddings.data.copy_(_k_rand)
|
1314 |
+
self.z_avg.data.copy_(_k_rand)
|
1315 |
+
self.N.data.copy_(torch.ones(self.n_codes))
|
1316 |
+
|
1317 |
+
def forward(self, z):
|
1318 |
+
# z: [b, c, t, h, w]
|
1319 |
+
if self._need_init and self.training:
|
1320 |
+
self._init_embeddings(z)
|
1321 |
+
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
|
1322 |
+
distances = (
|
1323 |
+
(flat_inputs**2).sum(dim=1, keepdim=True)
|
1324 |
+
- 2 * flat_inputs @ self.embeddings.t()
|
1325 |
+
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
|
1326 |
+
)
|
1327 |
+
|
1328 |
+
encoding_indices = torch.argmin(distances, dim=1)
|
1329 |
+
encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs)
|
1330 |
+
encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:])
|
1331 |
+
|
1332 |
+
embeddings = F.embedding(encoding_indices, self.embeddings)
|
1333 |
+
embeddings = shift_dim(embeddings, -1, 1)
|
1334 |
+
|
1335 |
+
commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
|
1336 |
+
|
1337 |
+
# EMA codebook update
|
1338 |
+
if self.training:
|
1339 |
+
n_total = encode_onehot.sum(dim=0)
|
1340 |
+
encode_sum = flat_inputs.t() @ encode_onehot
|
1341 |
+
if dist.is_initialized():
|
1342 |
+
dist.all_reduce(n_total)
|
1343 |
+
dist.all_reduce(encode_sum)
|
1344 |
+
|
1345 |
+
self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
|
1346 |
+
self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
|
1347 |
+
|
1348 |
+
n = self.N.sum()
|
1349 |
+
weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
|
1350 |
+
encode_normalized = self.z_avg / weights.unsqueeze(1)
|
1351 |
+
self.embeddings.data.copy_(encode_normalized)
|
1352 |
+
|
1353 |
+
y = self._tile(flat_inputs)
|
1354 |
+
_k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
|
1355 |
+
if dist.is_initialized():
|
1356 |
+
dist.broadcast(_k_rand, 0)
|
1357 |
+
|
1358 |
+
usage = (self.N.view(self.n_codes, 1) >= 1).float()
|
1359 |
+
self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
|
1360 |
+
|
1361 |
+
embeddings_st = (embeddings - z).detach() + z
|
1362 |
+
|
1363 |
+
avg_probs = torch.mean(encode_onehot, dim=0)
|
1364 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
1365 |
+
|
1366 |
+
return dict(
|
1367 |
+
embeddings=embeddings_st,
|
1368 |
+
encodings=encoding_indices,
|
1369 |
+
commitment_loss=commitment_loss,
|
1370 |
+
perplexity=perplexity,
|
1371 |
+
)
|
1372 |
+
|
1373 |
+
def dictionary_lookup(self, encodings):
|
1374 |
+
embeddings = F.embedding(encodings, self.embeddings)
|
1375 |
+
return embeddings
|
1376 |
+
|
1377 |
+
|
1378 |
+
class ResnetBlock2D(Block):
|
1379 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
|
1380 |
+
super().__init__()
|
1381 |
+
self.in_channels = in_channels
|
1382 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
1383 |
+
self.use_conv_shortcut = conv_shortcut
|
1384 |
+
|
1385 |
+
self.norm1 = Normalize(in_channels)
|
1386 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
1387 |
+
self.norm2 = Normalize(out_channels)
|
1388 |
+
self.dropout = torch.nn.Dropout(dropout)
|
1389 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
1390 |
+
if self.in_channels != self.out_channels:
|
1391 |
+
if self.use_conv_shortcut:
|
1392 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
1393 |
+
else:
|
1394 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
1395 |
+
|
1396 |
+
@video_to_image
|
1397 |
+
def forward(self, x):
|
1398 |
+
h = x
|
1399 |
+
h = self.norm1(h)
|
1400 |
+
h = nonlinearity(h)
|
1401 |
+
h = self.conv1(h)
|
1402 |
+
h = self.norm2(h)
|
1403 |
+
h = nonlinearity(h)
|
1404 |
+
h = self.dropout(h)
|
1405 |
+
h = self.conv2(h)
|
1406 |
+
if self.in_channels != self.out_channels:
|
1407 |
+
if self.use_conv_shortcut:
|
1408 |
+
x = self.conv_shortcut(x)
|
1409 |
+
else:
|
1410 |
+
x = self.nin_shortcut(x)
|
1411 |
+
x = x + h
|
1412 |
+
return x
|
1413 |
+
|
1414 |
+
|
1415 |
+
class ResnetBlock3D(Block):
|
1416 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
|
1417 |
+
super().__init__()
|
1418 |
+
self.in_channels = in_channels
|
1419 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
1420 |
+
self.use_conv_shortcut = conv_shortcut
|
1421 |
+
|
1422 |
+
self.norm1 = Normalize(in_channels)
|
1423 |
+
self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)
|
1424 |
+
self.norm2 = Normalize(out_channels)
|
1425 |
+
self.dropout = torch.nn.Dropout(dropout)
|
1426 |
+
self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)
|
1427 |
+
if self.in_channels != self.out_channels:
|
1428 |
+
if self.use_conv_shortcut:
|
1429 |
+
self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1)
|
1430 |
+
else:
|
1431 |
+
self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0)
|
1432 |
+
|
1433 |
+
def forward(self, x):
|
1434 |
+
h = x
|
1435 |
+
h = self.norm1(h)
|
1436 |
+
h = nonlinearity(h)
|
1437 |
+
h = self.conv1(h)
|
1438 |
+
h = self.norm2(h)
|
1439 |
+
h = nonlinearity(h)
|
1440 |
+
h = self.dropout(h)
|
1441 |
+
h = self.conv2(h)
|
1442 |
+
if self.in_channels != self.out_channels:
|
1443 |
+
if self.use_conv_shortcut:
|
1444 |
+
x = self.conv_shortcut(x)
|
1445 |
+
else:
|
1446 |
+
x = self.nin_shortcut(x)
|
1447 |
+
return x + h
|
1448 |
+
|
1449 |
+
|
1450 |
+
class Upsample(Block):
|
1451 |
+
def __init__(self, in_channels, out_channels):
|
1452 |
+
super().__init__()
|
1453 |
+
self.with_conv = True
|
1454 |
+
if self.with_conv:
|
1455 |
+
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
1456 |
+
|
1457 |
+
@video_to_image
|
1458 |
+
def forward(self, x):
|
1459 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
1460 |
+
if self.with_conv:
|
1461 |
+
x = self.conv(x)
|
1462 |
+
return x
|
1463 |
+
|
1464 |
+
|
1465 |
+
class Downsample(Block):
|
1466 |
+
def __init__(self, in_channels, out_channels):
|
1467 |
+
super().__init__()
|
1468 |
+
self.with_conv = True
|
1469 |
+
if self.with_conv:
|
1470 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
1471 |
+
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
1472 |
+
|
1473 |
+
@video_to_image
|
1474 |
+
def forward(self, x):
|
1475 |
+
if self.with_conv:
|
1476 |
+
pad = (0, 1, 0, 1)
|
1477 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
1478 |
+
x = self.conv(x)
|
1479 |
+
else:
|
1480 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
1481 |
+
return x
|
1482 |
+
|
1483 |
+
|
1484 |
+
class SpatialDownsample2x(Block):
|
1485 |
+
def __init__(
|
1486 |
+
self,
|
1487 |
+
chan_in,
|
1488 |
+
chan_out,
|
1489 |
+
kernel_size: Union[int, Tuple[int]] = (3, 3),
|
1490 |
+
stride: Union[int, Tuple[int]] = (2, 2),
|
1491 |
+
):
|
1492 |
+
super().__init__()
|
1493 |
+
kernel_size = cast_tuple(kernel_size, 2)
|
1494 |
+
stride = cast_tuple(stride, 2)
|
1495 |
+
self.chan_in = chan_in
|
1496 |
+
self.chan_out = chan_out
|
1497 |
+
self.kernel_size = kernel_size
|
1498 |
+
self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=0)
|
1499 |
+
|
1500 |
+
def forward(self, x):
|
1501 |
+
pad = (0, 1, 0, 1, 0, 0)
|
1502 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
1503 |
+
x = self.conv(x)
|
1504 |
+
return x
|
1505 |
+
|
1506 |
+
|
1507 |
+
class SpatialUpsample2x(Block):
|
1508 |
+
def __init__(
|
1509 |
+
self,
|
1510 |
+
chan_in,
|
1511 |
+
chan_out,
|
1512 |
+
kernel_size: Union[int, Tuple[int]] = (3, 3),
|
1513 |
+
stride: Union[int, Tuple[int]] = (1, 1),
|
1514 |
+
):
|
1515 |
+
super().__init__()
|
1516 |
+
self.chan_in = chan_in
|
1517 |
+
self.chan_out = chan_out
|
1518 |
+
self.kernel_size = kernel_size
|
1519 |
+
self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=1)
|
1520 |
+
|
1521 |
+
def forward(self, x):
|
1522 |
+
t = x.shape[2]
|
1523 |
+
x = rearrange(x, "b c t h w -> b (c t) h w")
|
1524 |
+
x = F.interpolate(x, scale_factor=(2, 2), mode="nearest")
|
1525 |
+
x = rearrange(x, "b (c t) h w -> b c t h w", t=t)
|
1526 |
+
x = self.conv(x)
|
1527 |
+
return x
|
1528 |
+
|
1529 |
+
|
1530 |
+
class TimeDownsample2x(Block):
|
1531 |
+
def __init__(self, chan_in, chan_out, kernel_size: int = 3):
|
1532 |
+
super().__init__()
|
1533 |
+
self.kernel_size = kernel_size
|
1534 |
+
self.conv = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
|
1535 |
+
|
1536 |
+
def forward(self, x):
|
1537 |
+
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size - 1, 1, 1))
|
1538 |
+
x = torch.concatenate((first_frame_pad, x), dim=2)
|
1539 |
+
return self.conv(x)
|
1540 |
+
|
1541 |
+
|
1542 |
+
class TimeUpsample2x(Block):
|
1543 |
+
def __init__(self, chan_in, chan_out):
|
1544 |
+
super().__init__()
|
1545 |
+
|
1546 |
+
def forward(self, x):
|
1547 |
+
if x.size(2) > 1:
|
1548 |
+
x, x_ = x[:, :, :1], x[:, :, 1:]
|
1549 |
+
x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
|
1550 |
+
x = torch.concat([x, x_], dim=2)
|
1551 |
+
return x
|
1552 |
+
|
1553 |
+
|
1554 |
+
class TimeDownsampleRes2x(nn.Module):
|
1555 |
+
def __init__(
|
1556 |
+
self,
|
1557 |
+
in_channels,
|
1558 |
+
out_channels,
|
1559 |
+
kernel_size: int = 3,
|
1560 |
+
mix_factor: float = 2.0,
|
1561 |
+
):
|
1562 |
+
super().__init__()
|
1563 |
+
self.kernel_size = cast_tuple(kernel_size, 3)
|
1564 |
+
self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
|
1565 |
+
self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1))
|
1566 |
+
self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
|
1567 |
+
|
1568 |
+
def forward(self, x):
|
1569 |
+
alpha = torch.sigmoid(self.mix_factor)
|
1570 |
+
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1))
|
1571 |
+
x = torch.concatenate((first_frame_pad, x), dim=2)
|
1572 |
+
return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x)
|
1573 |
+
|
1574 |
+
|
1575 |
+
class TimeUpsampleRes2x(nn.Module):
|
1576 |
+
def __init__(
|
1577 |
+
self,
|
1578 |
+
in_channels,
|
1579 |
+
out_channels,
|
1580 |
+
kernel_size: int = 3,
|
1581 |
+
mix_factor: float = 2.0,
|
1582 |
+
):
|
1583 |
+
super().__init__()
|
1584 |
+
self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1)
|
1585 |
+
self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
|
1586 |
+
|
1587 |
+
def forward(self, x):
|
1588 |
+
alpha = torch.sigmoid(self.mix_factor)
|
1589 |
+
if x.size(2) > 1:
|
1590 |
+
x, x_ = x[:, :, :1], x[:, :, 1:]
|
1591 |
+
x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
|
1592 |
+
x = torch.concat([x, x_], dim=2)
|
1593 |
+
return alpha * x + (1 - alpha) * self.conv(x)
|
1594 |
+
|
1595 |
+
|
1596 |
+
class TimeDownsampleResAdv2x(nn.Module):
|
1597 |
+
def __init__(
|
1598 |
+
self,
|
1599 |
+
in_channels,
|
1600 |
+
out_channels,
|
1601 |
+
kernel_size: int = 3,
|
1602 |
+
mix_factor: float = 1.5,
|
1603 |
+
):
|
1604 |
+
super().__init__()
|
1605 |
+
self.kernel_size = cast_tuple(kernel_size, 3)
|
1606 |
+
self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
|
1607 |
+
self.attn = TemporalAttnBlock(in_channels)
|
1608 |
+
self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
|
1609 |
+
self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1))
|
1610 |
+
self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
|
1611 |
+
|
1612 |
+
def forward(self, x):
|
1613 |
+
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1))
|
1614 |
+
x = torch.concatenate((first_frame_pad, x), dim=2)
|
1615 |
+
alpha = torch.sigmoid(self.mix_factor)
|
1616 |
+
return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(self.attn((self.res(x))))
|
1617 |
+
|
1618 |
+
|
1619 |
+
class TimeUpsampleResAdv2x(nn.Module):
|
1620 |
+
def __init__(
|
1621 |
+
self,
|
1622 |
+
in_channels,
|
1623 |
+
out_channels,
|
1624 |
+
kernel_size: int = 3,
|
1625 |
+
mix_factor: float = 1.5,
|
1626 |
+
):
|
1627 |
+
super().__init__()
|
1628 |
+
self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
|
1629 |
+
self.attn = TemporalAttnBlock(in_channels)
|
1630 |
+
self.norm = Normalize(in_channels=in_channels)
|
1631 |
+
self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1)
|
1632 |
+
self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
|
1633 |
+
|
1634 |
+
def forward(self, x):
|
1635 |
+
if x.size(2) > 1:
|
1636 |
+
x, x_ = x[:, :, :1], x[:, :, 1:]
|
1637 |
+
x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
|
1638 |
+
x = torch.concat([x, x_], dim=2)
|
1639 |
+
alpha = torch.sigmoid(self.mix_factor)
|
1640 |
+
return alpha * x + (1 - alpha) * self.conv(self.attn(self.res(x)))
|
videosys/models/cogvideo/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
from .pipeline import CogVideoConfig, CogVideoPipeline
|
2 |
-
|
3 |
-
__all__ = [
|
4 |
-
"CogVideoConfig",
|
5 |
-
"CogVideoPipeline",
|
6 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
videosys/models/cogvideo/modules.py
DELETED
@@ -1,317 +0,0 @@
|
|
1 |
-
# Adapted from CogVideo
|
2 |
-
|
3 |
-
# This source code is licensed under the license found in the
|
4 |
-
# LICENSE file in the root directory of this source tree.
|
5 |
-
# --------------------------------------------------------
|
6 |
-
# References:
|
7 |
-
# CogVideo: https://github.com/THUDM/CogVideo
|
8 |
-
# diffusers: https://github.com/huggingface/diffusers
|
9 |
-
# --------------------------------------------------------
|
10 |
-
|
11 |
-
from typing import Optional, Tuple, Union
|
12 |
-
|
13 |
-
import numpy as np
|
14 |
-
import torch
|
15 |
-
import torch.nn as nn
|
16 |
-
import torch.nn.functional as F
|
17 |
-
from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, get_2d_sincos_pos_embed_from_grid
|
18 |
-
|
19 |
-
|
20 |
-
class CogVideoXDownsample3D(nn.Module):
|
21 |
-
# Todo: Wait for paper relase.
|
22 |
-
r"""
|
23 |
-
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
|
24 |
-
|
25 |
-
Args:
|
26 |
-
in_channels (`int`):
|
27 |
-
Number of channels in the input image.
|
28 |
-
out_channels (`int`):
|
29 |
-
Number of channels produced by the convolution.
|
30 |
-
kernel_size (`int`, defaults to `3`):
|
31 |
-
Size of the convolving kernel.
|
32 |
-
stride (`int`, defaults to `2`):
|
33 |
-
Stride of the convolution.
|
34 |
-
padding (`int`, defaults to `0`):
|
35 |
-
Padding added to all four sides of the input.
|
36 |
-
compress_time (`bool`, defaults to `False`):
|
37 |
-
Whether or not to compress the time dimension.
|
38 |
-
"""
|
39 |
-
|
40 |
-
def __init__(
|
41 |
-
self,
|
42 |
-
in_channels: int,
|
43 |
-
out_channels: int,
|
44 |
-
kernel_size: int = 3,
|
45 |
-
stride: int = 2,
|
46 |
-
padding: int = 0,
|
47 |
-
compress_time: bool = False,
|
48 |
-
):
|
49 |
-
super().__init__()
|
50 |
-
|
51 |
-
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
52 |
-
self.compress_time = compress_time
|
53 |
-
|
54 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
55 |
-
if self.compress_time:
|
56 |
-
batch_size, channels, frames, height, width = x.shape
|
57 |
-
|
58 |
-
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
|
59 |
-
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
|
60 |
-
|
61 |
-
if x.shape[-1] % 2 == 1:
|
62 |
-
x_first, x_rest = x[..., 0], x[..., 1:]
|
63 |
-
if x_rest.shape[-1] > 0:
|
64 |
-
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
65 |
-
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
66 |
-
|
67 |
-
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
68 |
-
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
|
69 |
-
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
70 |
-
else:
|
71 |
-
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
|
72 |
-
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
73 |
-
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
74 |
-
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
75 |
-
|
76 |
-
# Pad the tensor
|
77 |
-
pad = (0, 1, 0, 1)
|
78 |
-
x = F.pad(x, pad, mode="constant", value=0)
|
79 |
-
batch_size, channels, frames, height, width = x.shape
|
80 |
-
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
|
81 |
-
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
|
82 |
-
x = self.conv(x)
|
83 |
-
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
|
84 |
-
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
85 |
-
return x
|
86 |
-
|
87 |
-
|
88 |
-
class CogVideoXUpsample3D(nn.Module):
|
89 |
-
r"""
|
90 |
-
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
|
91 |
-
|
92 |
-
Args:
|
93 |
-
in_channels (`int`):
|
94 |
-
Number of channels in the input image.
|
95 |
-
out_channels (`int`):
|
96 |
-
Number of channels produced by the convolution.
|
97 |
-
kernel_size (`int`, defaults to `3`):
|
98 |
-
Size of the convolving kernel.
|
99 |
-
stride (`int`, defaults to `1`):
|
100 |
-
Stride of the convolution.
|
101 |
-
padding (`int`, defaults to `1`):
|
102 |
-
Padding added to all four sides of the input.
|
103 |
-
compress_time (`bool`, defaults to `False`):
|
104 |
-
Whether or not to compress the time dimension.
|
105 |
-
"""
|
106 |
-
|
107 |
-
def __init__(
|
108 |
-
self,
|
109 |
-
in_channels: int,
|
110 |
-
out_channels: int,
|
111 |
-
kernel_size: int = 3,
|
112 |
-
stride: int = 1,
|
113 |
-
padding: int = 1,
|
114 |
-
compress_time: bool = False,
|
115 |
-
) -> None:
|
116 |
-
super().__init__()
|
117 |
-
|
118 |
-
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
119 |
-
self.compress_time = compress_time
|
120 |
-
|
121 |
-
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
122 |
-
if self.compress_time:
|
123 |
-
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
124 |
-
# split first frame
|
125 |
-
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
126 |
-
|
127 |
-
x_first = F.interpolate(x_first, scale_factor=2.0)
|
128 |
-
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
129 |
-
x_first = x_first[:, :, None, :, :]
|
130 |
-
inputs = torch.cat([x_first, x_rest], dim=2)
|
131 |
-
elif inputs.shape[2] > 1:
|
132 |
-
inputs = F.interpolate(inputs, scale_factor=2.0)
|
133 |
-
else:
|
134 |
-
inputs = inputs.squeeze(2)
|
135 |
-
inputs = F.interpolate(inputs, scale_factor=2.0)
|
136 |
-
inputs = inputs[:, :, None, :, :]
|
137 |
-
else:
|
138 |
-
# only interpolate 2D
|
139 |
-
b, c, t, h, w = inputs.shape
|
140 |
-
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
141 |
-
inputs = F.interpolate(inputs, scale_factor=2.0)
|
142 |
-
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
143 |
-
|
144 |
-
b, c, t, h, w = inputs.shape
|
145 |
-
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
146 |
-
inputs = self.conv(inputs)
|
147 |
-
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
148 |
-
|
149 |
-
return inputs
|
150 |
-
|
151 |
-
|
152 |
-
def get_3d_sincos_pos_embed(
|
153 |
-
embed_dim: int,
|
154 |
-
spatial_size: Union[int, Tuple[int, int]],
|
155 |
-
temporal_size: int,
|
156 |
-
spatial_interpolation_scale: float = 1.0,
|
157 |
-
temporal_interpolation_scale: float = 1.0,
|
158 |
-
) -> np.ndarray:
|
159 |
-
r"""
|
160 |
-
Args:
|
161 |
-
embed_dim (`int`):
|
162 |
-
spatial_size (`int` or `Tuple[int, int]`):
|
163 |
-
temporal_size (`int`):
|
164 |
-
spatial_interpolation_scale (`float`, defaults to 1.0):
|
165 |
-
temporal_interpolation_scale (`float`, defaults to 1.0):
|
166 |
-
"""
|
167 |
-
if embed_dim % 4 != 0:
|
168 |
-
raise ValueError("`embed_dim` must be divisible by 4")
|
169 |
-
if isinstance(spatial_size, int):
|
170 |
-
spatial_size = (spatial_size, spatial_size)
|
171 |
-
|
172 |
-
embed_dim_spatial = 3 * embed_dim // 4
|
173 |
-
embed_dim_temporal = embed_dim // 4
|
174 |
-
|
175 |
-
# 1. Spatial
|
176 |
-
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
|
177 |
-
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
|
178 |
-
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
179 |
-
grid = np.stack(grid, axis=0)
|
180 |
-
|
181 |
-
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
|
182 |
-
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
|
183 |
-
|
184 |
-
# 2. Temporal
|
185 |
-
grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
|
186 |
-
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
|
187 |
-
|
188 |
-
# 3. Concat
|
189 |
-
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
|
190 |
-
pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
|
191 |
-
|
192 |
-
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
|
193 |
-
pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
|
194 |
-
|
195 |
-
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
|
196 |
-
return pos_embed
|
197 |
-
|
198 |
-
|
199 |
-
class CogVideoXPatchEmbed(nn.Module):
|
200 |
-
def __init__(
|
201 |
-
self,
|
202 |
-
patch_size: int = 2,
|
203 |
-
in_channels: int = 16,
|
204 |
-
embed_dim: int = 1920,
|
205 |
-
text_embed_dim: int = 4096,
|
206 |
-
bias: bool = True,
|
207 |
-
) -> None:
|
208 |
-
super().__init__()
|
209 |
-
self.patch_size = patch_size
|
210 |
-
|
211 |
-
self.proj = nn.Conv2d(
|
212 |
-
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
213 |
-
)
|
214 |
-
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
215 |
-
|
216 |
-
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
217 |
-
r"""
|
218 |
-
Args:
|
219 |
-
text_embeds (`torch.Tensor`):
|
220 |
-
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
221 |
-
image_embeds (`torch.Tensor`):
|
222 |
-
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
223 |
-
"""
|
224 |
-
text_embeds = self.text_proj(text_embeds)
|
225 |
-
|
226 |
-
batch, num_frames, channels, height, width = image_embeds.shape
|
227 |
-
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
228 |
-
image_embeds = self.proj(image_embeds)
|
229 |
-
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
230 |
-
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
231 |
-
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
232 |
-
|
233 |
-
embeds = torch.cat(
|
234 |
-
[text_embeds, image_embeds], dim=1
|
235 |
-
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
236 |
-
return embeds
|
237 |
-
|
238 |
-
|
239 |
-
class CogVideoXLayerNormZero(nn.Module):
|
240 |
-
def __init__(
|
241 |
-
self,
|
242 |
-
conditioning_dim: int,
|
243 |
-
embedding_dim: int,
|
244 |
-
elementwise_affine: bool = True,
|
245 |
-
eps: float = 1e-5,
|
246 |
-
bias: bool = True,
|
247 |
-
) -> None:
|
248 |
-
super().__init__()
|
249 |
-
|
250 |
-
self.silu = nn.SiLU()
|
251 |
-
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
|
252 |
-
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
253 |
-
|
254 |
-
def forward(
|
255 |
-
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
256 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
257 |
-
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
258 |
-
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
259 |
-
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
260 |
-
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
|
261 |
-
|
262 |
-
|
263 |
-
class AdaLayerNorm(nn.Module):
|
264 |
-
r"""
|
265 |
-
Norm layer modified to incorporate timestep embeddings.
|
266 |
-
|
267 |
-
Parameters:
|
268 |
-
embedding_dim (`int`): The size of each embedding vector.
|
269 |
-
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
|
270 |
-
output_dim (`int`, *optional*):
|
271 |
-
norm_elementwise_affine (`bool`, defaults to `False):
|
272 |
-
norm_eps (`bool`, defaults to `False`):
|
273 |
-
chunk_dim (`int`, defaults to `0`):
|
274 |
-
"""
|
275 |
-
|
276 |
-
def __init__(
|
277 |
-
self,
|
278 |
-
embedding_dim: int,
|
279 |
-
num_embeddings: Optional[int] = None,
|
280 |
-
output_dim: Optional[int] = None,
|
281 |
-
norm_elementwise_affine: bool = False,
|
282 |
-
norm_eps: float = 1e-5,
|
283 |
-
chunk_dim: int = 0,
|
284 |
-
):
|
285 |
-
super().__init__()
|
286 |
-
|
287 |
-
self.chunk_dim = chunk_dim
|
288 |
-
output_dim = output_dim or embedding_dim * 2
|
289 |
-
|
290 |
-
if num_embeddings is not None:
|
291 |
-
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
292 |
-
else:
|
293 |
-
self.emb = None
|
294 |
-
|
295 |
-
self.silu = nn.SiLU()
|
296 |
-
self.linear = nn.Linear(embedding_dim, output_dim)
|
297 |
-
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
298 |
-
|
299 |
-
def forward(
|
300 |
-
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
|
301 |
-
) -> torch.Tensor:
|
302 |
-
if self.emb is not None:
|
303 |
-
temb = self.emb(timestep)
|
304 |
-
|
305 |
-
temb = self.linear(self.silu(temb))
|
306 |
-
|
307 |
-
if self.chunk_dim == 1:
|
308 |
-
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
309 |
-
# other if-branch. This branch is specific to CogVideoX for now.
|
310 |
-
shift, scale = temb.chunk(2, dim=1)
|
311 |
-
shift = shift[:, None, :]
|
312 |
-
scale = scale[:, None, :]
|
313 |
-
else:
|
314 |
-
scale, shift = temb.chunk(2, dim=0)
|
315 |
-
|
316 |
-
x = self.norm(x) * (1 + scale) + shift
|
317 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
videosys/models/cogvideo/retrieve_timesteps.py
DELETED
@@ -1,74 +0,0 @@
|
|
1 |
-
# Adapted from CogVideo
|
2 |
-
|
3 |
-
# This source code is licensed under the license found in the
|
4 |
-
# LICENSE file in the root directory of this source tree.
|
5 |
-
# --------------------------------------------------------
|
6 |
-
# References:
|
7 |
-
# CogVideo: https://github.com/THUDM/CogVideo
|
8 |
-
# diffusers: https://github.com/huggingface/diffusers
|
9 |
-
# --------------------------------------------------------
|
10 |
-
|
11 |
-
import inspect
|
12 |
-
from typing import List, Optional, Union
|
13 |
-
|
14 |
-
import torch
|
15 |
-
|
16 |
-
|
17 |
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
18 |
-
def retrieve_timesteps(
|
19 |
-
scheduler,
|
20 |
-
num_inference_steps: Optional[int] = None,
|
21 |
-
device: Optional[Union[str, torch.device]] = None,
|
22 |
-
timesteps: Optional[List[int]] = None,
|
23 |
-
sigmas: Optional[List[float]] = None,
|
24 |
-
**kwargs,
|
25 |
-
):
|
26 |
-
"""
|
27 |
-
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
28 |
-
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
29 |
-
|
30 |
-
Args:
|
31 |
-
scheduler (`SchedulerMixin`):
|
32 |
-
The scheduler to get timesteps from.
|
33 |
-
num_inference_steps (`int`):
|
34 |
-
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
35 |
-
must be `None`.
|
36 |
-
device (`str` or `torch.device`, *optional*):
|
37 |
-
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
38 |
-
timesteps (`List[int]`, *optional*):
|
39 |
-
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
40 |
-
`num_inference_steps` and `sigmas` must be `None`.
|
41 |
-
sigmas (`List[float]`, *optional*):
|
42 |
-
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
43 |
-
`num_inference_steps` and `timesteps` must be `None`.
|
44 |
-
|
45 |
-
Returns:
|
46 |
-
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
47 |
-
second element is the number of inference steps.
|
48 |
-
"""
|
49 |
-
if timesteps is not None and sigmas is not None:
|
50 |
-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
51 |
-
if timesteps is not None:
|
52 |
-
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
53 |
-
if not accepts_timesteps:
|
54 |
-
raise ValueError(
|
55 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
56 |
-
f" timestep schedules. Please check whether you are using the correct scheduler."
|
57 |
-
)
|
58 |
-
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
59 |
-
timesteps = scheduler.timesteps
|
60 |
-
num_inference_steps = len(timesteps)
|
61 |
-
elif sigmas is not None:
|
62 |
-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
63 |
-
if not accept_sigmas:
|
64 |
-
raise ValueError(
|
65 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
66 |
-
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
67 |
-
)
|
68 |
-
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
69 |
-
timesteps = scheduler.timesteps
|
70 |
-
num_inference_steps = len(timesteps)
|
71 |
-
else:
|
72 |
-
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
73 |
-
timesteps = scheduler.timesteps
|
74 |
-
return timesteps, num_inference_steps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
videosys/models/latte/__init__.py
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
from .pipeline import LatteConfig, LattePABConfig, LattePipeline
|
2 |
-
|
3 |
-
__all__ = [
|
4 |
-
"LattePipeline",
|
5 |
-
"LattePABConfig",
|
6 |
-
"LatteConfig",
|
7 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{eval/pab/experiments → videosys/models/modules}/__init__.py
RENAMED
File without changes
|
videosys/models/modules/activations.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
videosys/{modules/attn.py → models/modules/attentions.py}
RENAMED
@@ -1,12 +1,8 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from typing import Iterable, List, Optional, Sequence, Tuple
|
3 |
-
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
-
import torch.nn.functional as F
|
7 |
import torch.utils.checkpoint
|
8 |
|
9 |
-
from videosys.modules.
|
10 |
|
11 |
|
12 |
class Attention(nn.Module):
|
@@ -19,8 +15,9 @@ class Attention(nn.Module):
|
|
19 |
attn_drop: float = 0.0,
|
20 |
proj_drop: float = 0.0,
|
21 |
norm_layer: nn.Module = LlamaRMSNorm,
|
22 |
-
|
23 |
rope=None,
|
|
|
24 |
) -> None:
|
25 |
super().__init__()
|
26 |
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
@@ -28,11 +25,12 @@ class Attention(nn.Module):
|
|
28 |
self.num_heads = num_heads
|
29 |
self.head_dim = dim // num_heads
|
30 |
self.scale = self.head_dim**-0.5
|
31 |
-
self.
|
32 |
|
33 |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
34 |
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
35 |
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
|
36 |
self.attn_drop = nn.Dropout(attn_drop)
|
37 |
self.proj = nn.Linear(dim, dim)
|
38 |
self.proj_drop = nn.Dropout(proj_drop)
|
@@ -44,18 +42,32 @@ class Attention(nn.Module):
|
|
44 |
|
45 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
46 |
B, N, C = x.shape
|
47 |
-
|
|
|
48 |
qkv = self.qkv(x)
|
49 |
-
|
|
|
|
|
50 |
q, k, v = qkv.unbind(0)
|
51 |
-
if self.
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
-
if
|
57 |
from flash_attn import flash_attn_func
|
58 |
|
|
|
|
|
|
|
|
|
59 |
x = flash_attn_func(
|
60 |
q,
|
61 |
k,
|
@@ -64,13 +76,17 @@ class Attention(nn.Module):
|
|
64 |
softmax_scale=self.scale,
|
65 |
)
|
66 |
else:
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
)
|
|
|
|
|
|
|
|
|
71 |
|
72 |
x_output_shape = (B, N, C)
|
73 |
-
if not
|
74 |
x = x.transpose(1, 2)
|
75 |
x = x.reshape(x_output_shape)
|
76 |
x = self.proj(x)
|
@@ -79,139 +95,37 @@ class Attention(nn.Module):
|
|
79 |
|
80 |
|
81 |
class MultiHeadCrossAttention(nn.Module):
|
82 |
-
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0
|
83 |
super(MultiHeadCrossAttention, self).__init__()
|
84 |
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
85 |
|
86 |
self.d_model = d_model
|
87 |
self.num_heads = num_heads
|
88 |
self.head_dim = d_model // num_heads
|
89 |
-
self.enable_flashattn = enable_flashattn
|
90 |
|
91 |
self.q_linear = nn.Linear(d_model, d_model)
|
92 |
self.kv_linear = nn.Linear(d_model, d_model * 2)
|
93 |
self.attn_drop = nn.Dropout(attn_drop)
|
94 |
self.proj = nn.Linear(d_model, d_model)
|
95 |
self.proj_drop = nn.Dropout(proj_drop)
|
96 |
-
self.last_out = None
|
97 |
-
self.count = 0
|
98 |
|
99 |
-
def forward(self, x, cond, mask=None
|
100 |
# query/value: img tokens; key: condition; mask: if padding tokens
|
101 |
B, N, C = x.shape
|
102 |
|
103 |
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
104 |
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
105 |
k, v = kv.unbind(2)
|
106 |
-
x = self.flash_attn_impl(q, k, v, mask, B, N, C)
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
|
112 |
-
def flash_attn_impl(self, q, k, v, mask, B, N, C):
|
113 |
-
from flash_attn import flash_attn_varlen_func
|
114 |
-
|
115 |
-
q_seqinfo = _SeqLenInfo.from_seqlens([N] * B)
|
116 |
-
k_seqinfo = _SeqLenInfo.from_seqlens(mask)
|
117 |
-
|
118 |
-
x = flash_attn_varlen_func(
|
119 |
-
q.view(-1, self.num_heads, self.head_dim),
|
120 |
-
k.view(-1, self.num_heads, self.head_dim),
|
121 |
-
v.view(-1, self.num_heads, self.head_dim),
|
122 |
-
cu_seqlens_q=q_seqinfo.seqstart.cuda(),
|
123 |
-
cu_seqlens_k=k_seqinfo.seqstart.cuda(),
|
124 |
-
max_seqlen_q=q_seqinfo.max_seqlen,
|
125 |
-
max_seqlen_k=k_seqinfo.max_seqlen,
|
126 |
-
dropout_p=self.attn_drop.p if self.training else 0.0,
|
127 |
-
)
|
128 |
-
x = x.view(B, N, C)
|
129 |
-
return x
|
130 |
-
|
131 |
-
def torch_impl(self, q, k, v, mask, B, N, C):
|
132 |
-
q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
133 |
-
k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
134 |
-
v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
135 |
-
|
136 |
-
attn_mask = torch.zeros(B, N, k.shape[2], dtype=torch.float32, device=q.device)
|
137 |
-
for i, m in enumerate(mask):
|
138 |
-
attn_mask[i, :, m:] = -1e8
|
139 |
-
|
140 |
-
scale = 1 / q.shape[-1] ** 0.5
|
141 |
-
q = q * scale
|
142 |
-
attn = q @ k.transpose(-2, -1)
|
143 |
-
attn = attn.to(torch.float32)
|
144 |
if mask is not None:
|
145 |
-
|
146 |
-
|
147 |
-
attn = attn.to(v.dtype)
|
148 |
-
out = attn @ v
|
149 |
|
150 |
-
x =
|
|
|
|
|
151 |
return x
|
152 |
-
|
153 |
-
|
154 |
-
@dataclass
|
155 |
-
class _SeqLenInfo:
|
156 |
-
"""
|
157 |
-
copied from xformers
|
158 |
-
|
159 |
-
(Internal) Represents the division of a dimension into blocks.
|
160 |
-
For example, to represents a dimension of length 7 divided into
|
161 |
-
three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
|
162 |
-
The members will be:
|
163 |
-
max_seqlen: 3
|
164 |
-
min_seqlen: 2
|
165 |
-
seqstart_py: [0, 2, 5, 7]
|
166 |
-
seqstart: torch.IntTensor([0, 2, 5, 7])
|
167 |
-
"""
|
168 |
-
|
169 |
-
seqstart: torch.Tensor
|
170 |
-
max_seqlen: int
|
171 |
-
min_seqlen: int
|
172 |
-
seqstart_py: List[int]
|
173 |
-
|
174 |
-
def to(self, device: torch.device) -> None:
|
175 |
-
self.seqstart = self.seqstart.to(device, non_blocking=True)
|
176 |
-
|
177 |
-
def intervals(self) -> Iterable[Tuple[int, int]]:
|
178 |
-
yield from zip(self.seqstart_py, self.seqstart_py[1:])
|
179 |
-
|
180 |
-
@classmethod
|
181 |
-
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
|
182 |
-
"""
|
183 |
-
Input tensors are assumed to be in shape [B, M, *]
|
184 |
-
"""
|
185 |
-
assert not isinstance(seqlens, torch.Tensor)
|
186 |
-
seqstart_py = [0]
|
187 |
-
max_seqlen = -1
|
188 |
-
min_seqlen = -1
|
189 |
-
for seqlen in seqlens:
|
190 |
-
min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
|
191 |
-
max_seqlen = max(max_seqlen, seqlen)
|
192 |
-
seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
|
193 |
-
seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
|
194 |
-
return cls(
|
195 |
-
max_seqlen=max_seqlen,
|
196 |
-
min_seqlen=min_seqlen,
|
197 |
-
seqstart=seqstart,
|
198 |
-
seqstart_py=seqstart_py,
|
199 |
-
)
|
200 |
-
|
201 |
-
def split(self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None) -> List[torch.Tensor]:
|
202 |
-
if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1:
|
203 |
-
raise ValueError(
|
204 |
-
f"Invalid `torch.Tensor` of shape {x.shape}, expected format "
|
205 |
-
f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n"
|
206 |
-
f" seqstart: {self.seqstart_py}"
|
207 |
-
)
|
208 |
-
if batch_sizes is None:
|
209 |
-
batch_sizes = [1] * (len(self.seqstart_py) - 1)
|
210 |
-
split_chunks = []
|
211 |
-
it = 0
|
212 |
-
for batch_size in batch_sizes:
|
213 |
-
split_chunks.append(self.seqstart_py[it + batch_size] - self.seqstart_py[it])
|
214 |
-
it += batch_size
|
215 |
-
return [
|
216 |
-
tensor.reshape([bs, -1, *tensor.shape[2:]]) for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1))
|
217 |
-
]
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
|
|
3 |
import torch.utils.checkpoint
|
4 |
|
5 |
+
from videosys.models.modules.normalization import LlamaRMSNorm
|
6 |
|
7 |
|
8 |
class Attention(nn.Module):
|
|
|
15 |
attn_drop: float = 0.0,
|
16 |
proj_drop: float = 0.0,
|
17 |
norm_layer: nn.Module = LlamaRMSNorm,
|
18 |
+
enable_flash_attn: bool = False,
|
19 |
rope=None,
|
20 |
+
qk_norm_legacy: bool = False,
|
21 |
) -> None:
|
22 |
super().__init__()
|
23 |
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
|
|
25 |
self.num_heads = num_heads
|
26 |
self.head_dim = dim // num_heads
|
27 |
self.scale = self.head_dim**-0.5
|
28 |
+
self.enable_flash_attn = enable_flash_attn
|
29 |
|
30 |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
31 |
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
32 |
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
33 |
+
self.qk_norm_legacy = qk_norm_legacy
|
34 |
self.attn_drop = nn.Dropout(attn_drop)
|
35 |
self.proj = nn.Linear(dim, dim)
|
36 |
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
42 |
|
43 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
44 |
B, N, C = x.shape
|
45 |
+
# flash attn is not memory efficient for small sequences, this is empirical
|
46 |
+
enable_flash_attn = self.enable_flash_attn and (N > B)
|
47 |
qkv = self.qkv(x)
|
48 |
+
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
49 |
+
|
50 |
+
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
|
51 |
q, k, v = qkv.unbind(0)
|
52 |
+
if self.qk_norm_legacy:
|
53 |
+
# WARNING: this may be a bug
|
54 |
+
if self.rope:
|
55 |
+
q = self.rotary_emb(q)
|
56 |
+
k = self.rotary_emb(k)
|
57 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
58 |
+
else:
|
59 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
60 |
+
if self.rope:
|
61 |
+
q = self.rotary_emb(q)
|
62 |
+
k = self.rotary_emb(k)
|
63 |
|
64 |
+
if enable_flash_attn:
|
65 |
from flash_attn import flash_attn_func
|
66 |
|
67 |
+
# (B, #heads, N, #dim) -> (B, N, #heads, #dim)
|
68 |
+
q = q.permute(0, 2, 1, 3)
|
69 |
+
k = k.permute(0, 2, 1, 3)
|
70 |
+
v = v.permute(0, 2, 1, 3)
|
71 |
x = flash_attn_func(
|
72 |
q,
|
73 |
k,
|
|
|
76 |
softmax_scale=self.scale,
|
77 |
)
|
78 |
else:
|
79 |
+
dtype = q.dtype
|
80 |
+
q = q * self.scale
|
81 |
+
attn = q @ k.transpose(-2, -1) # translate attn to float32
|
82 |
+
attn = attn.to(torch.float32)
|
83 |
+
attn = attn.softmax(dim=-1)
|
84 |
+
attn = attn.to(dtype) # cast back attn to original dtype
|
85 |
+
attn = self.attn_drop(attn)
|
86 |
+
x = attn @ v
|
87 |
|
88 |
x_output_shape = (B, N, C)
|
89 |
+
if not enable_flash_attn:
|
90 |
x = x.transpose(1, 2)
|
91 |
x = x.reshape(x_output_shape)
|
92 |
x = self.proj(x)
|
|
|
95 |
|
96 |
|
97 |
class MultiHeadCrossAttention(nn.Module):
|
98 |
+
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
|
99 |
super(MultiHeadCrossAttention, self).__init__()
|
100 |
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
101 |
|
102 |
self.d_model = d_model
|
103 |
self.num_heads = num_heads
|
104 |
self.head_dim = d_model // num_heads
|
|
|
105 |
|
106 |
self.q_linear = nn.Linear(d_model, d_model)
|
107 |
self.kv_linear = nn.Linear(d_model, d_model * 2)
|
108 |
self.attn_drop = nn.Dropout(attn_drop)
|
109 |
self.proj = nn.Linear(d_model, d_model)
|
110 |
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
111 |
|
112 |
+
def forward(self, x, cond, mask=None):
|
113 |
# query/value: img tokens; key: condition; mask: if padding tokens
|
114 |
B, N, C = x.shape
|
115 |
|
116 |
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
117 |
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
118 |
k, v = kv.unbind(2)
|
|
|
119 |
|
120 |
+
attn_bias = None
|
121 |
+
# TODO: support torch computation
|
122 |
+
import xformers.ops
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
if mask is not None:
|
125 |
+
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
|
126 |
+
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
|
|
|
|
|
127 |
|
128 |
+
x = x.view(B, -1, C)
|
129 |
+
x = self.proj(x)
|
130 |
+
x = self.proj_drop(x)
|
131 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
videosys/models/modules/downsampling.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class CogVideoXDownsample3D(nn.Module):
|
7 |
+
# Todo: Wait for paper relase.
|
8 |
+
r"""
|
9 |
+
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
|
10 |
+
|
11 |
+
Args:
|
12 |
+
in_channels (`int`):
|
13 |
+
Number of channels in the input image.
|
14 |
+
out_channels (`int`):
|
15 |
+
Number of channels produced by the convolution.
|
16 |
+
kernel_size (`int`, defaults to `3`):
|
17 |
+
Size of the convolving kernel.
|
18 |
+
stride (`int`, defaults to `2`):
|
19 |
+
Stride of the convolution.
|
20 |
+
padding (`int`, defaults to `0`):
|
21 |
+
Padding added to all four sides of the input.
|
22 |
+
compress_time (`bool`, defaults to `False`):
|
23 |
+
Whether or not to compress the time dimension.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
in_channels: int,
|
29 |
+
out_channels: int,
|
30 |
+
kernel_size: int = 3,
|
31 |
+
stride: int = 2,
|
32 |
+
padding: int = 0,
|
33 |
+
compress_time: bool = False,
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
38 |
+
self.compress_time = compress_time
|
39 |
+
|
40 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
41 |
+
if self.compress_time:
|
42 |
+
batch_size, channels, frames, height, width = x.shape
|
43 |
+
|
44 |
+
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
|
45 |
+
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
|
46 |
+
|
47 |
+
if x.shape[-1] % 2 == 1:
|
48 |
+
x_first, x_rest = x[..., 0], x[..., 1:]
|
49 |
+
if x_rest.shape[-1] > 0:
|
50 |
+
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
51 |
+
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
52 |
+
|
53 |
+
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
54 |
+
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
|
55 |
+
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
56 |
+
else:
|
57 |
+
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
|
58 |
+
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
59 |
+
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
60 |
+
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
61 |
+
|
62 |
+
# Pad the tensor
|
63 |
+
pad = (0, 1, 0, 1)
|
64 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
65 |
+
batch_size, channels, frames, height, width = x.shape
|
66 |
+
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
|
67 |
+
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
|
68 |
+
x = self.conv(x)
|
69 |
+
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
|
70 |
+
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
71 |
+
return x
|
videosys/models/{open_sora/modules.py → modules/embeddings.py}
RENAMED
@@ -1,16 +1,8 @@
|
|
1 |
-
# Adapted from OpenSora
|
2 |
-
|
3 |
-
# This source code is licensed under the license found in the
|
4 |
-
# LICENSE file in the root directory of this source tree.
|
5 |
-
# --------------------------------------------------------
|
6 |
-
# References:
|
7 |
-
# OpenSora: https://github.com/hpcaitech/Open-Sora
|
8 |
-
# --------------------------------------------------------
|
9 |
-
|
10 |
import functools
|
11 |
import math
|
12 |
-
from typing import Optional
|
13 |
|
|
|
14 |
import torch
|
15 |
import torch.nn as nn
|
16 |
import torch.nn.functional as F
|
@@ -18,40 +10,48 @@ import torch.utils.checkpoint
|
|
18 |
from einops import rearrange
|
19 |
from timm.models.vision_transformer import Mlp
|
20 |
|
21 |
-
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
29 |
super().__init__()
|
30 |
-
self.
|
31 |
-
self.variance_epsilon = eps
|
32 |
-
|
33 |
-
def forward(self, hidden_states):
|
34 |
-
input_dtype = hidden_states.dtype
|
35 |
-
hidden_states = hidden_states.to(torch.float32)
|
36 |
-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
37 |
-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
38 |
-
return self.weight * hidden_states.to(input_dtype)
|
39 |
-
|
40 |
-
|
41 |
-
def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool):
|
42 |
-
return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine)
|
43 |
-
|
44 |
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
#
|
|
|
52 |
|
53 |
|
54 |
-
class
|
55 |
"""Video to Patch Embedding.
|
56 |
|
57 |
Args:
|
@@ -104,176 +104,6 @@ class PatchEmbed3D(nn.Module):
|
|
104 |
return x
|
105 |
|
106 |
|
107 |
-
class Attention(nn.Module):
|
108 |
-
def __init__(
|
109 |
-
self,
|
110 |
-
dim: int,
|
111 |
-
num_heads: int = 8,
|
112 |
-
qkv_bias: bool = False,
|
113 |
-
qk_norm: bool = False,
|
114 |
-
attn_drop: float = 0.0,
|
115 |
-
proj_drop: float = 0.0,
|
116 |
-
norm_layer: nn.Module = LlamaRMSNorm,
|
117 |
-
enable_flash_attn: bool = False,
|
118 |
-
rope=None,
|
119 |
-
qk_norm_legacy: bool = False,
|
120 |
-
) -> None:
|
121 |
-
super().__init__()
|
122 |
-
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
123 |
-
self.dim = dim
|
124 |
-
self.num_heads = num_heads
|
125 |
-
self.head_dim = dim // num_heads
|
126 |
-
self.scale = self.head_dim**-0.5
|
127 |
-
self.enable_flash_attn = enable_flash_attn
|
128 |
-
|
129 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
130 |
-
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
131 |
-
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
132 |
-
self.qk_norm_legacy = qk_norm_legacy
|
133 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
134 |
-
self.proj = nn.Linear(dim, dim)
|
135 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
136 |
-
|
137 |
-
self.rope = False
|
138 |
-
if rope is not None:
|
139 |
-
self.rope = True
|
140 |
-
self.rotary_emb = rope
|
141 |
-
|
142 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
143 |
-
B, N, C = x.shape
|
144 |
-
# flash attn is not memory efficient for small sequences, this is empirical
|
145 |
-
enable_flash_attn = self.enable_flash_attn and (N > B)
|
146 |
-
qkv = self.qkv(x)
|
147 |
-
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
148 |
-
|
149 |
-
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
|
150 |
-
q, k, v = qkv.unbind(0)
|
151 |
-
if self.qk_norm_legacy:
|
152 |
-
# WARNING: this may be a bug
|
153 |
-
if self.rope:
|
154 |
-
q = self.rotary_emb(q)
|
155 |
-
k = self.rotary_emb(k)
|
156 |
-
q, k = self.q_norm(q), self.k_norm(k)
|
157 |
-
else:
|
158 |
-
q, k = self.q_norm(q), self.k_norm(k)
|
159 |
-
if self.rope:
|
160 |
-
q = self.rotary_emb(q)
|
161 |
-
k = self.rotary_emb(k)
|
162 |
-
|
163 |
-
if enable_flash_attn:
|
164 |
-
from flash_attn import flash_attn_func
|
165 |
-
|
166 |
-
# (B, #heads, N, #dim) -> (B, N, #heads, #dim)
|
167 |
-
q = q.permute(0, 2, 1, 3)
|
168 |
-
k = k.permute(0, 2, 1, 3)
|
169 |
-
v = v.permute(0, 2, 1, 3)
|
170 |
-
x = flash_attn_func(
|
171 |
-
q,
|
172 |
-
k,
|
173 |
-
v,
|
174 |
-
dropout_p=self.attn_drop.p if self.training else 0.0,
|
175 |
-
softmax_scale=self.scale,
|
176 |
-
)
|
177 |
-
else:
|
178 |
-
dtype = q.dtype
|
179 |
-
q = q * self.scale
|
180 |
-
attn = q @ k.transpose(-2, -1) # translate attn to float32
|
181 |
-
attn = attn.to(torch.float32)
|
182 |
-
attn = attn.softmax(dim=-1)
|
183 |
-
attn = attn.to(dtype) # cast back attn to original dtype
|
184 |
-
attn = self.attn_drop(attn)
|
185 |
-
x = attn @ v
|
186 |
-
|
187 |
-
x_output_shape = (B, N, C)
|
188 |
-
if not enable_flash_attn:
|
189 |
-
x = x.transpose(1, 2)
|
190 |
-
x = x.reshape(x_output_shape)
|
191 |
-
x = self.proj(x)
|
192 |
-
x = self.proj_drop(x)
|
193 |
-
return x
|
194 |
-
|
195 |
-
|
196 |
-
class MultiHeadCrossAttention(nn.Module):
|
197 |
-
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
|
198 |
-
super(MultiHeadCrossAttention, self).__init__()
|
199 |
-
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
200 |
-
|
201 |
-
self.d_model = d_model
|
202 |
-
self.num_heads = num_heads
|
203 |
-
self.head_dim = d_model // num_heads
|
204 |
-
|
205 |
-
self.q_linear = nn.Linear(d_model, d_model)
|
206 |
-
self.kv_linear = nn.Linear(d_model, d_model * 2)
|
207 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
208 |
-
self.proj = nn.Linear(d_model, d_model)
|
209 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
210 |
-
|
211 |
-
def forward(self, x, cond, mask=None):
|
212 |
-
# query/value: img tokens; key: condition; mask: if padding tokens
|
213 |
-
B, N, C = x.shape
|
214 |
-
|
215 |
-
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
216 |
-
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
217 |
-
k, v = kv.unbind(2)
|
218 |
-
|
219 |
-
attn_bias = None
|
220 |
-
# TODO: support torch computation
|
221 |
-
import xformers.ops
|
222 |
-
|
223 |
-
if mask is not None:
|
224 |
-
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
|
225 |
-
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
|
226 |
-
|
227 |
-
x = x.view(B, -1, C)
|
228 |
-
x = self.proj(x)
|
229 |
-
x = self.proj_drop(x)
|
230 |
-
return x
|
231 |
-
|
232 |
-
|
233 |
-
class T2IFinalLayer(nn.Module):
|
234 |
-
"""
|
235 |
-
The final layer of PixArt.
|
236 |
-
"""
|
237 |
-
|
238 |
-
def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None):
|
239 |
-
super().__init__()
|
240 |
-
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
241 |
-
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
|
242 |
-
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
|
243 |
-
self.out_channels = out_channels
|
244 |
-
self.d_t = d_t
|
245 |
-
self.d_s = d_s
|
246 |
-
|
247 |
-
def t_mask_select(self, x_mask, x, masked_x, T, S):
|
248 |
-
# x: [B, (T, S), C]
|
249 |
-
# mased_x: [B, (T, S), C]
|
250 |
-
# x_mask: [B, T]
|
251 |
-
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
252 |
-
masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
|
253 |
-
x = torch.where(x_mask[:, :, None, None], x, masked_x)
|
254 |
-
x = rearrange(x, "B T S C -> B (T S) C")
|
255 |
-
return x
|
256 |
-
|
257 |
-
def forward(self, x, t, x_mask=None, t0=None, T=None, S=None):
|
258 |
-
if T is None:
|
259 |
-
T = self.d_t
|
260 |
-
if S is None:
|
261 |
-
S = self.d_s
|
262 |
-
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
|
263 |
-
x = t2i_modulate(self.norm_final(x), shift, scale)
|
264 |
-
if x_mask is not None:
|
265 |
-
shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1)
|
266 |
-
x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero)
|
267 |
-
x = self.t_mask_select(x_mask, x, x_zero, T, S)
|
268 |
-
x = self.linear(x)
|
269 |
-
return x
|
270 |
-
|
271 |
-
|
272 |
-
# ===============================================
|
273 |
-
# Embedding Layers for Timesteps and Class Labels
|
274 |
-
# ===============================================
|
275 |
-
|
276 |
-
|
277 |
class TimestepEmbedder(nn.Module):
|
278 |
"""
|
279 |
Embeds scalar timesteps into vector representations.
|
@@ -350,7 +180,7 @@ class SizeEmbedder(TimestepEmbedder):
|
|
350 |
return next(self.parameters()).dtype
|
351 |
|
352 |
|
353 |
-
class
|
354 |
"""
|
355 |
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
356 |
"""
|
@@ -398,7 +228,7 @@ class CaptionEmbedder(nn.Module):
|
|
398 |
return caption
|
399 |
|
400 |
|
401 |
-
class
|
402 |
def __init__(self, dim: int) -> None:
|
403 |
super().__init__()
|
404 |
self.dim = dim
|
@@ -448,3 +278,135 @@ class PositionEmbedding2D(nn.Module):
|
|
448 |
base_size: Optional[int] = None,
|
449 |
) -> torch.Tensor:
|
450 |
return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import functools
|
2 |
import math
|
3 |
+
from typing import Optional, Tuple, Union
|
4 |
|
5 |
+
import numpy as np
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
|
|
10 |
from einops import rearrange
|
11 |
from timm.models.vision_transformer import Mlp
|
12 |
|
|
|
13 |
|
14 |
+
class CogVideoXPatchEmbed(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
patch_size: int = 2,
|
18 |
+
in_channels: int = 16,
|
19 |
+
embed_dim: int = 1920,
|
20 |
+
text_embed_dim: int = 4096,
|
21 |
+
bias: bool = True,
|
22 |
+
) -> None:
|
23 |
super().__init__()
|
24 |
+
self.patch_size = patch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
self.proj = nn.Conv2d(
|
27 |
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
28 |
+
)
|
29 |
+
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
30 |
+
|
31 |
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
32 |
+
r"""
|
33 |
+
Args:
|
34 |
+
text_embeds (`torch.Tensor`):
|
35 |
+
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
36 |
+
image_embeds (`torch.Tensor`):
|
37 |
+
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
38 |
+
"""
|
39 |
+
text_embeds = self.text_proj(text_embeds)
|
40 |
|
41 |
+
batch, num_frames, channels, height, width = image_embeds.shape
|
42 |
+
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
43 |
+
image_embeds = self.proj(image_embeds)
|
44 |
+
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
45 |
+
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
46 |
+
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
47 |
|
48 |
+
embeds = torch.cat(
|
49 |
+
[text_embeds, image_embeds], dim=1
|
50 |
+
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
51 |
+
return embeds
|
52 |
|
53 |
|
54 |
+
class OpenSoraPatchEmbed3D(nn.Module):
|
55 |
"""Video to Patch Embedding.
|
56 |
|
57 |
Args:
|
|
|
104 |
return x
|
105 |
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
class TimestepEmbedder(nn.Module):
|
108 |
"""
|
109 |
Embeds scalar timesteps into vector representations.
|
|
|
180 |
return next(self.parameters()).dtype
|
181 |
|
182 |
|
183 |
+
class OpenSoraCaptionEmbedder(nn.Module):
|
184 |
"""
|
185 |
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
186 |
"""
|
|
|
228 |
return caption
|
229 |
|
230 |
|
231 |
+
class OpenSoraPositionEmbedding2D(nn.Module):
|
232 |
def __init__(self, dim: int) -> None:
|
233 |
super().__init__()
|
234 |
self.dim = dim
|
|
|
278 |
base_size: Optional[int] = None,
|
279 |
) -> torch.Tensor:
|
280 |
return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size)
|
281 |
+
|
282 |
+
|
283 |
+
def get_3d_rotary_pos_embed(
|
284 |
+
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
285 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
286 |
+
"""
|
287 |
+
RoPE for video tokens with 3D structure.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
embed_dim: (`int`):
|
291 |
+
The embedding dimension size, corresponding to hidden_size_head.
|
292 |
+
crops_coords (`Tuple[int]`):
|
293 |
+
The top-left and bottom-right coordinates of the crop.
|
294 |
+
grid_size (`Tuple[int]`):
|
295 |
+
The grid size of the spatial positional embedding (height, width).
|
296 |
+
temporal_size (`int`):
|
297 |
+
The size of the temporal dimension.
|
298 |
+
theta (`float`):
|
299 |
+
Scaling factor for frequency computation.
|
300 |
+
use_real (`bool`):
|
301 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
305 |
+
"""
|
306 |
+
start, stop = crops_coords
|
307 |
+
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
308 |
+
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
309 |
+
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
310 |
+
|
311 |
+
# Compute dimensions for each axis
|
312 |
+
dim_t = embed_dim // 4
|
313 |
+
dim_h = embed_dim // 8 * 3
|
314 |
+
dim_w = embed_dim // 8 * 3
|
315 |
+
|
316 |
+
# Temporal frequencies
|
317 |
+
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
318 |
+
grid_t = torch.from_numpy(grid_t).float()
|
319 |
+
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
320 |
+
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
321 |
+
|
322 |
+
# Spatial frequencies for height and width
|
323 |
+
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
324 |
+
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
325 |
+
grid_h = torch.from_numpy(grid_h).float()
|
326 |
+
grid_w = torch.from_numpy(grid_w).float()
|
327 |
+
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
328 |
+
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
329 |
+
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
330 |
+
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
331 |
+
|
332 |
+
# Broadcast and concatenate tensors along specified dimension
|
333 |
+
def broadcast(tensors, dim=-1):
|
334 |
+
num_tensors = len(tensors)
|
335 |
+
shape_lens = {len(t.shape) for t in tensors}
|
336 |
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
337 |
+
shape_len = list(shape_lens)[0]
|
338 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
339 |
+
dims = list(zip(*(list(t.shape) for t in tensors)))
|
340 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
341 |
+
assert all(
|
342 |
+
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
343 |
+
), "invalid dimensions for broadcastable concatenation"
|
344 |
+
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
345 |
+
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
346 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
347 |
+
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
348 |
+
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
349 |
+
return torch.cat(tensors, dim=dim)
|
350 |
+
|
351 |
+
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
352 |
+
|
353 |
+
t, h, w, d = freqs.shape
|
354 |
+
freqs = freqs.view(t * h * w, d)
|
355 |
+
|
356 |
+
# Generate sine and cosine components
|
357 |
+
sin = freqs.sin()
|
358 |
+
cos = freqs.cos()
|
359 |
+
|
360 |
+
if use_real:
|
361 |
+
return cos, sin
|
362 |
+
else:
|
363 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
364 |
+
return freqs_cis
|
365 |
+
|
366 |
+
|
367 |
+
def apply_rotary_emb(
|
368 |
+
x: torch.Tensor,
|
369 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
370 |
+
use_real: bool = True,
|
371 |
+
use_real_unbind_dim: int = -1,
|
372 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
373 |
+
"""
|
374 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
375 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
376 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
377 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
378 |
+
|
379 |
+
Args:
|
380 |
+
x (`torch.Tensor`):
|
381 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
382 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
383 |
+
|
384 |
+
Returns:
|
385 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
386 |
+
"""
|
387 |
+
if use_real:
|
388 |
+
cos, sin = freqs_cis # [S, D]
|
389 |
+
cos = cos[None, None]
|
390 |
+
sin = sin[None, None]
|
391 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
392 |
+
|
393 |
+
if use_real_unbind_dim == -1:
|
394 |
+
# Use for example in Lumina
|
395 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
396 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
397 |
+
elif use_real_unbind_dim == -2:
|
398 |
+
# Use for example in Stable Audio
|
399 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
400 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
401 |
+
else:
|
402 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
403 |
+
|
404 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
405 |
+
|
406 |
+
return out
|
407 |
+
else:
|
408 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
409 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
410 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
411 |
+
|
412 |
+
return x_out.type_as(x)
|
videosys/models/modules/normalization.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class LlamaRMSNorm(nn.Module):
|
8 |
+
def __init__(self, hidden_size, eps=1e-6):
|
9 |
+
"""
|
10 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
11 |
+
"""
|
12 |
+
super().__init__()
|
13 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
14 |
+
self.variance_epsilon = eps
|
15 |
+
|
16 |
+
def forward(self, hidden_states):
|
17 |
+
input_dtype = hidden_states.dtype
|
18 |
+
hidden_states = hidden_states.to(torch.float32)
|
19 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
20 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
21 |
+
return self.weight * hidden_states.to(input_dtype)
|
22 |
+
|
23 |
+
|
24 |
+
class CogVideoXLayerNormZero(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
conditioning_dim: int,
|
28 |
+
embedding_dim: int,
|
29 |
+
elementwise_affine: bool = True,
|
30 |
+
eps: float = 1e-5,
|
31 |
+
bias: bool = True,
|
32 |
+
) -> None:
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.silu = nn.SiLU()
|
36 |
+
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
|
37 |
+
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
38 |
+
|
39 |
+
def forward(
|
40 |
+
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
41 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
42 |
+
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
43 |
+
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
44 |
+
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
45 |
+
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
|
46 |
+
|
47 |
+
|
48 |
+
class AdaLayerNorm(nn.Module):
|
49 |
+
r"""
|
50 |
+
Norm layer modified to incorporate timestep embeddings.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
embedding_dim (`int`): The size of each embedding vector.
|
54 |
+
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
|
55 |
+
output_dim (`int`, *optional*):
|
56 |
+
norm_elementwise_affine (`bool`, defaults to `False):
|
57 |
+
norm_eps (`bool`, defaults to `False`):
|
58 |
+
chunk_dim (`int`, defaults to `0`):
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
embedding_dim: int,
|
64 |
+
num_embeddings: Optional[int] = None,
|
65 |
+
output_dim: Optional[int] = None,
|
66 |
+
norm_elementwise_affine: bool = False,
|
67 |
+
norm_eps: float = 1e-5,
|
68 |
+
chunk_dim: int = 0,
|
69 |
+
):
|
70 |
+
super().__init__()
|
71 |
+
|
72 |
+
self.chunk_dim = chunk_dim
|
73 |
+
output_dim = output_dim or embedding_dim * 2
|
74 |
+
|
75 |
+
if num_embeddings is not None:
|
76 |
+
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
77 |
+
else:
|
78 |
+
self.emb = None
|
79 |
+
|
80 |
+
self.silu = nn.SiLU()
|
81 |
+
self.linear = nn.Linear(embedding_dim, output_dim)
|
82 |
+
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
83 |
+
|
84 |
+
def forward(
|
85 |
+
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
|
86 |
+
) -> torch.Tensor:
|
87 |
+
if self.emb is not None:
|
88 |
+
temb = self.emb(timestep)
|
89 |
+
|
90 |
+
temb = self.linear(self.silu(temb))
|
91 |
+
|
92 |
+
if self.chunk_dim == 1:
|
93 |
+
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
94 |
+
# other if-branch. This branch is specific to CogVideoX for now.
|
95 |
+
shift, scale = temb.chunk(2, dim=1)
|
96 |
+
shift = shift[:, None, :]
|
97 |
+
scale = scale[:, None, :]
|
98 |
+
else:
|
99 |
+
scale, shift = temb.chunk(2, dim=0)
|
100 |
+
|
101 |
+
x = self.norm(x) * (1 + scale) + shift
|
102 |
+
return x
|
videosys/models/modules/upsampling.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class CogVideoXUpsample3D(nn.Module):
|
7 |
+
r"""
|
8 |
+
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
in_channels (`int`):
|
12 |
+
Number of channels in the input image.
|
13 |
+
out_channels (`int`):
|
14 |
+
Number of channels produced by the convolution.
|
15 |
+
kernel_size (`int`, defaults to `3`):
|
16 |
+
Size of the convolving kernel.
|
17 |
+
stride (`int`, defaults to `1`):
|
18 |
+
Stride of the convolution.
|
19 |
+
padding (`int`, defaults to `1`):
|
20 |
+
Padding added to all four sides of the input.
|
21 |
+
compress_time (`bool`, defaults to `False`):
|
22 |
+
Whether or not to compress the time dimension.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
in_channels: int,
|
28 |
+
out_channels: int,
|
29 |
+
kernel_size: int = 3,
|
30 |
+
stride: int = 1,
|
31 |
+
padding: int = 1,
|
32 |
+
compress_time: bool = False,
|
33 |
+
) -> None:
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
37 |
+
self.compress_time = compress_time
|
38 |
+
|
39 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
40 |
+
if self.compress_time:
|
41 |
+
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
42 |
+
# split first frame
|
43 |
+
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
44 |
+
|
45 |
+
x_first = F.interpolate(x_first, scale_factor=2.0)
|
46 |
+
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
47 |
+
x_first = x_first[:, :, None, :, :]
|
48 |
+
inputs = torch.cat([x_first, x_rest], dim=2)
|
49 |
+
elif inputs.shape[2] > 1:
|
50 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
51 |
+
else:
|
52 |
+
inputs = inputs.squeeze(2)
|
53 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
54 |
+
inputs = inputs[:, :, None, :, :]
|
55 |
+
else:
|
56 |
+
# only interpolate 2D
|
57 |
+
b, c, t, h, w = inputs.shape
|
58 |
+
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
59 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
60 |
+
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
61 |
+
|
62 |
+
b, c, t, h, w = inputs.shape
|
63 |
+
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
64 |
+
inputs = self.conv(inputs)
|
65 |
+
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
66 |
+
|
67 |
+
return inputs
|